tor_hsservice/publish/
backoff.rs

1//! Helpers for retrying a fallible operation according to a backoff schedule.
2//!
3//! [`Runner::run`] retries the specified operation according to the [`BackoffSchedule`] of the
4//! [`Runner`]. Users can customize the backoff behavior by implementing [`BackoffSchedule`].
5
6// TODO: this is a (somewhat) general-purpose utility, so it should probably be factored out of
7// tor-hsservice
8
9use std::pin::Pin;
10
11use futures::future::FusedFuture;
12
13use tor_rtcompat::TimeoutError;
14
15use super::*;
16
17/// A runner for a fallible operation, which retries on failure according to a [`BackoffSchedule`].
18pub(super) struct Runner<B: BackoffSchedule, R: Runtime> {
19    /// A description of the operation we are trying to do.
20    doing: String,
21    /// The backoff schedule.
22    schedule: B,
23    /// The runtime.
24    runtime: R,
25}
26
27impl<B: BackoffSchedule, R: Runtime> Runner<B, R> {
28    /// Create a new `Runner`.
29    pub(super) fn new(doing: String, schedule: B, runtime: R) -> Self {
30        Self {
31            doing,
32            schedule,
33            runtime,
34        }
35    }
36
37    /// Run `fallible_fn`, retrying according to the [`BackoffSchedule`] of this `Runner`.
38    ///
39    /// If `fallible_fn` eventually returns `Ok(_)`, return that output. Otherwise,
40    /// keep retrying until either `fallible_fn` has failed too many times, or until
41    /// a fatal error occurs.
42    #[allow(clippy::cognitive_complexity)] // TODO: Refactor
43    pub(super) async fn run<T, E, F>(
44        mut self,
45        mut fallible_fn: impl FnMut() -> F,
46    ) -> Result<T, BackoffError<E>>
47    where
48        E: RetriableError,
49        F: Future<Output = Result<T, E>> + Send,
50    {
51        let mut retry_count = 0;
52        let mut errors = RetryError::in_attempt_to(self.doing.clone());
53
54        // When this timeout elapses, the `Runner` will stop retrying the fallible operation.
55        //
56        // A `overall_timeout` of `None` means there is no time limit for the retries.
57        let mut overall_timeout = match self.schedule.overall_timeout() {
58            Some(timeout) => Either::Left(Box::pin(self.runtime.sleep(timeout))),
59            None => Either::Right(future::pending()),
60        }
61        .fuse();
62
63        loop {
64            // Bail if we've exceeded the number of allowed retries.
65            if matches!(self.schedule.max_retries(), Some(max_retry_count) if retry_count >= max_retry_count)
66            {
67                return Err(BackoffError::MaxRetryCountExceeded(errors));
68            }
69
70            let mut fallible_op = optionally_timeout(
71                &self.runtime,
72                fallible_fn(),
73                self.schedule.single_attempt_timeout(),
74            );
75
76            trace!(attempt = (retry_count + 1), "{}", self.doing);
77
78            select_biased! {
79                () = overall_timeout => {
80                    // The timeout has elapsed, so stop retrying and return the errors
81                    // accumulated so far.
82                    return Err(BackoffError::Timeout(errors))
83                }
84                res = fallible_op => {
85                    // TODO: the error branches in the match below have different error types,
86                    // so we must compute should_retry and delay separately, on each branch.
87                    //
88                    // We could refactor this to extract the error using
89                    // let err = match res { ... } and call err.should_retry()
90                    // and next_delay() after the match, but this will involve
91                    // rethinking the BackoffSchedule trait and/or RetriableError
92                    // (currently RetriableError is Clone, so it's not object safe).
93                    let (should_retry, delay) = match res {
94                        Ok(Ok(res)) => return Ok(res),
95                        Ok(Err(e)) => {
96                            // The operation failed: check if we can retry it.
97                            let should_retry = e.should_retry();
98
99                            debug!(
100                                attempt=(retry_count + 1), can_retry=should_retry,
101                                "failed to {}: {e}", self.doing
102                            );
103
104                            errors.push(e.clone());
105                            (e.should_retry(), self.schedule.next_delay(&e))
106                        }
107                        Err(e) => {
108                            trace!("fallible operation timed out; retrying");
109                            (e.should_retry(), self.schedule.next_delay(&e))
110                        },
111                    };
112
113                    if should_retry {
114                        retry_count += 1;
115
116                        let Some(delay) = delay else {
117                            return Err(BackoffError::ExplicitStop(errors));
118                        };
119
120                        // Introduce the specified delay between retries
121                        let () = self.runtime.sleep(delay).await;
122
123                        // Try again unless the entire operation has timed out.
124                        continue;
125                    }
126
127                    return Err(BackoffError::FatalError(errors));
128                },
129            }
130        }
131    }
132}
133
134/// Wrap a [`Future`] with an optional timeout.
135///
136/// If `timeout` is `Some`, returns a [`Timeout`](tor_rtcompat::Timeout)
137/// that resolves to the value of `future` if the future completes within `timeout`,
138/// or a [`TimeoutError`] if it does not.
139/// If `timeout` is `None`, returns a new future which maps the specified `future`'s
140/// output type to a `Result::Ok`.
141fn optionally_timeout<'f, R, F>(
142    runtime: &R,
143    future: F,
144    timeout: Option<Duration>,
145) -> Pin<Box<dyn FusedFuture<Output = Result<F::Output, TimeoutError>> + Send + 'f>>
146where
147    R: Runtime,
148    F: Future + Send + 'f,
149{
150    match timeout {
151        Some(timeout) => Box::pin(runtime.timeout(timeout, future).fuse()),
152        None => Box::pin(future.map(Ok)),
153    }
154}
155
156/// A trait that specifies the parameters for retrying a fallible operation.
157pub(super) trait BackoffSchedule {
158    /// The maximum number of retries.
159    ///
160    /// A return value of `None` indicates is no upper limit for the number of retries, and that
161    /// the operation should be retried until [`BackoffSchedule::overall_timeout`] time elapses (or
162    /// indefinitely, if [`BackoffSchedule::overall_timeout`] returns `None`).
163    fn max_retries(&self) -> Option<usize>;
164
165    /// The total amount of time allowed for the retriable operation.
166    ///
167    /// A return value of `None` indicates the operation should be retried until
168    /// [`BackoffSchedule::max_retries`] number of retries are exceeded (or indefinitely, if
169    /// [`BackoffSchedule::max_retries`] returns `None`).
170    fn overall_timeout(&self) -> Option<Duration>;
171
172    /// The total amount of time allowed for a single operation.
173    fn single_attempt_timeout(&self) -> Option<Duration>;
174
175    /// Return the delay to introduce before the next retry.
176    ///
177    /// The `error` parameter contains the error returned by the fallible operation. This enables
178    /// implementors to (optionally) implement adaptive backoff. For example, if the operation is
179    /// sending an HTTP request, and the error is a 429 (Too Many Requests) HTTP response with a
180    /// `Retry-After` header, the implementor can implement a backoff schedule where the next retry
181    /// is delayed by the value specified in the `Retry-After` header.
182    fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration>;
183}
184
185/// The type of error encountered while running a fallible operation.
186#[derive(Clone, Debug, thiserror::Error)]
187pub(super) enum BackoffError<E> {
188    /// A fatal (non-transient) error occurred.
189    #[error("A fatal (non-transient) error occurred")]
190    FatalError(RetryError<E>),
191
192    /// Ran out of retries.
193    #[error("Ran out of retries")]
194    MaxRetryCountExceeded(RetryError<E>),
195
196    /// Exceeded the maximum allowed time.
197    #[error("Timeout exceeded")]
198    Timeout(RetryError<E>),
199
200    /// The [`BackoffSchedule`] told us to stop retrying.
201    #[error("Stopped retrying as requested by BackoffSchedule")]
202    ExplicitStop(RetryError<E>),
203}
204
205impl<E> From<BackoffError<E>> for RetryError<E> {
206    fn from(e: BackoffError<E>) -> Self {
207        match e {
208            BackoffError::FatalError(e)
209            | BackoffError::MaxRetryCountExceeded(e)
210            | BackoffError::Timeout(e)
211            | BackoffError::ExplicitStop(e) => e,
212        }
213    }
214}
215
216/// A trait for representing retriable errors.
217pub(super) trait RetriableError: StdError + Clone {
218    /// Whether this error is transient.
219    fn should_retry(&self) -> bool;
220}
221
222impl RetriableError for TimeoutError {
223    fn should_retry(&self) -> bool {
224        true
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    // @@ begin test lint list maintained by maint/add_warning @@
231    #![allow(clippy::bool_assert_comparison)]
232    #![allow(clippy::clone_on_copy)]
233    #![allow(clippy::dbg_macro)]
234    #![allow(clippy::mixed_attributes_style)]
235    #![allow(clippy::print_stderr)]
236    #![allow(clippy::print_stdout)]
237    #![allow(clippy::single_char_pattern)]
238    #![allow(clippy::unwrap_used)]
239    #![allow(clippy::unchecked_duration_subtraction)]
240    #![allow(clippy::useless_vec)]
241    #![allow(clippy::needless_pass_by_value)]
242    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
243
244    use super::*;
245    use std::sync::Arc;
246
247    use std::iter;
248    use std::sync::RwLock;
249
250    use oneshot_fused_workaround as oneshot;
251    use tor_rtcompat::{SleepProvider, ToplevelBlockOn};
252    use tor_rtmock::MockRuntime;
253
254    const SHORT_DELAY: Duration = Duration::from_millis(10);
255    const TIMEOUT: Duration = Duration::from_millis(100);
256    const SINGLE_TIMEOUT: Duration = Duration::from_millis(50);
257    const MAX_RETRIES: usize = 5;
258
259    macro_rules! impl_backoff_sched {
260        ($name:ty, $max_retries:expr, $timeout:expr, $single_timeout:expr, $next_delay:expr) => {
261            impl BackoffSchedule for $name {
262                fn max_retries(&self) -> Option<usize> {
263                    $max_retries
264                }
265
266                fn overall_timeout(&self) -> Option<Duration> {
267                    $timeout
268                }
269
270                fn single_attempt_timeout(&self) -> Option<Duration> {
271                    $single_timeout
272                }
273
274                #[allow(unused_variables)]
275                fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration> {
276                    $next_delay
277                }
278            }
279        };
280    }
281
282    struct BackoffWithMaxRetries;
283
284    impl_backoff_sched!(
285        BackoffWithMaxRetries,
286        Some(MAX_RETRIES),
287        None,
288        None,
289        Some(SHORT_DELAY)
290    );
291
292    struct BackoffWithTimeout;
293
294    impl_backoff_sched!(
295        BackoffWithTimeout,
296        None,
297        Some(TIMEOUT),
298        None,
299        Some(SHORT_DELAY)
300    );
301
302    struct BackoffWithSingleTimeout;
303
304    impl_backoff_sched!(
305        BackoffWithSingleTimeout,
306        Some(MAX_RETRIES),
307        None,
308        Some(SINGLE_TIMEOUT),
309        Some(SHORT_DELAY)
310    );
311
312    /// A potentially retriable error.
313    #[derive(Debug, Copy, Clone, thiserror::Error)]
314    enum TestError {
315        /// A fatal error
316        #[error("A fatal test error")]
317        Fatal,
318        /// A transient error
319        #[error("A transient test error")]
320        Transient,
321    }
322
323    impl RetriableError for TestError {
324        fn should_retry(&self) -> bool {
325            match self {
326                Self::Fatal => false,
327                Self::Transient => true,
328            }
329        }
330    }
331
332    /// Run a single [`Runner`] test.
333    fn run_test<E: RetriableError + Send + Sync + 'static>(
334        sleep_for: Option<Duration>,
335        schedule: impl BackoffSchedule + Send + 'static,
336        errors: impl Iterator<Item = E> + Send + Sync + 'static,
337        expected_run_count: usize,
338        description: &'static str,
339        expected_duration: Duration,
340    ) {
341        let runtime = MockRuntime::new();
342
343        runtime.clone().block_on(async move {
344            let runner = Runner {
345                doing: description.into(),
346                schedule,
347                runtime: runtime.clone(),
348            };
349
350            let retry_count = Arc::new(RwLock::new(0));
351            let (tx, rx) = oneshot::channel();
352
353            let start = runtime.now();
354            runtime
355                .mock_task()
356                .spawn_identified(format!("retry runner task: {description}"), {
357                    let retry_count = Arc::clone(&retry_count);
358                    let errors = Arc::new(RwLock::new(errors));
359                    let runtime = runtime.clone();
360                    async move {
361                        if let Ok(()) = runner
362                            .run(|| async {
363                                *retry_count.write().unwrap() += 1;
364
365                                if let Some(dur) = sleep_for {
366                                    runtime.sleep(dur).await;
367                                }
368
369                                Err::<(), _>(errors.write().unwrap().next().unwrap())
370                            })
371                            .await
372                        {
373                            unreachable!();
374                        }
375
376                        let () = tx.send(()).unwrap();
377                    }
378                });
379
380            // The expected retry count may be unknown (for example, if we set a timeout but no
381            // upper limit for the number of retries, it's impossible to tell exactly how many
382            // times the operation will be retried)
383            for i in 1..=expected_run_count {
384                runtime.mock_task().progress_until_stalled().await;
385                // If our fallible_op is sleeping, advance the time until after it times out or
386                // finishes sleeping.
387                if let Some(sleep_for) = sleep_for {
388                    runtime
389                        .mock_sleep()
390                        .advance(std::cmp::min(SINGLE_TIMEOUT, sleep_for));
391                }
392                runtime.mock_task().progress_until_stalled().await;
393                runtime.mock_sleep().advance(SHORT_DELAY);
394                assert_eq!(*retry_count.read().unwrap(), i);
395            }
396
397            let () = rx.await.unwrap();
398            let end = runtime.now();
399
400            assert_eq!(*retry_count.read().unwrap(), expected_run_count);
401            assert!(duration_close_to(end - start, expected_duration));
402        });
403    }
404
405    /// Return true if d1 is in range [d2...d2 + 0.01sec]
406    ///
407    /// TODO: lifted from tor-circmgr
408    fn duration_close_to(d1: Duration, d2: Duration) -> bool {
409        d1 >= d2 && d1 <= d2 + SHORT_DELAY
410    }
411
412    #[test]
413    fn max_retries() {
414        run_test(
415            None,
416            BackoffWithMaxRetries,
417            iter::repeat(TestError::Transient),
418            MAX_RETRIES,
419            "backoff with max_retries and no timeout (transient errors)",
420            Duration::from_millis(SHORT_DELAY.as_millis() as u64 * MAX_RETRIES as u64),
421        );
422    }
423
424    #[test]
425    fn max_retries_fatal() {
426        use TestError::*;
427
428        /// The number of transient errors that happen before the final, fatal error.
429        const RETRIES_UNTIL_FATAL: usize = 3;
430        /// The total number of times we exoect the fallible function to be called.
431        /// The first RETRIES_UNTIL_FATAL times, a transient error is returned.
432        /// The last call corresponds to the fatal error
433        const EXPECTED_TOTAL_RUNS: usize = RETRIES_UNTIL_FATAL + 1;
434
435        run_test(
436            None,
437            BackoffWithMaxRetries,
438            std::iter::repeat_n(Transient, RETRIES_UNTIL_FATAL)
439                .chain([Fatal])
440                .chain(iter::repeat(Transient)),
441            EXPECTED_TOTAL_RUNS,
442            "backoff with max_retries and no timeout (transient errors followed by a fatal error)",
443            Duration::from_millis(SHORT_DELAY.as_millis() as u64 * EXPECTED_TOTAL_RUNS as u64),
444        );
445    }
446
447    #[test]
448    fn timeout() {
449        use TestError::*;
450
451        let expected_run_count = TIMEOUT.as_millis() / SHORT_DELAY.as_millis();
452
453        run_test(
454            None,
455            BackoffWithTimeout,
456            iter::repeat(Transient),
457            expected_run_count as usize,
458            "backoff with timeout and no max_retries (transient errors)",
459            TIMEOUT,
460        );
461    }
462
463    #[test]
464    fn single_timeout() {
465        use TestError::*;
466
467        // Each attempt will time out after SINGLE_TIMEOUT time units,
468        // and the backoff runner sleeps for SLEEP_DELAY units in between retries
469        let expected_duration = Duration::from_millis(
470            (SHORT_DELAY.as_millis() + SINGLE_TIMEOUT.as_millis()) as u64 * MAX_RETRIES as u64,
471        );
472
473        run_test(
474            // Sleep for more than SINGLE_TIMEOUT units
475            // to trigger the single_attempt_timeout() timeout
476            Some(SINGLE_TIMEOUT * 2),
477            BackoffWithSingleTimeout,
478            iter::repeat(Transient),
479            MAX_RETRIES,
480            "backoff with single timeout and max_retries and no overall timeout",
481            expected_duration,
482        );
483    }
484
485    // TODO (#1120): needs tests for the remaining corner cases
486}