1use std::pin::Pin;
10
11use futures::future::FusedFuture;
12
13use tor_rtcompat::TimeoutError;
14
15use super::*;
16
17pub(super) struct Runner<B: BackoffSchedule, R: Runtime> {
19 doing: String,
21 schedule: B,
23 runtime: R,
25}
26
27impl<B: BackoffSchedule, R: Runtime> Runner<B, R> {
28 pub(super) fn new(doing: String, schedule: B, runtime: R) -> Self {
30 Self {
31 doing,
32 schedule,
33 runtime,
34 }
35 }
36
37 pub(super) async fn run<T, E, F>(
43 mut self,
44 mut fallible_fn: impl FnMut() -> F,
45 ) -> Result<T, BackoffError<E>>
46 where
47 E: RetriableError,
48 F: Future<Output = Result<T, E>> + Send,
49 {
50 let mut retry_count = 0;
51 let mut errors = RetryError::in_attempt_to(self.doing.clone());
52
53 let mut overall_timeout = match self.schedule.overall_timeout() {
57 Some(timeout) => Either::Left(Box::pin(self.runtime.sleep(timeout))),
58 None => Either::Right(future::pending()),
59 }
60 .fuse();
61
62 loop {
63 if matches!(self.schedule.max_retries(), Some(max_retry_count) if retry_count >= max_retry_count)
65 {
66 return Err(BackoffError::MaxRetryCountExceeded(errors));
67 }
68
69 let mut fallible_op = optionally_timeout(
70 &self.runtime,
71 fallible_fn(),
72 self.schedule.single_attempt_timeout(),
73 );
74
75 trace!(attempt = (retry_count + 1), "{}", self.doing);
76
77 select_biased! {
78 () = overall_timeout => {
79 return Err(BackoffError::Timeout(errors))
82 }
83 res = fallible_op => {
84 let (should_retry, delay) = match res {
93 Ok(Ok(res)) => return Ok(res),
94 Ok(Err(e)) => {
95 let should_retry = e.should_retry();
97
98 debug!(
99 attempt=(retry_count + 1), can_retry=should_retry,
100 "failed to {}: {e}", self.doing
101 );
102
103 errors.push(e.clone());
104 (e.should_retry(), self.schedule.next_delay(&e))
105 }
106 Err(e) => {
107 trace!("fallible operation timed out; retrying");
108 (e.should_retry(), self.schedule.next_delay(&e))
109 },
110 };
111
112 if should_retry {
113 retry_count += 1;
114
115 let Some(delay) = delay else {
116 return Err(BackoffError::ExplicitStop(errors));
117 };
118
119 let () = self.runtime.sleep(delay).await;
121
122 continue;
124 }
125
126 return Err(BackoffError::FatalError(errors));
127 },
128 }
129 }
130 }
131}
132
133fn optionally_timeout<'f, R, F>(
141 runtime: &R,
142 future: F,
143 timeout: Option<Duration>,
144) -> Pin<Box<dyn FusedFuture<Output = Result<F::Output, TimeoutError>> + Send + 'f>>
145where
146 R: Runtime,
147 F: Future + Send + 'f,
148{
149 match timeout {
150 Some(timeout) => Box::pin(runtime.timeout(timeout, future).fuse()),
151 None => Box::pin(future.map(Ok)),
152 }
153}
154
155pub(super) trait BackoffSchedule {
157 fn max_retries(&self) -> Option<usize>;
163
164 fn overall_timeout(&self) -> Option<Duration>;
170
171 fn single_attempt_timeout(&self) -> Option<Duration>;
173
174 fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration>;
182}
183
184#[derive(Clone, Debug, thiserror::Error)]
186pub(super) enum BackoffError<E> {
187 #[error("A fatal (non-transient) error occurred")]
189 FatalError(RetryError<E>),
190
191 #[error("Ran out of retries")]
193 MaxRetryCountExceeded(RetryError<E>),
194
195 #[error("Timeout exceeded")]
197 Timeout(RetryError<E>),
198
199 #[error("Stopped retrying as requested by BackoffSchedule")]
201 ExplicitStop(RetryError<E>),
202}
203
204impl<E> From<BackoffError<E>> for RetryError<E> {
205 fn from(e: BackoffError<E>) -> Self {
206 match e {
207 BackoffError::FatalError(e)
208 | BackoffError::MaxRetryCountExceeded(e)
209 | BackoffError::Timeout(e)
210 | BackoffError::ExplicitStop(e) => e,
211 }
212 }
213}
214
215pub(super) trait RetriableError: StdError + Clone {
217 fn should_retry(&self) -> bool;
219}
220
221impl RetriableError for TimeoutError {
222 fn should_retry(&self) -> bool {
223 true
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 #![allow(clippy::bool_assert_comparison)]
231 #![allow(clippy::clone_on_copy)]
232 #![allow(clippy::dbg_macro)]
233 #![allow(clippy::mixed_attributes_style)]
234 #![allow(clippy::print_stderr)]
235 #![allow(clippy::print_stdout)]
236 #![allow(clippy::single_char_pattern)]
237 #![allow(clippy::unwrap_used)]
238 #![allow(clippy::unchecked_duration_subtraction)]
239 #![allow(clippy::useless_vec)]
240 #![allow(clippy::needless_pass_by_value)]
241 use super::*;
244 use std::sync::Arc;
245
246 use std::iter;
247 use std::sync::RwLock;
248
249 use oneshot_fused_workaround as oneshot;
250 use tor_rtcompat::{SleepProvider, ToplevelBlockOn};
251 use tor_rtmock::MockRuntime;
252
253 const SHORT_DELAY: Duration = Duration::from_millis(10);
254 const TIMEOUT: Duration = Duration::from_millis(100);
255 const SINGLE_TIMEOUT: Duration = Duration::from_millis(50);
256 const MAX_RETRIES: usize = 5;
257
258 macro_rules! impl_backoff_sched {
259 ($name:ty, $max_retries:expr, $timeout:expr, $single_timeout:expr, $next_delay:expr) => {
260 impl BackoffSchedule for $name {
261 fn max_retries(&self) -> Option<usize> {
262 $max_retries
263 }
264
265 fn overall_timeout(&self) -> Option<Duration> {
266 $timeout
267 }
268
269 fn single_attempt_timeout(&self) -> Option<Duration> {
270 $single_timeout
271 }
272
273 #[allow(unused_variables)]
274 fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration> {
275 $next_delay
276 }
277 }
278 };
279 }
280
281 struct BackoffWithMaxRetries;
282
283 impl_backoff_sched!(
284 BackoffWithMaxRetries,
285 Some(MAX_RETRIES),
286 None,
287 None,
288 Some(SHORT_DELAY)
289 );
290
291 struct BackoffWithTimeout;
292
293 impl_backoff_sched!(
294 BackoffWithTimeout,
295 None,
296 Some(TIMEOUT),
297 None,
298 Some(SHORT_DELAY)
299 );
300
301 struct BackoffWithSingleTimeout;
302
303 impl_backoff_sched!(
304 BackoffWithSingleTimeout,
305 Some(MAX_RETRIES),
306 None,
307 Some(SINGLE_TIMEOUT),
308 Some(SHORT_DELAY)
309 );
310
311 #[derive(Debug, Copy, Clone, thiserror::Error)]
313 enum TestError {
314 #[error("A fatal test error")]
316 Fatal,
317 #[error("A transient test error")]
319 Transient,
320 }
321
322 impl RetriableError for TestError {
323 fn should_retry(&self) -> bool {
324 match self {
325 Self::Fatal => false,
326 Self::Transient => true,
327 }
328 }
329 }
330
331 fn run_test<E: RetriableError + Send + Sync + 'static>(
333 sleep_for: Option<Duration>,
334 schedule: impl BackoffSchedule + Send + 'static,
335 errors: impl Iterator<Item = E> + Send + Sync + 'static,
336 expected_run_count: usize,
337 description: &'static str,
338 expected_duration: Duration,
339 ) {
340 let runtime = MockRuntime::new();
341
342 runtime.clone().block_on(async move {
343 let runner = Runner {
344 doing: description.into(),
345 schedule,
346 runtime: runtime.clone(),
347 };
348
349 let retry_count = Arc::new(RwLock::new(0));
350 let (tx, rx) = oneshot::channel();
351
352 let start = runtime.now();
353 runtime
354 .mock_task()
355 .spawn_identified(format!("retry runner task: {description}"), {
356 let retry_count = Arc::clone(&retry_count);
357 let errors = Arc::new(RwLock::new(errors));
358 let runtime = runtime.clone();
359 async move {
360 if let Ok(()) = runner
361 .run(|| async {
362 *retry_count.write().unwrap() += 1;
363
364 if let Some(dur) = sleep_for {
365 runtime.sleep(dur).await;
366 }
367
368 Err::<(), _>(errors.write().unwrap().next().unwrap())
369 })
370 .await
371 {
372 unreachable!();
373 }
374
375 let () = tx.send(()).unwrap();
376 }
377 });
378
379 for i in 1..=expected_run_count {
383 runtime.mock_task().progress_until_stalled().await;
384 if let Some(sleep_for) = sleep_for {
387 runtime
388 .mock_sleep()
389 .advance(std::cmp::min(SINGLE_TIMEOUT, sleep_for));
390 }
391 runtime.mock_task().progress_until_stalled().await;
392 runtime.mock_sleep().advance(SHORT_DELAY);
393 assert_eq!(*retry_count.read().unwrap(), i);
394 }
395
396 let () = rx.await.unwrap();
397 let end = runtime.now();
398
399 assert_eq!(*retry_count.read().unwrap(), expected_run_count);
400 assert!(duration_close_to(end - start, expected_duration));
401 });
402 }
403
404 fn duration_close_to(d1: Duration, d2: Duration) -> bool {
408 d1 >= d2 && d1 <= d2 + SHORT_DELAY
409 }
410
411 #[test]
412 fn max_retries() {
413 run_test(
414 None,
415 BackoffWithMaxRetries,
416 iter::repeat(TestError::Transient),
417 MAX_RETRIES,
418 "backoff with max_retries and no timeout (transient errors)",
419 Duration::from_millis(SHORT_DELAY.as_millis() as u64 * MAX_RETRIES as u64),
420 );
421 }
422
423 #[test]
424 fn max_retries_fatal() {
425 use TestError::*;
426
427 const RETRIES_UNTIL_FATAL: usize = 3;
429 const EXPECTED_TOTAL_RUNS: usize = RETRIES_UNTIL_FATAL + 1;
433
434 run_test(
435 None,
436 BackoffWithMaxRetries,
437 std::iter::repeat_n(Transient, RETRIES_UNTIL_FATAL)
438 .chain([Fatal])
439 .chain(iter::repeat(Transient)),
440 EXPECTED_TOTAL_RUNS,
441 "backoff with max_retries and no timeout (transient errors followed by a fatal error)",
442 Duration::from_millis(SHORT_DELAY.as_millis() as u64 * EXPECTED_TOTAL_RUNS as u64),
443 );
444 }
445
446 #[test]
447 fn timeout() {
448 use TestError::*;
449
450 let expected_run_count = TIMEOUT.as_millis() / SHORT_DELAY.as_millis();
451
452 run_test(
453 None,
454 BackoffWithTimeout,
455 iter::repeat(Transient),
456 expected_run_count as usize,
457 "backoff with timeout and no max_retries (transient errors)",
458 TIMEOUT,
459 );
460 }
461
462 #[test]
463 fn single_timeout() {
464 use TestError::*;
465
466 let expected_duration = Duration::from_millis(
469 (SHORT_DELAY.as_millis() + SINGLE_TIMEOUT.as_millis()) as u64 * MAX_RETRIES as u64,
470 );
471
472 run_test(
473 Some(SINGLE_TIMEOUT * 2),
476 BackoffWithSingleTimeout,
477 iter::repeat(Transient),
478 MAX_RETRIES,
479 "backoff with single timeout and max_retries and no overall timeout",
480 expected_duration,
481 );
482 }
483
484 }