1use std::pin::Pin;
10
11use futures::future::FusedFuture;
12
13use tor_rtcompat::TimeoutError;
14
15use super::*;
16
17pub(super) struct Runner<B: BackoffSchedule, R: Runtime> {
19 doing: String,
21 schedule: B,
23 runtime: R,
25}
26
27impl<B: BackoffSchedule, R: Runtime> Runner<B, R> {
28 pub(super) fn new(doing: String, schedule: B, runtime: R) -> Self {
30 Self {
31 doing,
32 schedule,
33 runtime,
34 }
35 }
36
37 #[allow(clippy::cognitive_complexity)] pub(super) async fn run<T, E, F>(
44 mut self,
45 mut fallible_fn: impl FnMut() -> F,
46 ) -> Result<T, BackoffError<E>>
47 where
48 E: RetriableError,
49 F: Future<Output = Result<T, E>> + Send,
50 {
51 let mut retry_count = 0;
52 let mut errors = RetryError::in_attempt_to(self.doing.clone());
53
54 let mut overall_timeout = match self.schedule.overall_timeout() {
58 Some(timeout) => Either::Left(Box::pin(self.runtime.sleep(timeout))),
59 None => Either::Right(future::pending()),
60 }
61 .fuse();
62
63 loop {
64 if matches!(self.schedule.max_retries(), Some(max_retry_count) if retry_count >= max_retry_count)
66 {
67 return Err(BackoffError::MaxRetryCountExceeded(errors));
68 }
69
70 let mut fallible_op = optionally_timeout(
71 &self.runtime,
72 fallible_fn(),
73 self.schedule.single_attempt_timeout(),
74 );
75
76 trace!(attempt = (retry_count + 1), "{}", self.doing);
77
78 select_biased! {
79 () = overall_timeout => {
80 return Err(BackoffError::Timeout(errors))
83 }
84 res = fallible_op => {
85 let (should_retry, delay) = match res {
94 Ok(Ok(res)) => return Ok(res),
95 Ok(Err(e)) => {
96 let should_retry = e.should_retry();
98
99 debug!(
100 attempt=(retry_count + 1), can_retry=should_retry,
101 "failed to {}: {e}", self.doing
102 );
103
104 errors.push(e.clone());
105 (e.should_retry(), self.schedule.next_delay(&e))
106 }
107 Err(e) => {
108 trace!("fallible operation timed out; retrying");
109 (e.should_retry(), self.schedule.next_delay(&e))
110 },
111 };
112
113 if should_retry {
114 retry_count += 1;
115
116 let Some(delay) = delay else {
117 return Err(BackoffError::ExplicitStop(errors));
118 };
119
120 let () = self.runtime.sleep(delay).await;
122
123 continue;
125 }
126
127 return Err(BackoffError::FatalError(errors));
128 },
129 }
130 }
131 }
132}
133
134fn optionally_timeout<'f, R, F>(
142 runtime: &R,
143 future: F,
144 timeout: Option<Duration>,
145) -> Pin<Box<dyn FusedFuture<Output = Result<F::Output, TimeoutError>> + Send + 'f>>
146where
147 R: Runtime,
148 F: Future + Send + 'f,
149{
150 match timeout {
151 Some(timeout) => Box::pin(runtime.timeout(timeout, future).fuse()),
152 None => Box::pin(future.map(Ok)),
153 }
154}
155
156pub(super) trait BackoffSchedule {
158 fn max_retries(&self) -> Option<usize>;
164
165 fn overall_timeout(&self) -> Option<Duration>;
171
172 fn single_attempt_timeout(&self) -> Option<Duration>;
174
175 fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration>;
183}
184
185#[derive(Clone, Debug, thiserror::Error)]
187pub(super) enum BackoffError<E> {
188 #[error("A fatal (non-transient) error occurred")]
190 FatalError(RetryError<E>),
191
192 #[error("Ran out of retries")]
194 MaxRetryCountExceeded(RetryError<E>),
195
196 #[error("Timeout exceeded")]
198 Timeout(RetryError<E>),
199
200 #[error("Stopped retrying as requested by BackoffSchedule")]
202 ExplicitStop(RetryError<E>),
203}
204
205impl<E> From<BackoffError<E>> for RetryError<E> {
206 fn from(e: BackoffError<E>) -> Self {
207 match e {
208 BackoffError::FatalError(e)
209 | BackoffError::MaxRetryCountExceeded(e)
210 | BackoffError::Timeout(e)
211 | BackoffError::ExplicitStop(e) => e,
212 }
213 }
214}
215
216pub(super) trait RetriableError: StdError + Clone {
218 fn should_retry(&self) -> bool;
220}
221
222impl RetriableError for TimeoutError {
223 fn should_retry(&self) -> bool {
224 true
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 #![allow(clippy::bool_assert_comparison)]
232 #![allow(clippy::clone_on_copy)]
233 #![allow(clippy::dbg_macro)]
234 #![allow(clippy::mixed_attributes_style)]
235 #![allow(clippy::print_stderr)]
236 #![allow(clippy::print_stdout)]
237 #![allow(clippy::single_char_pattern)]
238 #![allow(clippy::unwrap_used)]
239 #![allow(clippy::unchecked_duration_subtraction)]
240 #![allow(clippy::useless_vec)]
241 #![allow(clippy::needless_pass_by_value)]
242 use super::*;
245 use std::sync::Arc;
246
247 use std::iter;
248 use std::sync::RwLock;
249
250 use oneshot_fused_workaround as oneshot;
251 use tor_rtcompat::{SleepProvider, ToplevelBlockOn};
252 use tor_rtmock::MockRuntime;
253
254 const SHORT_DELAY: Duration = Duration::from_millis(10);
255 const TIMEOUT: Duration = Duration::from_millis(100);
256 const SINGLE_TIMEOUT: Duration = Duration::from_millis(50);
257 const MAX_RETRIES: usize = 5;
258
259 macro_rules! impl_backoff_sched {
260 ($name:ty, $max_retries:expr, $timeout:expr, $single_timeout:expr, $next_delay:expr) => {
261 impl BackoffSchedule for $name {
262 fn max_retries(&self) -> Option<usize> {
263 $max_retries
264 }
265
266 fn overall_timeout(&self) -> Option<Duration> {
267 $timeout
268 }
269
270 fn single_attempt_timeout(&self) -> Option<Duration> {
271 $single_timeout
272 }
273
274 #[allow(unused_variables)]
275 fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration> {
276 $next_delay
277 }
278 }
279 };
280 }
281
282 struct BackoffWithMaxRetries;
283
284 impl_backoff_sched!(
285 BackoffWithMaxRetries,
286 Some(MAX_RETRIES),
287 None,
288 None,
289 Some(SHORT_DELAY)
290 );
291
292 struct BackoffWithTimeout;
293
294 impl_backoff_sched!(
295 BackoffWithTimeout,
296 None,
297 Some(TIMEOUT),
298 None,
299 Some(SHORT_DELAY)
300 );
301
302 struct BackoffWithSingleTimeout;
303
304 impl_backoff_sched!(
305 BackoffWithSingleTimeout,
306 Some(MAX_RETRIES),
307 None,
308 Some(SINGLE_TIMEOUT),
309 Some(SHORT_DELAY)
310 );
311
312 #[derive(Debug, Copy, Clone, thiserror::Error)]
314 enum TestError {
315 #[error("A fatal test error")]
317 Fatal,
318 #[error("A transient test error")]
320 Transient,
321 }
322
323 impl RetriableError for TestError {
324 fn should_retry(&self) -> bool {
325 match self {
326 Self::Fatal => false,
327 Self::Transient => true,
328 }
329 }
330 }
331
332 fn run_test<E: RetriableError + Send + Sync + 'static>(
334 sleep_for: Option<Duration>,
335 schedule: impl BackoffSchedule + Send + 'static,
336 errors: impl Iterator<Item = E> + Send + Sync + 'static,
337 expected_run_count: usize,
338 description: &'static str,
339 expected_duration: Duration,
340 ) {
341 let runtime = MockRuntime::new();
342
343 runtime.clone().block_on(async move {
344 let runner = Runner {
345 doing: description.into(),
346 schedule,
347 runtime: runtime.clone(),
348 };
349
350 let retry_count = Arc::new(RwLock::new(0));
351 let (tx, rx) = oneshot::channel();
352
353 let start = runtime.now();
354 runtime
355 .mock_task()
356 .spawn_identified(format!("retry runner task: {description}"), {
357 let retry_count = Arc::clone(&retry_count);
358 let errors = Arc::new(RwLock::new(errors));
359 let runtime = runtime.clone();
360 async move {
361 if let Ok(()) = runner
362 .run(|| async {
363 *retry_count.write().unwrap() += 1;
364
365 if let Some(dur) = sleep_for {
366 runtime.sleep(dur).await;
367 }
368
369 Err::<(), _>(errors.write().unwrap().next().unwrap())
370 })
371 .await
372 {
373 unreachable!();
374 }
375
376 let () = tx.send(()).unwrap();
377 }
378 });
379
380 for i in 1..=expected_run_count {
384 runtime.mock_task().progress_until_stalled().await;
385 if let Some(sleep_for) = sleep_for {
388 runtime
389 .mock_sleep()
390 .advance(std::cmp::min(SINGLE_TIMEOUT, sleep_for));
391 }
392 runtime.mock_task().progress_until_stalled().await;
393 runtime.mock_sleep().advance(SHORT_DELAY);
394 assert_eq!(*retry_count.read().unwrap(), i);
395 }
396
397 let () = rx.await.unwrap();
398 let end = runtime.now();
399
400 assert_eq!(*retry_count.read().unwrap(), expected_run_count);
401 assert!(duration_close_to(end - start, expected_duration));
402 });
403 }
404
405 fn duration_close_to(d1: Duration, d2: Duration) -> bool {
409 d1 >= d2 && d1 <= d2 + SHORT_DELAY
410 }
411
412 #[test]
413 fn max_retries() {
414 run_test(
415 None,
416 BackoffWithMaxRetries,
417 iter::repeat(TestError::Transient),
418 MAX_RETRIES,
419 "backoff with max_retries and no timeout (transient errors)",
420 Duration::from_millis(SHORT_DELAY.as_millis() as u64 * MAX_RETRIES as u64),
421 );
422 }
423
424 #[test]
425 fn max_retries_fatal() {
426 use TestError::*;
427
428 const RETRIES_UNTIL_FATAL: usize = 3;
430 const EXPECTED_TOTAL_RUNS: usize = RETRIES_UNTIL_FATAL + 1;
434
435 run_test(
436 None,
437 BackoffWithMaxRetries,
438 std::iter::repeat_n(Transient, RETRIES_UNTIL_FATAL)
439 .chain([Fatal])
440 .chain(iter::repeat(Transient)),
441 EXPECTED_TOTAL_RUNS,
442 "backoff with max_retries and no timeout (transient errors followed by a fatal error)",
443 Duration::from_millis(SHORT_DELAY.as_millis() as u64 * EXPECTED_TOTAL_RUNS as u64),
444 );
445 }
446
447 #[test]
448 fn timeout() {
449 use TestError::*;
450
451 let expected_run_count = TIMEOUT.as_millis() / SHORT_DELAY.as_millis();
452
453 run_test(
454 None,
455 BackoffWithTimeout,
456 iter::repeat(Transient),
457 expected_run_count as usize,
458 "backoff with timeout and no max_retries (transient errors)",
459 TIMEOUT,
460 );
461 }
462
463 #[test]
464 fn single_timeout() {
465 use TestError::*;
466
467 let expected_duration = Duration::from_millis(
470 (SHORT_DELAY.as_millis() + SINGLE_TIMEOUT.as_millis()) as u64 * MAX_RETRIES as u64,
471 );
472
473 run_test(
474 Some(SINGLE_TIMEOUT * 2),
477 BackoffWithSingleTimeout,
478 iter::repeat(Transient),
479 MAX_RETRIES,
480 "backoff with single timeout and max_retries and no overall timeout",
481 expected_duration,
482 );
483 }
484
485 }