tor_hsservice/publish/
backoff.rs

1//! Helpers for retrying a fallible operation according to a backoff schedule.
2//!
3//! [`Runner::run`] retries the specified operation according to the [`BackoffSchedule`] of the
4//! [`Runner`]. Users can customize the backoff behavior by implementing [`BackoffSchedule`].
5
6// TODO: this is a (somewhat) general-purpose utility, so it should probably be factored out of
7// tor-hsservice
8
9use std::pin::Pin;
10
11use futures::future::FusedFuture;
12
13use tor_rtcompat::TimeoutError;
14
15use super::*;
16
17/// A runner for a fallible operation, which retries on failure according to a [`BackoffSchedule`].
18pub(super) struct Runner<B: BackoffSchedule, R: Runtime> {
19    /// A description of the operation we are trying to do.
20    doing: String,
21    /// The backoff schedule.
22    schedule: B,
23    /// The runtime.
24    runtime: R,
25}
26
27impl<B: BackoffSchedule, R: Runtime> Runner<B, R> {
28    /// Create a new `Runner`.
29    pub(super) fn new(doing: String, schedule: B, runtime: R) -> Self {
30        Self {
31            doing,
32            schedule,
33            runtime,
34        }
35    }
36
37    /// Run `fallible_fn`, retrying according to the [`BackoffSchedule`] of this `Runner`.
38    ///
39    /// If `fallible_fn` eventually returns `Ok(_)`, return that output. Otherwise,
40    /// keep retrying until either `fallible_fn` has failed too many times, or until
41    /// a fatal error occurs.
42    pub(super) async fn run<T, E, F>(
43        mut self,
44        mut fallible_fn: impl FnMut() -> F,
45    ) -> Result<T, BackoffError<E>>
46    where
47        E: RetriableError,
48        F: Future<Output = Result<T, E>> + Send,
49    {
50        let mut retry_count = 0;
51        let mut errors = RetryError::in_attempt_to(self.doing.clone());
52
53        // When this timeout elapses, the `Runner` will stop retrying the fallible operation.
54        //
55        // A `overall_timeout` of `None` means there is no time limit for the retries.
56        let mut overall_timeout = match self.schedule.overall_timeout() {
57            Some(timeout) => Either::Left(Box::pin(self.runtime.sleep(timeout))),
58            None => Either::Right(future::pending()),
59        }
60        .fuse();
61
62        loop {
63            // Bail if we've exceeded the number of allowed retries.
64            if matches!(self.schedule.max_retries(), Some(max_retry_count) if retry_count >= max_retry_count)
65            {
66                return Err(BackoffError::MaxRetryCountExceeded(errors));
67            }
68
69            let mut fallible_op = optionally_timeout(
70                &self.runtime,
71                fallible_fn(),
72                self.schedule.single_attempt_timeout(),
73            );
74
75            trace!(attempt = (retry_count + 1), "{}", self.doing);
76
77            select_biased! {
78                () = overall_timeout => {
79                    // The timeout has elapsed, so stop retrying and return the errors
80                    // accumulated so far.
81                    return Err(BackoffError::Timeout(errors))
82                }
83                res = fallible_op => {
84                    // TODO: the error branches in the match below have different error types,
85                    // so we must compute should_retry and delay separately, on each branch.
86                    //
87                    // We could refactor this to extract the error using
88                    // let err = match res { ... } and call err.should_retry()
89                    // and next_delay() after the match, but this will involve
90                    // rethinking the BackoffSchedule trait and/or RetriableError
91                    // (currently RetriableError is Clone, so it's not object safe).
92                    let (should_retry, delay) = match res {
93                        Ok(Ok(res)) => return Ok(res),
94                        Ok(Err(e)) => {
95                            // The operation failed: check if we can retry it.
96                            let should_retry = e.should_retry();
97
98                            debug!(
99                                attempt=(retry_count + 1), can_retry=should_retry,
100                                "failed to {}: {e}", self.doing
101                            );
102
103                            errors.push(e.clone());
104                            (e.should_retry(), self.schedule.next_delay(&e))
105                        }
106                        Err(e) => {
107                            trace!("fallible operation timed out; retrying");
108                            (e.should_retry(), self.schedule.next_delay(&e))
109                        },
110                    };
111
112                    if should_retry {
113                        retry_count += 1;
114
115                        let Some(delay) = delay else {
116                            return Err(BackoffError::ExplicitStop(errors));
117                        };
118
119                        // Introduce the specified delay between retries
120                        let () = self.runtime.sleep(delay).await;
121
122                        // Try again unless the entire operation has timed out.
123                        continue;
124                    }
125
126                    return Err(BackoffError::FatalError(errors));
127                },
128            }
129        }
130    }
131}
132
133/// Wrap a [`Future`] with an optional timeout.
134///
135/// If `timeout` is `Some`, returns a [`Timeout`](tor_rtcompat::Timeout)
136/// that resolves to the value of `future` if the future completes within `timeout`,
137/// or a [`TimeoutError`] if it does not.
138/// If `timeout` is `None`, returns a new future which maps the specified `future`'s
139/// output type to a `Result::Ok`.
140fn optionally_timeout<'f, R, F>(
141    runtime: &R,
142    future: F,
143    timeout: Option<Duration>,
144) -> Pin<Box<dyn FusedFuture<Output = Result<F::Output, TimeoutError>> + Send + 'f>>
145where
146    R: Runtime,
147    F: Future + Send + 'f,
148{
149    match timeout {
150        Some(timeout) => Box::pin(runtime.timeout(timeout, future).fuse()),
151        None => Box::pin(future.map(Ok)),
152    }
153}
154
155/// A trait that specifies the parameters for retrying a fallible operation.
156pub(super) trait BackoffSchedule {
157    /// The maximum number of retries.
158    ///
159    /// A return value of `None` indicates is no upper limit for the number of retries, and that
160    /// the operation should be retried until [`BackoffSchedule::overall_timeout`] time elapses (or
161    /// indefinitely, if [`BackoffSchedule::overall_timeout`] returns `None`).
162    fn max_retries(&self) -> Option<usize>;
163
164    /// The total amount of time allowed for the retriable operation.
165    ///
166    /// A return value of `None` indicates the operation should be retried until
167    /// [`BackoffSchedule::max_retries`] number of retries are exceeded (or indefinitely, if
168    /// [`BackoffSchedule::max_retries`] returns `None`).
169    fn overall_timeout(&self) -> Option<Duration>;
170
171    /// The total amount of time allowed for a single operation.
172    fn single_attempt_timeout(&self) -> Option<Duration>;
173
174    /// Return the delay to introduce before the next retry.
175    ///
176    /// The `error` parameter contains the error returned by the fallible operation. This enables
177    /// implementors to (optionally) implement adaptive backoff. For example, if the operation is
178    /// sending an HTTP request, and the error is a 429 (Too Many Requests) HTTP response with a
179    /// `Retry-After` header, the implementor can implement a backoff schedule where the next retry
180    /// is delayed by the value specified in the `Retry-After` header.
181    fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration>;
182}
183
184/// The type of error encountered while running a fallible operation.
185#[derive(Clone, Debug, thiserror::Error)]
186pub(super) enum BackoffError<E> {
187    /// A fatal (non-transient) error occurred.
188    #[error("A fatal (non-transient) error occurred")]
189    FatalError(RetryError<E>),
190
191    /// Ran out of retries.
192    #[error("Ran out of retries")]
193    MaxRetryCountExceeded(RetryError<E>),
194
195    /// Exceeded the maximum allowed time.
196    #[error("Timeout exceeded")]
197    Timeout(RetryError<E>),
198
199    /// The [`BackoffSchedule`] told us to stop retrying.
200    #[error("Stopped retrying as requested by BackoffSchedule")]
201    ExplicitStop(RetryError<E>),
202}
203
204impl<E> From<BackoffError<E>> for RetryError<E> {
205    fn from(e: BackoffError<E>) -> Self {
206        match e {
207            BackoffError::FatalError(e)
208            | BackoffError::MaxRetryCountExceeded(e)
209            | BackoffError::Timeout(e)
210            | BackoffError::ExplicitStop(e) => e,
211        }
212    }
213}
214
215/// A trait for representing retriable errors.
216pub(super) trait RetriableError: StdError + Clone {
217    /// Whether this error is transient.
218    fn should_retry(&self) -> bool;
219}
220
221impl RetriableError for TimeoutError {
222    fn should_retry(&self) -> bool {
223        true
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    // @@ begin test lint list maintained by maint/add_warning @@
230    #![allow(clippy::bool_assert_comparison)]
231    #![allow(clippy::clone_on_copy)]
232    #![allow(clippy::dbg_macro)]
233    #![allow(clippy::mixed_attributes_style)]
234    #![allow(clippy::print_stderr)]
235    #![allow(clippy::print_stdout)]
236    #![allow(clippy::single_char_pattern)]
237    #![allow(clippy::unwrap_used)]
238    #![allow(clippy::unchecked_duration_subtraction)]
239    #![allow(clippy::useless_vec)]
240    #![allow(clippy::needless_pass_by_value)]
241    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
242
243    use super::*;
244    use std::sync::Arc;
245
246    use std::iter;
247    use std::sync::RwLock;
248
249    use oneshot_fused_workaround as oneshot;
250    use tor_rtcompat::{SleepProvider, ToplevelBlockOn};
251    use tor_rtmock::MockRuntime;
252
253    const SHORT_DELAY: Duration = Duration::from_millis(10);
254    const TIMEOUT: Duration = Duration::from_millis(100);
255    const SINGLE_TIMEOUT: Duration = Duration::from_millis(50);
256    const MAX_RETRIES: usize = 5;
257
258    macro_rules! impl_backoff_sched {
259        ($name:ty, $max_retries:expr, $timeout:expr, $single_timeout:expr, $next_delay:expr) => {
260            impl BackoffSchedule for $name {
261                fn max_retries(&self) -> Option<usize> {
262                    $max_retries
263                }
264
265                fn overall_timeout(&self) -> Option<Duration> {
266                    $timeout
267                }
268
269                fn single_attempt_timeout(&self) -> Option<Duration> {
270                    $single_timeout
271                }
272
273                #[allow(unused_variables)]
274                fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration> {
275                    $next_delay
276                }
277            }
278        };
279    }
280
281    struct BackoffWithMaxRetries;
282
283    impl_backoff_sched!(
284        BackoffWithMaxRetries,
285        Some(MAX_RETRIES),
286        None,
287        None,
288        Some(SHORT_DELAY)
289    );
290
291    struct BackoffWithTimeout;
292
293    impl_backoff_sched!(
294        BackoffWithTimeout,
295        None,
296        Some(TIMEOUT),
297        None,
298        Some(SHORT_DELAY)
299    );
300
301    struct BackoffWithSingleTimeout;
302
303    impl_backoff_sched!(
304        BackoffWithSingleTimeout,
305        Some(MAX_RETRIES),
306        None,
307        Some(SINGLE_TIMEOUT),
308        Some(SHORT_DELAY)
309    );
310
311    /// A potentially retriable error.
312    #[derive(Debug, Copy, Clone, thiserror::Error)]
313    enum TestError {
314        /// A fatal error
315        #[error("A fatal test error")]
316        Fatal,
317        /// A transient error
318        #[error("A transient test error")]
319        Transient,
320    }
321
322    impl RetriableError for TestError {
323        fn should_retry(&self) -> bool {
324            match self {
325                Self::Fatal => false,
326                Self::Transient => true,
327            }
328        }
329    }
330
331    /// Run a single [`Runner`] test.
332    fn run_test<E: RetriableError + Send + Sync + 'static>(
333        sleep_for: Option<Duration>,
334        schedule: impl BackoffSchedule + Send + 'static,
335        errors: impl Iterator<Item = E> + Send + Sync + 'static,
336        expected_run_count: usize,
337        description: &'static str,
338        expected_duration: Duration,
339    ) {
340        let runtime = MockRuntime::new();
341
342        runtime.clone().block_on(async move {
343            let runner = Runner {
344                doing: description.into(),
345                schedule,
346                runtime: runtime.clone(),
347            };
348
349            let retry_count = Arc::new(RwLock::new(0));
350            let (tx, rx) = oneshot::channel();
351
352            let start = runtime.now();
353            runtime
354                .mock_task()
355                .spawn_identified(format!("retry runner task: {description}"), {
356                    let retry_count = Arc::clone(&retry_count);
357                    let errors = Arc::new(RwLock::new(errors));
358                    let runtime = runtime.clone();
359                    async move {
360                        if let Ok(()) = runner
361                            .run(|| async {
362                                *retry_count.write().unwrap() += 1;
363
364                                if let Some(dur) = sleep_for {
365                                    runtime.sleep(dur).await;
366                                }
367
368                                Err::<(), _>(errors.write().unwrap().next().unwrap())
369                            })
370                            .await
371                        {
372                            unreachable!();
373                        }
374
375                        let () = tx.send(()).unwrap();
376                    }
377                });
378
379            // The expected retry count may be unknown (for example, if we set a timeout but no
380            // upper limit for the number of retries, it's impossible to tell exactly how many
381            // times the operation will be retried)
382            for i in 1..=expected_run_count {
383                runtime.mock_task().progress_until_stalled().await;
384                // If our fallible_op is sleeping, advance the time until after it times out or
385                // finishes sleeping.
386                if let Some(sleep_for) = sleep_for {
387                    runtime
388                        .mock_sleep()
389                        .advance(std::cmp::min(SINGLE_TIMEOUT, sleep_for));
390                }
391                runtime.mock_task().progress_until_stalled().await;
392                runtime.mock_sleep().advance(SHORT_DELAY);
393                assert_eq!(*retry_count.read().unwrap(), i);
394            }
395
396            let () = rx.await.unwrap();
397            let end = runtime.now();
398
399            assert_eq!(*retry_count.read().unwrap(), expected_run_count);
400            assert!(duration_close_to(end - start, expected_duration));
401        });
402    }
403
404    /// Return true if d1 is in range [d2...d2 + 0.01sec]
405    ///
406    /// TODO: lifted from tor-circmgr
407    fn duration_close_to(d1: Duration, d2: Duration) -> bool {
408        d1 >= d2 && d1 <= d2 + SHORT_DELAY
409    }
410
411    #[test]
412    fn max_retries() {
413        run_test(
414            None,
415            BackoffWithMaxRetries,
416            iter::repeat(TestError::Transient),
417            MAX_RETRIES,
418            "backoff with max_retries and no timeout (transient errors)",
419            Duration::from_millis(SHORT_DELAY.as_millis() as u64 * MAX_RETRIES as u64),
420        );
421    }
422
423    #[test]
424    fn max_retries_fatal() {
425        use TestError::*;
426
427        /// The number of transient errors that happen before the final, fatal error.
428        const RETRIES_UNTIL_FATAL: usize = 3;
429        /// The total number of times we exoect the fallible function to be called.
430        /// The first RETRIES_UNTIL_FATAL times, a transient error is returned.
431        /// The last call corresponds to the fatal error
432        const EXPECTED_TOTAL_RUNS: usize = RETRIES_UNTIL_FATAL + 1;
433
434        run_test(
435            None,
436            BackoffWithMaxRetries,
437            std::iter::repeat_n(Transient, RETRIES_UNTIL_FATAL)
438                .chain([Fatal])
439                .chain(iter::repeat(Transient)),
440            EXPECTED_TOTAL_RUNS,
441            "backoff with max_retries and no timeout (transient errors followed by a fatal error)",
442            Duration::from_millis(SHORT_DELAY.as_millis() as u64 * EXPECTED_TOTAL_RUNS as u64),
443        );
444    }
445
446    #[test]
447    fn timeout() {
448        use TestError::*;
449
450        let expected_run_count = TIMEOUT.as_millis() / SHORT_DELAY.as_millis();
451
452        run_test(
453            None,
454            BackoffWithTimeout,
455            iter::repeat(Transient),
456            expected_run_count as usize,
457            "backoff with timeout and no max_retries (transient errors)",
458            TIMEOUT,
459        );
460    }
461
462    #[test]
463    fn single_timeout() {
464        use TestError::*;
465
466        // Each attempt will time out after SINGLE_TIMEOUT time units,
467        // and the backoff runner sleeps for SLEEP_DELAY units in between retries
468        let expected_duration = Duration::from_millis(
469            (SHORT_DELAY.as_millis() + SINGLE_TIMEOUT.as_millis()) as u64 * MAX_RETRIES as u64,
470        );
471
472        run_test(
473            // Sleep for more than SINGLE_TIMEOUT units
474            // to trigger the single_attempt_timeout() timeout
475            Some(SINGLE_TIMEOUT * 2),
476            BackoffWithSingleTimeout,
477            iter::repeat(Transient),
478            MAX_RETRIES,
479            "backoff with single timeout and max_retries and no overall timeout",
480            expected_duration,
481        );
482    }
483
484    // TODO (#1120): needs tests for the remaining corner cases
485}