1
//! Helpers for retrying a fallible operation according to a backoff schedule.
2
//!
3
//! [`Runner::run`] retries the specified operation according to the [`BackoffSchedule`] of the
4
//! [`Runner`]. Users can customize the backoff behavior by implementing [`BackoffSchedule`].
5

            
6
// TODO: this is a (somewhat) general-purpose utility, so it should probably be factored out of
7
// tor-hsservice
8

            
9
use std::pin::Pin;
10

            
11
use futures::future::FusedFuture;
12

            
13
use tor_rtcompat::TimeoutError;
14

            
15
use super::*;
16

            
17
/// A runner for a fallible operation, which retries on failure according to a [`BackoffSchedule`].
18
pub(super) struct Runner<B: BackoffSchedule, R: Runtime> {
19
    /// A description of the operation we are trying to do.
20
    doing: String,
21
    /// The backoff schedule.
22
    schedule: B,
23
    /// The runtime.
24
    runtime: R,
25
}
26

            
27
impl<B: BackoffSchedule, R: Runtime> Runner<B, R> {
28
    /// Create a new `Runner`.
29
128
    pub(super) fn new(doing: String, schedule: B, runtime: R) -> Self {
30
128
        Self {
31
128
            doing,
32
128
            schedule,
33
128
            runtime,
34
128
        }
35
128
    }
36

            
37
    /// Run `fallible_fn`, retrying according to the [`BackoffSchedule`] of this `Runner`.
38
    ///
39
    /// If `fallible_fn` eventually returns `Ok(_)`, return that output. Otherwise,
40
    /// keep retrying until either `fallible_fn` has failed too many times, or until
41
    /// a fatal error occurs.
42
136
    pub(super) async fn run<T, E, F>(
43
136
        mut self,
44
136
        mut fallible_fn: impl FnMut() -> F,
45
136
    ) -> Result<T, BackoffError<E>>
46
136
    where
47
136
        E: RetriableError,
48
136
        F: Future<Output = Result<T, E>> + Send,
49
136
    {
50
136
        let mut retry_count = 0;
51
136
        let mut errors = RetryError::in_attempt_to(self.doing.clone());
52

            
53
        // When this timeout elapses, the `Runner` will stop retrying the fallible operation.
54
        //
55
        // A `overall_timeout` of `None` means there is no time limit for the retries.
56
136
        let mut overall_timeout = match self.schedule.overall_timeout() {
57
130
            Some(timeout) => Either::Left(Box::pin(self.runtime.sleep(timeout))),
58
6
            None => Either::Right(future::pending()),
59
        }
60
136
        .fuse();
61

            
62
        loop {
63
            // Bail if we've exceeded the number of allowed retries.
64
214
            if matches!(self.schedule.max_retries(), Some(max_retry_count) if retry_count >= max_retry_count)
65
            {
66
4
                return Err(BackoffError::MaxRetryCountExceeded(errors));
67
210
            }
68
210

            
69
210
            let mut fallible_op = optionally_timeout(
70
210
                &self.runtime,
71
210
                fallible_fn(),
72
210
                self.schedule.single_attempt_timeout(),
73
210
            );
74
210

            
75
210
            trace!(attempt = (retry_count + 1), "{}", self.doing);
76

            
77
            select_biased! {
78
                () = overall_timeout => {
79
                    // The timeout has elapsed, so stop retrying and return the errors
80
                    // accumulated so far.
81
                    return Err(BackoffError::Timeout(errors))
82
                }
83
                res = fallible_op => {
84
                    // TODO: the error branches in the match below have different error types,
85
                    // so we must compute should_retry and delay separately, on each branch.
86
                    //
87
                    // We could refactor this to extract the error using
88
                    // let err = match res { ... } and call err.should_retry()
89
                    // and next_delay() after the match, but this will involve
90
                    // rethinking the BackoffSchedule trait and/or RetriableError
91
                    // (currently RetriableError is Clone, so it's not object safe).
92
                    let (should_retry, delay) = match res {
93
                        Ok(Ok(res)) => return Ok(res),
94
                        Ok(Err(e)) => {
95
                            // The operation failed: check if we can retry it.
96
                            let should_retry = e.should_retry();
97

            
98
                            debug!(
99
                                attempt=(retry_count + 1), can_retry=should_retry,
100
                                "failed to {}: {e}", self.doing
101
                            );
102

            
103
                            errors.push(e.clone());
104
                            (e.should_retry(), self.schedule.next_delay(&e))
105
                        }
106
                        Err(e) => {
107
                            trace!("fallible operation timed out; retrying");
108
                            (e.should_retry(), self.schedule.next_delay(&e))
109
                        },
110
                    };
111

            
112
                    if should_retry {
113
                        retry_count += 1;
114

            
115
                        let Some(delay) = delay else {
116
                            return Err(BackoffError::ExplicitStop(errors));
117
                        };
118

            
119
                        // Introduce the specified delay between retries
120
                        let () = self.runtime.sleep(delay).await;
121

            
122
                        // Try again unless the entire operation has timed out.
123
                        continue;
124
                    }
125

            
126
                    return Err(BackoffError::FatalError(errors));
127
                },
128
            }
129
        }
130
104
    }
131
}
132

            
133
/// Wrap a [`Future`] with an optional timeout.
134
///
135
/// If `timeout` is `Some`, returns a [`Timeout`](tor_rtcompat::Timeout)
136
/// that resolves to the value of `future` if the future completes within `timeout`,
137
/// or a [`TimeoutError`] if it does not.
138
/// If `timeout` is `None`, returns a new future which maps the specified `future`'s
139
/// output type to a `Result::Ok`.
140
210
fn optionally_timeout<'f, R, F>(
141
210
    runtime: &R,
142
210
    future: F,
143
210
    timeout: Option<Duration>,
144
210
) -> Pin<Box<dyn FusedFuture<Output = Result<F::Output, TimeoutError>> + Send + 'f>>
145
210
where
146
210
    R: Runtime,
147
210
    F: Future + Send + 'f,
148
210
{
149
210
    match timeout {
150
170
        Some(timeout) => Box::pin(runtime.timeout(timeout, future).fuse()),
151
40
        None => Box::pin(future.map(Ok)),
152
    }
153
210
}
154

            
155
/// A trait that specifies the parameters for retrying a fallible operation.
156
pub(super) trait BackoffSchedule {
157
    /// The maximum number of retries.
158
    ///
159
    /// A return value of `None` indicates is no upper limit for the number of retries, and that
160
    /// the operation should be retried until [`BackoffSchedule::overall_timeout`] time elapses (or
161
    /// indefinitely, if [`BackoffSchedule::overall_timeout`] returns `None`).
162
    fn max_retries(&self) -> Option<usize>;
163

            
164
    /// The total amount of time allowed for the retriable operation.
165
    ///
166
    /// A return value of `None` indicates the operation should be retried until
167
    /// [`BackoffSchedule::max_retries`] number of retries are exceeded (or indefinitely, if
168
    /// [`BackoffSchedule::max_retries`] returns `None`).
169
    fn overall_timeout(&self) -> Option<Duration>;
170

            
171
    /// The total amount of time allowed for a single operation.
172
    fn single_attempt_timeout(&self) -> Option<Duration>;
173

            
174
    /// Return the delay to introduce before the next retry.
175
    ///
176
    /// The `error` parameter contains the error returned by the fallible operation. This enables
177
    /// implementors to (optionally) implement adaptive backoff. For example, if the operation is
178
    /// sending an HTTP request, and the error is a 429 (Too Many Requests) HTTP response with a
179
    /// `Retry-After` header, the implementor can implement a backoff schedule where the next retry
180
    /// is delayed by the value specified in the `Retry-After` header.
181
    fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration>;
182
}
183

            
184
/// The type of error encountered while running a fallible operation.
185
#[derive(Clone, Debug, thiserror::Error)]
186
pub(super) enum BackoffError<E> {
187
    /// A fatal (non-transient) error occurred.
188
    #[error("A fatal (non-transient) error occurred")]
189
    FatalError(RetryError<E>),
190

            
191
    /// Ran out of retries.
192
    #[error("Ran out of retries")]
193
    MaxRetryCountExceeded(RetryError<E>),
194

            
195
    /// Exceeded the maximum allowed time.
196
    #[error("Timeout exceeded")]
197
    Timeout(RetryError<E>),
198

            
199
    /// The [`BackoffSchedule`] told us to stop retrying.
200
    #[error("Stopped retrying as requested by BackoffSchedule")]
201
    ExplicitStop(RetryError<E>),
202
}
203

            
204
impl<E> From<BackoffError<E>> for RetryError<E> {
205
    fn from(e: BackoffError<E>) -> Self {
206
        match e {
207
            BackoffError::FatalError(e)
208
            | BackoffError::MaxRetryCountExceeded(e)
209
            | BackoffError::Timeout(e)
210
            | BackoffError::ExplicitStop(e) => e,
211
        }
212
    }
213
}
214

            
215
/// A trait for representing retriable errors.
216
pub(super) trait RetriableError: StdError + Clone {
217
    /// Whether this error is transient.
218
    fn should_retry(&self) -> bool;
219
}
220

            
221
impl RetriableError for TimeoutError {
222
10
    fn should_retry(&self) -> bool {
223
10
        true
224
10
    }
225
}
226

            
227
#[cfg(test)]
228
mod tests {
229
    // @@ begin test lint list maintained by maint/add_warning @@
230
    #![allow(clippy::bool_assert_comparison)]
231
    #![allow(clippy::clone_on_copy)]
232
    #![allow(clippy::dbg_macro)]
233
    #![allow(clippy::mixed_attributes_style)]
234
    #![allow(clippy::print_stderr)]
235
    #![allow(clippy::print_stdout)]
236
    #![allow(clippy::single_char_pattern)]
237
    #![allow(clippy::unwrap_used)]
238
    #![allow(clippy::unchecked_duration_subtraction)]
239
    #![allow(clippy::useless_vec)]
240
    #![allow(clippy::needless_pass_by_value)]
241
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
242

            
243
    use super::*;
244
    use std::sync::Arc;
245

            
246
    use std::iter;
247
    use std::sync::RwLock;
248

            
249
    use oneshot_fused_workaround as oneshot;
250
    use tor_rtcompat::{BlockOn, SleepProvider};
251
    use tor_rtmock::MockRuntime;
252

            
253
    const SHORT_DELAY: Duration = Duration::from_millis(10);
254
    const TIMEOUT: Duration = Duration::from_millis(100);
255
    const SINGLE_TIMEOUT: Duration = Duration::from_millis(50);
256
    const MAX_RETRIES: usize = 5;
257

            
258
    macro_rules! impl_backoff_sched {
259
        ($name:ty, $max_retries:expr, $timeout:expr, $single_timeout:expr, $next_delay:expr) => {
260
            impl BackoffSchedule for $name {
261
                fn max_retries(&self) -> Option<usize> {
262
                    $max_retries
263
                }
264

            
265
                fn overall_timeout(&self) -> Option<Duration> {
266
                    $timeout
267
                }
268

            
269
                fn single_attempt_timeout(&self) -> Option<Duration> {
270
                    $single_timeout
271
                }
272

            
273
                #[allow(unused_variables)]
274
                fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration> {
275
                    $next_delay
276
                }
277
            }
278
        };
279
    }
280

            
281
    struct BackoffWithMaxRetries;
282

            
283
    impl_backoff_sched!(
284
        BackoffWithMaxRetries,
285
        Some(MAX_RETRIES),
286
        None,
287
        None,
288
        Some(SHORT_DELAY)
289
    );
290

            
291
    struct BackoffWithTimeout;
292

            
293
    impl_backoff_sched!(
294
        BackoffWithTimeout,
295
        None,
296
        Some(TIMEOUT),
297
        None,
298
        Some(SHORT_DELAY)
299
    );
300

            
301
    struct BackoffWithSingleTimeout;
302

            
303
    impl_backoff_sched!(
304
        BackoffWithSingleTimeout,
305
        Some(MAX_RETRIES),
306
        None,
307
        Some(SINGLE_TIMEOUT),
308
        Some(SHORT_DELAY)
309
    );
310

            
311
    /// A potentially retriable error.
312
    #[derive(Debug, Copy, Clone, thiserror::Error)]
313
    enum TestError {
314
        /// A fatal error
315
        #[error("A fatal test error")]
316
        Fatal,
317
        /// A transient error
318
        #[error("A transient test error")]
319
        Transient,
320
    }
321

            
322
    impl RetriableError for TestError {
323
        fn should_retry(&self) -> bool {
324
            match self {
325
                Self::Fatal => false,
326
                Self::Transient => true,
327
            }
328
        }
329
    }
330

            
331
    /// Run a single [`Runner`] test.
332
    fn run_test<E: RetriableError + Send + Sync + 'static>(
333
        sleep_for: Option<Duration>,
334
        schedule: impl BackoffSchedule + Send + 'static,
335
        errors: impl Iterator<Item = E> + Send + Sync + 'static,
336
        expected_run_count: usize,
337
        description: &'static str,
338
        expected_duration: Duration,
339
    ) {
340
        let runtime = MockRuntime::new();
341

            
342
        runtime.clone().block_on(async move {
343
            let runner = Runner {
344
                doing: description.into(),
345
                schedule,
346
                runtime: runtime.clone(),
347
            };
348

            
349
            let retry_count = Arc::new(RwLock::new(0));
350
            let (tx, rx) = oneshot::channel();
351

            
352
            let start = runtime.now();
353
            runtime
354
                .mock_task()
355
                .spawn_identified(format!("retry runner task: {description}"), {
356
                    let retry_count = Arc::clone(&retry_count);
357
                    let errors = Arc::new(RwLock::new(errors));
358
                    let runtime = runtime.clone();
359
                    async move {
360
                        if let Ok(()) = runner
361
                            .run(|| async {
362
                                *retry_count.write().unwrap() += 1;
363

            
364
                                if let Some(dur) = sleep_for {
365
                                    runtime.sleep(dur).await;
366
                                }
367

            
368
                                Err::<(), _>(errors.write().unwrap().next().unwrap())
369
                            })
370
                            .await
371
                        {
372
                            unreachable!();
373
                        }
374

            
375
                        let () = tx.send(()).unwrap();
376
                    }
377
                });
378

            
379
            // The expected retry count may be unknown (for example, if we set a timeout but no
380
            // upper limit for the number of retries, it's impossible to tell exactly how many
381
            // times the operation will be retried)
382
            for i in 1..=expected_run_count {
383
                runtime.mock_task().progress_until_stalled().await;
384
                // If our fallible_op is sleeping, advance the time until after it times out or
385
                // finishes sleeping.
386
                if let Some(sleep_for) = sleep_for {
387
                    runtime
388
                        .mock_sleep()
389
                        .advance(std::cmp::min(SINGLE_TIMEOUT, sleep_for));
390
                }
391
                runtime.mock_task().progress_until_stalled().await;
392
                runtime.mock_sleep().advance(SHORT_DELAY);
393
                assert_eq!(*retry_count.read().unwrap(), i);
394
            }
395

            
396
            let () = rx.await.unwrap();
397
            let end = runtime.now();
398

            
399
            assert_eq!(*retry_count.read().unwrap(), expected_run_count);
400
            assert!(duration_close_to(end - start, expected_duration));
401
        });
402
    }
403

            
404
    /// Return true if d1 is in range [d2...d2 + 0.01sec]
405
    ///
406
    /// TODO: lifted from tor-circmgr
407
    fn duration_close_to(d1: Duration, d2: Duration) -> bool {
408
        d1 >= d2 && d1 <= d2 + SHORT_DELAY
409
    }
410

            
411
    #[test]
412
    fn max_retries() {
413
        run_test(
414
            None,
415
            BackoffWithMaxRetries,
416
            iter::repeat(TestError::Transient),
417
            MAX_RETRIES,
418
            "backoff with max_retries and no timeout (transient errors)",
419
            Duration::from_millis(SHORT_DELAY.as_millis() as u64 * MAX_RETRIES as u64),
420
        );
421
    }
422

            
423
    #[test]
424
    fn max_retries_fatal() {
425
        use TestError::*;
426

            
427
        /// The number of transient errors that happen before the final, fatal error.
428
        const RETRIES_UNTIL_FATAL: usize = 3;
429
        /// The total number of times we exoect the fallible function to be called.
430
        /// The first RETRIES_UNTIL_FATAL times, a transient error is returned.
431
        /// The last call corresponds to the fatal error
432
        const EXPECTED_TOTAL_RUNS: usize = RETRIES_UNTIL_FATAL + 1;
433

            
434
        run_test(
435
            None,
436
            BackoffWithMaxRetries,
437
            iter::repeat(Transient)
438
                .take(RETRIES_UNTIL_FATAL)
439
                .chain([Fatal])
440
                .chain(iter::repeat(Transient)),
441
            EXPECTED_TOTAL_RUNS,
442
            "backoff with max_retries and no timeout (transient errors followed by a fatal error)",
443
            Duration::from_millis(SHORT_DELAY.as_millis() as u64 * EXPECTED_TOTAL_RUNS as u64),
444
        );
445
    }
446

            
447
    #[test]
448
    fn timeout() {
449
        use TestError::*;
450

            
451
        let expected_run_count = TIMEOUT.as_millis() / SHORT_DELAY.as_millis();
452

            
453
        run_test(
454
            None,
455
            BackoffWithTimeout,
456
            iter::repeat(Transient),
457
            expected_run_count as usize,
458
            "backoff with timeout and no max_retries (transient errors)",
459
            TIMEOUT,
460
        );
461
    }
462

            
463
    #[test]
464
    fn single_timeout() {
465
        use TestError::*;
466

            
467
        // Each attempt will time out after SINGLE_TIMEOUT time units,
468
        // and the backoff runner sleeps for SLEEP_DELAY units in between retries
469
        let expected_duration = Duration::from_millis(
470
            (SHORT_DELAY.as_millis() + SINGLE_TIMEOUT.as_millis()) as u64 * MAX_RETRIES as u64,
471
        );
472

            
473
        run_test(
474
            // Sleep for more than SINGLE_TIMEOUT units
475
            // to trigger the single_attempt_timeout() timeout
476
            Some(SINGLE_TIMEOUT * 2),
477
            BackoffWithSingleTimeout,
478
            iter::repeat(Transient),
479
            MAX_RETRIES,
480
            "backoff with single timeout and max_retries and no overall timeout",
481
            expected_duration,
482
        );
483
    }
484

            
485
    // TODO (#1120): needs tests for the remaining corner cases
486
}