1
//! Helpers for retrying a fallible operation according to a backoff schedule.
2
//!
3
//! [`Runner::run`] retries the specified operation according to the [`BackoffSchedule`] of the
4
//! [`Runner`]. Users can customize the backoff behavior by implementing [`BackoffSchedule`].
5

            
6
// TODO: this is a (somewhat) general-purpose utility, so it should probably be factored out of
7
// tor-hsservice
8

            
9
use std::pin::Pin;
10

            
11
use futures::future::FusedFuture;
12

            
13
use tor_rtcompat::TimeoutError;
14

            
15
use super::*;
16

            
17
/// A runner for a fallible operation, which retries on failure according to a [`BackoffSchedule`].
18
pub(super) struct Runner<B: BackoffSchedule, R: Runtime> {
19
    /// A description of the operation we are trying to do.
20
    doing: String,
21
    /// The backoff schedule.
22
    schedule: B,
23
    /// The runtime.
24
    runtime: R,
25
}
26

            
27
impl<B: BackoffSchedule, R: Runtime> Runner<B, R> {
28
    /// Create a new `Runner`.
29
    pub(super) fn new(doing: String, schedule: B, runtime: R) -> Self {
30
        Self {
31
            doing,
32
            schedule,
33
            runtime,
34
        }
35
    }
36

            
37
    /// Run `fallible_fn`, retrying according to the [`BackoffSchedule`] of this `Runner`.
38
    ///
39
    /// If `fallible_fn` eventually returns `Ok(_)`, return that output. Otherwise,
40
    /// keep retrying until either `fallible_fn` has failed too many times, or until
41
    /// a fatal error occurs.
42
    #[allow(clippy::cognitive_complexity)] // TODO: Refactor
43
8
    pub(super) async fn run<T, E, F>(
44
8
        mut self,
45
8
        mut fallible_fn: impl FnMut() -> F,
46
8
    ) -> Result<T, BackoffError<E>>
47
8
    where
48
8
        E: RetriableError,
49
8
        F: Future<Output = Result<T, E>> + Send,
50
8
    {
51
8
        let mut retry_count = 0;
52
8
        let mut errors = RetryError::in_attempt_to(self.doing.clone());
53

            
54
        // When this timeout elapses, the `Runner` will stop retrying the fallible operation.
55
        //
56
        // A `overall_timeout` of `None` means there is no time limit for the retries.
57
8
        let mut overall_timeout = match self.schedule.overall_timeout() {
58
2
            Some(timeout) => Either::Left(Box::pin(self.runtime.sleep(timeout))),
59
6
            None => Either::Right(future::pending()),
60
        }
61
8
        .fuse();
62

            
63
        loop {
64
            // Bail if we've exceeded the number of allowed retries.
65
54
            if matches!(self.schedule.max_retries(), Some(max_retry_count) if retry_count >= max_retry_count)
66
            {
67
4
                return Err(BackoffError::MaxRetryCountExceeded(errors));
68
50
            }
69
50

            
70
50
            let mut fallible_op = optionally_timeout(
71
50
                &self.runtime,
72
50
                fallible_fn(),
73
50
                self.schedule.single_attempt_timeout(),
74
50
            );
75
50

            
76
50
            trace!(attempt = (retry_count + 1), "{}", self.doing);
77

            
78
50
            select_biased! {
79
                () = overall_timeout => {
80
                    // The timeout has elapsed, so stop retrying and return the errors
81
                    // accumulated so far.
82
2
                    return Err(BackoffError::Timeout(errors))
83
                }
84
48
                res = fallible_op => {
85
                    // TODO: the error branches in the match below have different error types,
86
                    // so we must compute should_retry and delay separately, on each branch.
87
                    //
88
                    // We could refactor this to extract the error using
89
                    // let err = match res { ... } and call err.should_retry()
90
                    // and next_delay() after the match, but this will involve
91
                    // rethinking the BackoffSchedule trait and/or RetriableError
92
                    // (currently RetriableError is Clone, so it's not object safe).
93
48
                    let (should_retry, delay) = match res {
94
                        Ok(Ok(res)) => return Ok(res),
95
38
                        Ok(Err(e)) => {
96
38
                            // The operation failed: check if we can retry it.
97
38
                            let should_retry = e.should_retry();
98
38

            
99
38
                            debug!(
100
                                attempt=(retry_count + 1), can_retry=should_retry,
101
                                "failed to {}: {e}", self.doing
102
                            );
103

            
104
38
                            errors.push(e.clone());
105
38
                            (e.should_retry(), self.schedule.next_delay(&e))
106
                        }
107
10
                        Err(e) => {
108
10
                            trace!("fallible operation timed out; retrying");
109
10
                            (e.should_retry(), self.schedule.next_delay(&e))
110
                        },
111
                    };
112

            
113
48
                    if should_retry {
114
46
                        retry_count += 1;
115

            
116
46
                        let Some(delay) = delay else {
117
                            return Err(BackoffError::ExplicitStop(errors));
118
                        };
119

            
120
                        // Introduce the specified delay between retries
121
46
                        let () = self.runtime.sleep(delay).await;
122

            
123
                        // Try again unless the entire operation has timed out.
124
46
                        continue;
125
2
                    }
126
2

            
127
2
                    return Err(BackoffError::FatalError(errors));
128
                },
129
            }
130
        }
131
8
    }
132
}
133

            
134
/// Wrap a [`Future`] with an optional timeout.
135
///
136
/// If `timeout` is `Some`, returns a [`Timeout`](tor_rtcompat::Timeout)
137
/// that resolves to the value of `future` if the future completes within `timeout`,
138
/// or a [`TimeoutError`] if it does not.
139
/// If `timeout` is `None`, returns a new future which maps the specified `future`'s
140
/// output type to a `Result::Ok`.
141
50
fn optionally_timeout<'f, R, F>(
142
50
    runtime: &R,
143
50
    future: F,
144
50
    timeout: Option<Duration>,
145
50
) -> Pin<Box<dyn FusedFuture<Output = Result<F::Output, TimeoutError>> + Send + 'f>>
146
50
where
147
50
    R: Runtime,
148
50
    F: Future + Send + 'f,
149
50
{
150
50
    match timeout {
151
10
        Some(timeout) => Box::pin(runtime.timeout(timeout, future).fuse()),
152
40
        None => Box::pin(future.map(Ok)),
153
    }
154
50
}
155

            
156
/// A trait that specifies the parameters for retrying a fallible operation.
157
pub(super) trait BackoffSchedule {
158
    /// The maximum number of retries.
159
    ///
160
    /// A return value of `None` indicates is no upper limit for the number of retries, and that
161
    /// the operation should be retried until [`BackoffSchedule::overall_timeout`] time elapses (or
162
    /// indefinitely, if [`BackoffSchedule::overall_timeout`] returns `None`).
163
    fn max_retries(&self) -> Option<usize>;
164

            
165
    /// The total amount of time allowed for the retriable operation.
166
    ///
167
    /// A return value of `None` indicates the operation should be retried until
168
    /// [`BackoffSchedule::max_retries`] number of retries are exceeded (or indefinitely, if
169
    /// [`BackoffSchedule::max_retries`] returns `None`).
170
    fn overall_timeout(&self) -> Option<Duration>;
171

            
172
    /// The total amount of time allowed for a single operation.
173
    fn single_attempt_timeout(&self) -> Option<Duration>;
174

            
175
    /// Return the delay to introduce before the next retry.
176
    ///
177
    /// The `error` parameter contains the error returned by the fallible operation. This enables
178
    /// implementors to (optionally) implement adaptive backoff. For example, if the operation is
179
    /// sending an HTTP request, and the error is a 429 (Too Many Requests) HTTP response with a
180
    /// `Retry-After` header, the implementor can implement a backoff schedule where the next retry
181
    /// is delayed by the value specified in the `Retry-After` header.
182
    fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration>;
183
}
184

            
185
/// The type of error encountered while running a fallible operation.
186
#[derive(Clone, Debug, thiserror::Error)]
187
pub(super) enum BackoffError<E> {
188
    /// A fatal (non-transient) error occurred.
189
    #[error("A fatal (non-transient) error occurred")]
190
    FatalError(RetryError<E>),
191

            
192
    /// Ran out of retries.
193
    #[error("Ran out of retries")]
194
    MaxRetryCountExceeded(RetryError<E>),
195

            
196
    /// Exceeded the maximum allowed time.
197
    #[error("Timeout exceeded")]
198
    Timeout(RetryError<E>),
199

            
200
    /// The [`BackoffSchedule`] told us to stop retrying.
201
    #[error("Stopped retrying as requested by BackoffSchedule")]
202
    ExplicitStop(RetryError<E>),
203
}
204

            
205
impl<E> From<BackoffError<E>> for RetryError<E> {
206
    fn from(e: BackoffError<E>) -> Self {
207
        match e {
208
            BackoffError::FatalError(e)
209
            | BackoffError::MaxRetryCountExceeded(e)
210
            | BackoffError::Timeout(e)
211
            | BackoffError::ExplicitStop(e) => e,
212
        }
213
    }
214
}
215

            
216
/// A trait for representing retriable errors.
217
pub(super) trait RetriableError: StdError + Clone {
218
    /// Whether this error is transient.
219
    fn should_retry(&self) -> bool;
220
}
221

            
222
impl RetriableError for TimeoutError {
223
10
    fn should_retry(&self) -> bool {
224
10
        true
225
10
    }
226
}
227

            
228
#[cfg(test)]
229
mod tests {
230
    // @@ begin test lint list maintained by maint/add_warning @@
231
    #![allow(clippy::bool_assert_comparison)]
232
    #![allow(clippy::clone_on_copy)]
233
    #![allow(clippy::dbg_macro)]
234
    #![allow(clippy::mixed_attributes_style)]
235
    #![allow(clippy::print_stderr)]
236
    #![allow(clippy::print_stdout)]
237
    #![allow(clippy::single_char_pattern)]
238
    #![allow(clippy::unwrap_used)]
239
    #![allow(clippy::unchecked_duration_subtraction)]
240
    #![allow(clippy::useless_vec)]
241
    #![allow(clippy::needless_pass_by_value)]
242
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
243

            
244
    use super::*;
245
    use std::sync::Arc;
246

            
247
    use std::iter;
248
    use std::sync::RwLock;
249

            
250
    use oneshot_fused_workaround as oneshot;
251
    use tor_rtcompat::{SleepProvider, ToplevelBlockOn};
252
    use tor_rtmock::MockRuntime;
253

            
254
    const SHORT_DELAY: Duration = Duration::from_millis(10);
255
    const TIMEOUT: Duration = Duration::from_millis(100);
256
    const SINGLE_TIMEOUT: Duration = Duration::from_millis(50);
257
    const MAX_RETRIES: usize = 5;
258

            
259
    macro_rules! impl_backoff_sched {
260
        ($name:ty, $max_retries:expr, $timeout:expr, $single_timeout:expr, $next_delay:expr) => {
261
            impl BackoffSchedule for $name {
262
                fn max_retries(&self) -> Option<usize> {
263
                    $max_retries
264
                }
265

            
266
                fn overall_timeout(&self) -> Option<Duration> {
267
                    $timeout
268
                }
269

            
270
                fn single_attempt_timeout(&self) -> Option<Duration> {
271
                    $single_timeout
272
                }
273

            
274
                #[allow(unused_variables)]
275
                fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration> {
276
                    $next_delay
277
                }
278
            }
279
        };
280
    }
281

            
282
    struct BackoffWithMaxRetries;
283

            
284
    impl_backoff_sched!(
285
        BackoffWithMaxRetries,
286
        Some(MAX_RETRIES),
287
        None,
288
        None,
289
        Some(SHORT_DELAY)
290
    );
291

            
292
    struct BackoffWithTimeout;
293

            
294
    impl_backoff_sched!(
295
        BackoffWithTimeout,
296
        None,
297
        Some(TIMEOUT),
298
        None,
299
        Some(SHORT_DELAY)
300
    );
301

            
302
    struct BackoffWithSingleTimeout;
303

            
304
    impl_backoff_sched!(
305
        BackoffWithSingleTimeout,
306
        Some(MAX_RETRIES),
307
        None,
308
        Some(SINGLE_TIMEOUT),
309
        Some(SHORT_DELAY)
310
    );
311

            
312
    /// A potentially retriable error.
313
    #[derive(Debug, Copy, Clone, thiserror::Error)]
314
    enum TestError {
315
        /// A fatal error
316
        #[error("A fatal test error")]
317
        Fatal,
318
        /// A transient error
319
        #[error("A transient test error")]
320
        Transient,
321
    }
322

            
323
    impl RetriableError for TestError {
324
        fn should_retry(&self) -> bool {
325
            match self {
326
                Self::Fatal => false,
327
                Self::Transient => true,
328
            }
329
        }
330
    }
331

            
332
    /// Run a single [`Runner`] test.
333
    fn run_test<E: RetriableError + Send + Sync + 'static>(
334
        sleep_for: Option<Duration>,
335
        schedule: impl BackoffSchedule + Send + 'static,
336
        errors: impl Iterator<Item = E> + Send + Sync + 'static,
337
        expected_run_count: usize,
338
        description: &'static str,
339
        expected_duration: Duration,
340
    ) {
341
        let runtime = MockRuntime::new();
342

            
343
        runtime.clone().block_on(async move {
344
            let runner = Runner {
345
                doing: description.into(),
346
                schedule,
347
                runtime: runtime.clone(),
348
            };
349

            
350
            let retry_count = Arc::new(RwLock::new(0));
351
            let (tx, rx) = oneshot::channel();
352

            
353
            let start = runtime.now();
354
            runtime
355
                .mock_task()
356
                .spawn_identified(format!("retry runner task: {description}"), {
357
                    let retry_count = Arc::clone(&retry_count);
358
                    let errors = Arc::new(RwLock::new(errors));
359
                    let runtime = runtime.clone();
360
                    async move {
361
                        if let Ok(()) = runner
362
                            .run(|| async {
363
                                *retry_count.write().unwrap() += 1;
364

            
365
                                if let Some(dur) = sleep_for {
366
                                    runtime.sleep(dur).await;
367
                                }
368

            
369
                                Err::<(), _>(errors.write().unwrap().next().unwrap())
370
                            })
371
                            .await
372
                        {
373
                            unreachable!();
374
                        }
375

            
376
                        let () = tx.send(()).unwrap();
377
                    }
378
                });
379

            
380
            // The expected retry count may be unknown (for example, if we set a timeout but no
381
            // upper limit for the number of retries, it's impossible to tell exactly how many
382
            // times the operation will be retried)
383
            for i in 1..=expected_run_count {
384
                runtime.mock_task().progress_until_stalled().await;
385
                // If our fallible_op is sleeping, advance the time until after it times out or
386
                // finishes sleeping.
387
                if let Some(sleep_for) = sleep_for {
388
                    runtime
389
                        .mock_sleep()
390
                        .advance(std::cmp::min(SINGLE_TIMEOUT, sleep_for));
391
                }
392
                runtime.mock_task().progress_until_stalled().await;
393
                runtime.mock_sleep().advance(SHORT_DELAY);
394
                assert_eq!(*retry_count.read().unwrap(), i);
395
            }
396

            
397
            let () = rx.await.unwrap();
398
            let end = runtime.now();
399

            
400
            assert_eq!(*retry_count.read().unwrap(), expected_run_count);
401
            assert!(duration_close_to(end - start, expected_duration));
402
        });
403
    }
404

            
405
    /// Return true if d1 is in range [d2...d2 + 0.01sec]
406
    ///
407
    /// TODO: lifted from tor-circmgr
408
    fn duration_close_to(d1: Duration, d2: Duration) -> bool {
409
        d1 >= d2 && d1 <= d2 + SHORT_DELAY
410
    }
411

            
412
    #[test]
413
    fn max_retries() {
414
        run_test(
415
            None,
416
            BackoffWithMaxRetries,
417
            iter::repeat(TestError::Transient),
418
            MAX_RETRIES,
419
            "backoff with max_retries and no timeout (transient errors)",
420
            Duration::from_millis(SHORT_DELAY.as_millis() as u64 * MAX_RETRIES as u64),
421
        );
422
    }
423

            
424
    #[test]
425
    fn max_retries_fatal() {
426
        use TestError::*;
427

            
428
        /// The number of transient errors that happen before the final, fatal error.
429
        const RETRIES_UNTIL_FATAL: usize = 3;
430
        /// The total number of times we exoect the fallible function to be called.
431
        /// The first RETRIES_UNTIL_FATAL times, a transient error is returned.
432
        /// The last call corresponds to the fatal error
433
        const EXPECTED_TOTAL_RUNS: usize = RETRIES_UNTIL_FATAL + 1;
434

            
435
        run_test(
436
            None,
437
            BackoffWithMaxRetries,
438
            std::iter::repeat_n(Transient, RETRIES_UNTIL_FATAL)
439
                .chain([Fatal])
440
                .chain(iter::repeat(Transient)),
441
            EXPECTED_TOTAL_RUNS,
442
            "backoff with max_retries and no timeout (transient errors followed by a fatal error)",
443
            Duration::from_millis(SHORT_DELAY.as_millis() as u64 * EXPECTED_TOTAL_RUNS as u64),
444
        );
445
    }
446

            
447
    #[test]
448
    fn timeout() {
449
        use TestError::*;
450

            
451
        let expected_run_count = TIMEOUT.as_millis() / SHORT_DELAY.as_millis();
452

            
453
        run_test(
454
            None,
455
            BackoffWithTimeout,
456
            iter::repeat(Transient),
457
            expected_run_count as usize,
458
            "backoff with timeout and no max_retries (transient errors)",
459
            TIMEOUT,
460
        );
461
    }
462

            
463
    #[test]
464
    fn single_timeout() {
465
        use TestError::*;
466

            
467
        // Each attempt will time out after SINGLE_TIMEOUT time units,
468
        // and the backoff runner sleeps for SLEEP_DELAY units in between retries
469
        let expected_duration = Duration::from_millis(
470
            (SHORT_DELAY.as_millis() + SINGLE_TIMEOUT.as_millis()) as u64 * MAX_RETRIES as u64,
471
        );
472

            
473
        run_test(
474
            // Sleep for more than SINGLE_TIMEOUT units
475
            // to trigger the single_attempt_timeout() timeout
476
            Some(SINGLE_TIMEOUT * 2),
477
            BackoffWithSingleTimeout,
478
            iter::repeat(Transient),
479
            MAX_RETRIES,
480
            "backoff with single timeout and max_retries and no overall timeout",
481
            expected_duration,
482
        );
483
    }
484

            
485
    // TODO (#1120): needs tests for the remaining corner cases
486
}