tor_proto/util/token_bucket/
bucket.rs

1//! A token bucket implementation.
2
3use std::fmt::Debug;
4use std::time::{Duration, Instant};
5
6/// A token bucket.
7///
8/// Calculations are performed at microsecond resolution.
9/// You likely want to call [`refill()`](Self::refill) each time you want to access or perform an
10/// operation on the token bucket.
11///
12/// This is partially inspired by tor's `token_bucket_ctr_t`,
13/// but the implementation is quite a bit different.
14/// We use larger values here (for example `u64`),
15/// and we aim to avoid drift when refills occur at times that aren't exactly in period with the
16/// refill rate.
17///
18/// It's possible that we could relax these requirements to reduce memory usage and computation
19/// complexity, but that optimization should probably only be made if/when needed since it would
20/// make the code more difficult to reason about, and possibly more complex.
21#[derive(Debug)]
22pub(crate) struct TokenBucket<I> {
23    /// The refill rate in tokens/second.
24    rate: u64,
25    /// The max amount of tokens in the bucket.
26    /// Commonly referred to as the "burst".
27    bucket_max: u64,
28    /// Current amount of tokens in the bucket.
29    // It's possible that in the future we may want a token bucket to allow negative values. For
30    // example we might want to send a few extra bytes over the allowed limit if it would mean that
31    // we send a complete TLS record.
32    bucket: u64,
33    /// Time that the most recent token was added to the bucket.
34    ///
35    /// While this can be thought of as the last time the bucket was partially refilled, it more
36    /// specifically is the time that the most recent token was added. For example if the bucket
37    /// refills one token every 100 ms, and the bucket is refilled at time 510 ms, the bucket would
38    /// gain 5 tokens and the stored time would be 500 ms.
39    added_tokens_at: I,
40}
41
42impl<I: TokenBucketInstant> TokenBucket<I> {
43    /// A new [`TokenBucket`] with a given `rate` in tokens/second and a `max` token limit.
44    ///
45    /// The bucket will initially be full.
46    /// The value `max` is commonly referred to as the "burst".
47    pub(crate) fn new(config: &TokenBucketConfig, now: I) -> Self {
48        Self {
49            rate: config.rate,
50            bucket_max: config.bucket_max,
51            bucket: config.bucket_max,
52            added_tokens_at: now,
53        }
54    }
55
56    /// Are there no tokens in the bucket?
57    // remove this if we use it in the future
58    #[cfg_attr(not(test), expect(dead_code))]
59    pub(crate) fn is_empty(&self) -> bool {
60        self.bucket == 0
61    }
62
63    /// The maximum number of tokens that this bucket can hold.
64    pub(crate) fn max(&self) -> u64 {
65        self.bucket_max
66    }
67
68    /// Remove `count` tokens from the bucket.
69    // remove this if we use it in the future
70    #[cfg_attr(not(test), expect(dead_code))]
71    pub(crate) fn drain(&mut self, count: u64) -> Result<BecameEmpty, InsufficientTokensError> {
72        Ok(self.claim(count)?.commit())
73    }
74
75    /// Claim a number of tokens.
76    ///
77    /// The claim will be held by the returned [`ClaimedTokens`], and committed when dropped.
78    ///
79    /// **Note:** You probably want to call [`refill()`](Self::refill) before this.
80    // Since the `ClaimedTokens` holds a `&mut` to this `TokenBucket`, we don't need to worry about
81    // other calls accessing the `TokenBucket` before the `ClaimedTokens` are committed.
82    pub(crate) fn claim(
83        &mut self,
84        count: u64,
85    ) -> Result<ClaimedTokens<I>, InsufficientTokensError> {
86        if count > self.bucket {
87            return Err(InsufficientTokensError {
88                available: self.bucket,
89            });
90        }
91
92        Ok(ClaimedTokens::new(self, count))
93    }
94
95    /// Adjust the refill rate and max tokens of the bucket.
96    ///
97    /// The token bucket is refilled up to `now` before changing the rate.
98    ///
99    /// If the new max is smaller than the existing number of tokens,
100    /// the number of tokens will be reduced to the new max.
101    ///
102    /// A rate and/or max of 0 is allowed.
103    pub(crate) fn adjust(&mut self, now: I, config: &TokenBucketConfig) {
104        // make sure that the bucket gets the tokens it is owed before we change the rate
105        self.refill(now);
106
107        // If the old rate was small (or 0), the `refill()` might not have updated
108        // `added_tokens_at`.
109        //
110        // For example if the bucket has a rate of 0 and was last refilled 10 seconds ago, it will
111        // not have gained any tokens in the last 10 seconds. If we were to only update the rate to
112        // 100 tokens/second now, the bucket would immediately become eligible to refill 1000
113        // tokens. We only want the rate change to become effective now, not in the past, so we
114        // ensure this by resetting `added_tokens_at`.
115        self.added_tokens_at = std::cmp::max(self.added_tokens_at, now);
116
117        self.rate = config.rate;
118        self.bucket_max = config.bucket_max;
119        self.bucket = std::cmp::min(self.bucket, self.bucket_max);
120    }
121
122    /// An estimated time at which the bucket will have `tokens` available.
123    ///
124    /// It is not guaranteed that `tokens` will be available at the returned time.
125    ///
126    /// If there are already enough tokens available, a time in the past may be returned.
127    ///
128    /// A value of `None` implies "never",
129    /// for example if the refill rate is 0,
130    /// the bucket max is too small,
131    /// or the time is too large to be represented as an `I`.
132    pub(crate) fn tokens_available_at(&self, tokens: u64) -> Result<I, NeverEnoughTokensError> {
133        let tokens_needed = tokens.saturating_sub(self.bucket);
134
135        // check if we currently have enough tokens before considering refilling
136        if tokens_needed == 0 {
137            return Ok(self.added_tokens_at);
138        }
139
140        // if the rate is 0, we'll never get more tokens
141        if self.rate == 0 {
142            return Err(NeverEnoughTokensError::ZeroRate);
143        }
144
145        // if more tokens are wanted than the capacity of the bucket, we'll never get enough
146        if tokens > self.bucket_max {
147            return Err(NeverEnoughTokensError::ExceedsMaxTokens);
148        }
149
150        // this may underestimate the time if either argument is very large
151        let time_needed = Self::tokens_to_duration(tokens_needed, self.rate)
152            .ok_or(NeverEnoughTokensError::ZeroRate)?;
153
154        // Always return at least 1 microsecond since:
155        // 1. We don't want to return `Duration::ZERO` if the tokens aren't ready,
156        //    which may occur if the rate is very large (<1 ns/token).
157        // 2. Clocks generally don't operate at <1 us resolution.
158        let time_needed = std::cmp::max(time_needed, Duration::from_micros(1));
159
160        self.added_tokens_at
161            .checked_add(time_needed)
162            .ok_or(NeverEnoughTokensError::InstantNotRepresentable)
163    }
164
165    /// Refill the bucket.
166    pub(crate) fn refill(&mut self, now: I) -> BecameNonEmpty {
167        // time since we last added tokens
168        let elapsed = now.saturating_duration_since(self.added_tokens_at);
169
170        // If we exceeded the threshold, update the timestamp and return.
171        // This is taken from tor, which has the comment below:
172        //
173        // > Skip over updates that include an overflow or a very large jump. This can happen for
174        // > platform specific reasons, such as the old ~48 day windows timer.
175        //
176        // It's unclear if this type of OS bug is still common enough that this check is useful,
177        // but it shouldn't hurt.
178        if elapsed > I::IGNORE_THRESHOLD {
179            tracing::debug!(
180                "Time jump of {elapsed:?} is larger than {:?}; not refilling token bucket",
181                I::IGNORE_THRESHOLD,
182            );
183            self.added_tokens_at = now;
184            return BecameNonEmpty::No;
185        }
186
187        let old_bucket = self.bucket;
188
189        // Compute how much we should increment the bucket by.
190        // This may be underestimated in some cases.
191        let bucket_inc = Self::duration_to_tokens(elapsed, self.rate);
192
193        self.bucket = std::cmp::min(self.bucket_max, self.bucket.saturating_add(bucket_inc));
194
195        // Compute how much we should increment the `last_added_tokens` time by. This avoids
196        // drifting if the `bucket_inc` was underestimated, and avoids rounding errors which could
197        // cause the token bucket to effectively use a lower rate. For example if the rate was
198        // "1 token / sec" and the elapsed time was "1.2 sec", we only want to refill 1 token and
199        // increment the time by 1 second.
200        //
201        // While the docs for `tokens_to_duration` say that a smaller than expected duration may be
202        // returned, we have a test `test_duration_token_round_trip` which ensures that
203        // `tokens_to_duration` returns the expected value when used with the result from
204        // `duration_to_tokens`.
205        let added_tokens_at_inc =
206            Self::tokens_to_duration(bucket_inc, self.rate).unwrap_or(Duration::ZERO);
207
208        self.added_tokens_at = self
209            .added_tokens_at
210            .checked_add(added_tokens_at_inc)
211            .expect("overflowed time");
212        debug_assert!(self.added_tokens_at <= now);
213
214        if old_bucket == 0 && self.bucket != 0 {
215            BecameNonEmpty::Yes
216        } else {
217            BecameNonEmpty::No
218        }
219    }
220
221    /// How long would it take to refill `tokens` at `rate`?
222    ///
223    /// The result is rounded up to the nearest microsecond.
224    /// If the number of `tokens` is large,
225    /// the result may be much lower than the expected duration due to saturating 64-bit arithmetic.
226    ///
227    /// `None` will be returned if the `rate` is 0.
228    fn tokens_to_duration(tokens: u64, rate: u64) -> Option<Duration> {
229        // Perform the calculation in microseconds rather than nanoseconds since timers typically
230        // have microsecond granularity, and it lowers the chance that the calculation overflows the
231        // `u64::MAX` limit compared to nanoseconds. In the case that the calculation saturates, the
232        // returned duration will be shorter than the real value.
233        //
234        // For example with `tokens = u64::MAX` and `rate = u64::MAX` we'd expect a result of 1
235        // second, but:
236        // u64::MAX.saturating_mul(1000 * 1000).div_ceil(u64::MAX) = 1 microsecond
237        //
238        // The `div_ceil` ensures we always round up to the nearest microsecond.
239        //
240        // dimensional analysis:
241        // (tokens) * (microseconds / second) / (tokens / second) = microseconds
242        if rate == 0 {
243            return None;
244        }
245        let micros = tokens.saturating_mul(1000 * 1000).div_ceil(rate);
246        Some(Duration::from_micros(micros))
247    }
248
249    /// How many tokens would be refilled within `time` at `rate`?
250    ///
251    /// The `time` is truncated to microsecond granularity.
252    /// If the `time` or `rate` is large,
253    /// the result may be much lower than the expected number of tokens due to saturating 64-bit
254    /// arithmetic.
255    fn duration_to_tokens(time: Duration, rate: u64) -> u64 {
256        let micros = u64::try_from(time.as_micros()).unwrap_or(u64::MAX);
257        // dimensional analysis:
258        // (tokens / second) * (microseconds) / (microseconds / second) = tokens
259        rate.saturating_mul(micros) / (1000 * 1000)
260    }
261}
262
263/// The refill rate and token max for a [`TokenBucket`].
264#[derive(Clone, Debug)]
265pub(crate) struct TokenBucketConfig {
266    /// The refill rate in tokens/second.
267    pub(crate) rate: u64,
268    /// The max amount of tokens in the bucket.
269    /// Commonly referred to as the "burst".
270    pub(crate) bucket_max: u64,
271}
272
273/// A handle to a number of claimed tokens.
274///
275/// Dropping this handle will commit the claim.
276#[derive(Debug)]
277pub(crate) struct ClaimedTokens<'a, I> {
278    /// The bucket that the claim is for.
279    bucket: &'a mut TokenBucket<I>,
280    /// How many tokens to remove from the bucket.
281    count: u64,
282}
283
284impl<'a, I> ClaimedTokens<'a, I> {
285    /// Create a new [`ClaimedTokens`] that will remove `count` tokens from the token `bucket` when
286    /// dropped.
287    fn new(bucket: &'a mut TokenBucket<I>, count: u64) -> Self {
288        Self { bucket, count }
289    }
290
291    /// Commit the claimed tokens.
292    ///
293    /// This is equivalent to just dropping the [`ClaimedTokens`], but also returns whether the
294    /// token bucket became empty or not.
295    pub(crate) fn commit(mut self) -> BecameEmpty {
296        self.commit_impl()
297    }
298
299    /// Reduce the claim to a fewer number of tokens than the original claim.
300    ///
301    /// If `count` is larger than the original claim, an error will be returned containing the
302    /// current number of claimed tokens.
303    pub(crate) fn reduce(&mut self, count: u64) -> Result<(), InsufficientTokensError> {
304        if count > self.count {
305            return Err(InsufficientTokensError {
306                available: self.count,
307            });
308        }
309
310        self.count = count;
311        Ok(())
312    }
313
314    /// Discard the claim.
315    ///
316    /// This does not remove any tokens from the token bucket.
317    pub(crate) fn discard(mut self) {
318        self.count = 0;
319    }
320
321    /// The commit implementation.
322    ///
323    /// After calling [`commit_impl()`](Self::commit_impl),
324    /// the [`ClaimedTokens`] should no longer be used and should be dropped immediately.
325    fn commit_impl(&mut self) -> BecameEmpty {
326        // when the `ClaimedTokens` was created by the `TokenBucket`, it should have ensured that
327        // there were enough tokens
328        self.bucket.bucket = self
329            .bucket
330            .bucket
331            .checked_sub(self.count)
332            .unwrap_or_else(|| {
333                panic!(
334                    "claim commit failed: {}, {}",
335                    self.count, self.bucket.bucket,
336                )
337            });
338
339        // when `self` is dropped some time after this function ends,
340        // we don't want to subtract again
341        self.count = 0;
342
343        if self.bucket.bucket > 0 {
344            BecameEmpty::No
345        } else {
346            BecameEmpty::Yes
347        }
348    }
349}
350
351impl<'a, I> std::ops::Drop for ClaimedTokens<'a, I> {
352    fn drop(&mut self) {
353        self.commit_impl();
354    }
355}
356
357/// An operation was attempted to reduce the number of tokens,
358/// but the token bucket did not have enough tokens.
359#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
360#[error("insufficient tokens for operation")]
361pub(crate) struct InsufficientTokensError {
362    /// The number of tokens that are available to drain/commit.
363    available: u64,
364}
365
366impl InsufficientTokensError {
367    /// Get the number of tokens that are available to drain/commit.
368    pub(crate) fn available_tokens(&self) -> u64 {
369        self.available
370    }
371}
372
373/// The token bucket will never have the requested number of tokens.
374#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
375#[error("there will never be enough tokens for this operation")]
376pub(crate) enum NeverEnoughTokensError {
377    /// The request exceeds the bucket's maximum number of tokens.
378    ExceedsMaxTokens,
379    /// The refill rate is 0.
380    ZeroRate,
381    /// The time is not representable.
382    ///
383    /// For example the if the rate is low and a large number of tokens were requested, it may be
384    /// too far in the future that it cannot be represented as a time value.
385    InstantNotRepresentable,
386}
387
388/// The token bucket transitioned from "empty" to "non-empty".
389#[derive(Copy, Clone, Debug, PartialEq, Eq)]
390pub(crate) enum BecameNonEmpty {
391    /// Token bucket became non-empty.
392    Yes,
393    /// Token bucket remains empty.
394    No,
395}
396
397/// The token bucket transitioned from "non-empty" to "empty".
398#[derive(Copy, Clone, Debug, PartialEq, Eq)]
399pub(crate) enum BecameEmpty {
400    /// Token bucket became empty.
401    Yes,
402    /// Token bucket remains non-empty.
403    No,
404}
405
406/// Any type implementing this must be represented as a measurement of a monotonically nondecreasing
407/// clock.
408pub(crate) trait TokenBucketInstant:
409    Copy + Clone + Debug + PartialEq + Eq + PartialOrd + Ord
410{
411    /// An unrealistically large time jump.
412    ///
413    /// We assume that any time change larger than this indicates a broken monotonic clock,
414    /// and the bucket will not be refilled.
415    const IGNORE_THRESHOLD: Duration;
416
417    /// See [`Instant::checked_add`].
418    fn checked_add(&self, duration: Duration) -> Option<Self>;
419
420    /// See [`Instant::checked_duration_since`].
421    fn checked_duration_since(&self, earlier: Self) -> Option<Duration>;
422
423    /// See [`Instant::saturating_duration_since`].
424    fn saturating_duration_since(&self, earlier: Self) -> Duration {
425        self.checked_duration_since(earlier).unwrap_or_default()
426    }
427}
428
429impl TokenBucketInstant for Instant {
430    // This value is taken from tor (see `elapsed_ticks <= UINT32_MAX/4` in
431    // `src/lib/evloop/token_bucket.c`).
432    const IGNORE_THRESHOLD: Duration = Duration::from_secs((u32::MAX / 4) as u64);
433
434    #[inline]
435    fn checked_add(&self, duration: Duration) -> Option<Self> {
436        self.checked_add(duration)
437    }
438
439    #[inline]
440    fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
441        self.checked_duration_since(earlier)
442    }
443
444    #[inline]
445    fn saturating_duration_since(&self, earlier: Self) -> Duration {
446        self.saturating_duration_since(earlier)
447    }
448}
449
450#[cfg(test)]
451mod test {
452    #![allow(clippy::unwrap_used)]
453
454    use super::*;
455
456    use rand::Rng;
457
458    #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
459    struct MillisTimestamp(u64);
460
461    impl TokenBucketInstant for MillisTimestamp {
462        const IGNORE_THRESHOLD: Duration = Duration::from_millis(1_000_000_000);
463
464        fn checked_add(&self, duration: Duration) -> Option<Self> {
465            let duration = u64::try_from(duration.as_millis()).ok()?;
466            self.0.checked_add(duration).map(Self)
467        }
468
469        fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
470            Some(Duration::from_millis(self.0.checked_sub(earlier.0)?))
471        }
472    }
473
474    #[test]
475    fn adjust_now() {
476        let time = MillisTimestamp(100);
477
478        let config = TokenBucketConfig {
479            rate: 10,
480            bucket_max: 100,
481        };
482        let mut tb = TokenBucket::new(&config, time);
483        assert_eq!(tb.bucket, 100);
484        assert_eq!(tb.bucket_max, 100);
485        assert_eq!(tb.rate, 10);
486
487        tb.adjust(
488            time,
489            &TokenBucketConfig {
490                rate: 20,
491                bucket_max: 100,
492            },
493        );
494        assert_eq!(tb.bucket, 100);
495        assert_eq!(tb.bucket_max, 100);
496
497        tb.adjust(
498            time,
499            &TokenBucketConfig {
500                rate: 20,
501                bucket_max: 40,
502            },
503        );
504        assert_eq!(tb.bucket, 40);
505        assert_eq!(tb.bucket_max, 40);
506
507        tb.adjust(
508            time,
509            &TokenBucketConfig {
510                rate: 20,
511                bucket_max: 100,
512            },
513        );
514        assert_eq!(tb.bucket, 40);
515        assert_eq!(tb.bucket_max, 100);
516
517        tb.adjust(
518            time,
519            &TokenBucketConfig {
520                rate: 200,
521                bucket_max: 100,
522            },
523        );
524        assert_eq!(tb.bucket, 40);
525        assert_eq!(tb.bucket_max, 100);
526        assert_eq!(tb.rate, 200);
527    }
528
529    #[test]
530    fn adjust_future() {
531        let config = TokenBucketConfig {
532            rate: 10,
533            bucket_max: 100,
534        };
535        let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
536        assert_eq!(tb.bucket, 100);
537        assert_eq!(tb.bucket_max, 100);
538        assert_eq!(tb.rate, 10);
539
540        // at 300 ms: increase rate and max; bucket was already full, so doesn't gain any tokens
541        tb.adjust(
542            MillisTimestamp(300),
543            &TokenBucketConfig {
544                rate: 20,
545                bucket_max: 200,
546            },
547        );
548        assert_eq!(tb.bucket, 100);
549        assert_eq!(tb.bucket_max, 200);
550
551        // at 500 ms: no changes; bucket is refilled during `adjust()`, so gains 4 tokens
552        tb.adjust(
553            MillisTimestamp(500),
554            &TokenBucketConfig {
555                rate: 20,
556                bucket_max: 200,
557            },
558        );
559        assert_eq!(tb.bucket, 104);
560        assert_eq!(tb.bucket_max, 200);
561
562        // at 700 ms: lower rate and max; bucket is lowered to new max, so loses 4 tokens
563        tb.adjust(
564            MillisTimestamp(700),
565            &TokenBucketConfig {
566                rate: 0,
567                bucket_max: 100,
568            },
569        );
570        assert_eq!(tb.bucket, 100);
571        assert_eq!(tb.bucket_max, 100);
572
573        // at 900 ms: raise rate and max; rate was previously 0 so doesn't gain any tokens
574        tb.adjust(
575            MillisTimestamp(900),
576            &TokenBucketConfig {
577                rate: 100,
578                bucket_max: 200,
579            },
580        );
581        assert_eq!(tb.bucket, 100);
582        assert_eq!(tb.bucket_max, 200);
583    }
584
585    #[test]
586    fn adjust_zero() {
587        let time = MillisTimestamp(100);
588
589        let config = TokenBucketConfig {
590            rate: 10,
591            bucket_max: 100,
592        };
593
594        let mut tb = TokenBucket::new(&config, time);
595        tb.adjust(
596            time,
597            &TokenBucketConfig {
598                rate: 0,
599                bucket_max: 200,
600            },
601        );
602        assert_eq!(tb.bucket, 100);
603        assert_eq!(tb.bucket_max, 200);
604        assert_eq!(tb.rate, 0);
605        // bucket should not increase
606        tb.refill(MillisTimestamp(10_000_000));
607        assert_eq!(tb.bucket, 100);
608
609        let mut tb = TokenBucket::new(&config, time);
610        tb.adjust(
611            time,
612            &TokenBucketConfig {
613                rate: 10,
614                bucket_max: 0,
615            },
616        );
617        assert_eq!(tb.bucket, 0);
618        assert_eq!(tb.bucket_max, 0);
619        assert_eq!(tb.rate, 10);
620        // bucket should stay empty
621        tb.refill(MillisTimestamp(10_000_000));
622        assert_eq!(tb.bucket, 0);
623
624        let mut tb = TokenBucket::new(&config, time);
625        tb.adjust(
626            time,
627            &TokenBucketConfig {
628                rate: 0,
629                bucket_max: 0,
630            },
631        );
632        assert_eq!(tb.bucket, 0);
633        assert_eq!(tb.bucket_max, 0);
634        assert_eq!(tb.rate, 0);
635        // bucket should stay empty
636        tb.refill(MillisTimestamp(10_000_000));
637        assert_eq!(tb.bucket, 0);
638    }
639
640    #[test]
641    fn is_empty() {
642        // increases 10 tokens/second (one every 100 ms)
643        let config = TokenBucketConfig {
644            rate: 10,
645            bucket_max: 100,
646        };
647        let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
648        assert!(!tb.is_empty());
649
650        tb.drain(99).unwrap();
651        assert!(!tb.is_empty());
652
653        tb.drain(1).unwrap();
654        assert!(tb.is_empty());
655
656        tb.refill(MillisTimestamp(199));
657        assert!(tb.is_empty());
658
659        tb.refill(MillisTimestamp(200));
660        assert!(!tb.is_empty());
661    }
662
663    #[test]
664    fn correctness() {
665        // increases 10 tokens/second (one every 100 ms)
666        let config = TokenBucketConfig {
667            rate: 10,
668            bucket_max: 100,
669        };
670        let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
671
672        tb.drain(50).unwrap();
673        assert_eq!(tb.bucket, 50);
674
675        tb.refill(MillisTimestamp(1100));
676        assert_eq!(tb.bucket, 60);
677
678        tb.drain(50).unwrap();
679        assert_eq!(tb.bucket, 10);
680
681        tb.refill(MillisTimestamp(2100));
682        assert_eq!(tb.bucket, 20);
683
684        tb.refill(MillisTimestamp(2101));
685        assert_eq!(tb.bucket, 20);
686        tb.refill(MillisTimestamp(2199));
687        assert_eq!(tb.bucket, 20);
688        tb.refill(MillisTimestamp(2200));
689        assert_eq!(tb.bucket, 21);
690    }
691
692    #[test]
693    fn rounding() {
694        // increases 10 tokens/second (one every 100 ms)
695        let config = TokenBucketConfig {
696            rate: 10,
697            bucket_max: 100,
698        };
699        let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
700        tb.drain(100).unwrap();
701
702        // ensure that refilling at 150 ms does not change the `added_tokens_at` time to 150 ms,
703        // otherwise the next refill wouldn't occur until 250 ms instead of 200 ms
704        tb.refill(MillisTimestamp(99));
705        assert_eq!(tb.bucket, 0);
706        tb.refill(MillisTimestamp(150));
707        assert_eq!(tb.bucket, 1);
708        tb.refill(MillisTimestamp(199));
709        assert_eq!(tb.bucket, 1);
710        tb.refill(MillisTimestamp(200));
711        assert_eq!(tb.bucket, 2);
712    }
713
714    #[test]
715    fn tokens_available_at() {
716        // increases 10 tokens/second (one every 100 ms)
717        let config = TokenBucketConfig {
718            rate: 10,
719            bucket_max: 100,
720        };
721        let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
722
723        // bucket is empty at 0 ms, next token at 100 ms
724        tb.drain(100).unwrap();
725
726        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
727        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
728        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
729
730        // bucket is still empty at 40 ms, next token at 100 ms
731        tb.refill(MillisTimestamp(40));
732
733        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
734        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
735        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
736
737        // bucket has 1 token at 100 ms, next token at 200 ms
738        tb.refill(MillisTimestamp(100));
739
740        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
741        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
742        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
743
744        // bucket is empty at 100 ms, next token at 200 ms
745        tb.drain(1).unwrap();
746
747        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
748        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
749        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
750
751        // bucket is empty at 140 ms, next token at 200 ms
752        tb.refill(MillisTimestamp(140));
753
754        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
755        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
756        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
757
758        // bucket has 1 token at 210 ms, next token at 300 ms
759        tb.refill(MillisTimestamp(210));
760
761        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(200)));
762        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
763        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
764
765        use NeverEnoughTokensError as NETE;
766
767        assert_eq!(tb.tokens_available_at(100), Ok(MillisTimestamp(10_100)));
768        assert_eq!(tb.tokens_available_at(101), Err(NETE::ExceedsMaxTokens));
769        assert_eq!(
770            tb.tokens_available_at(u64::MAX),
771            Err(NETE::ExceedsMaxTokens),
772        );
773
774        // set the refill rate to 0; note that adjusting the rate also resets `added_tokens_at`
775        tb.adjust(
776            MillisTimestamp(210),
777            &TokenBucketConfig {
778                rate: 0,
779                bucket_max: 100,
780            },
781        );
782
783        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(210)));
784        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(210)));
785        assert_eq!(tb.tokens_available_at(2), Err(NETE::ZeroRate));
786    }
787
788    #[test]
789    fn test_duration_token_round_trip() {
790        let tokens_to_duration = TokenBucket::<Instant>::tokens_to_duration;
791        let duration_to_tokens = TokenBucket::<Instant>::duration_to_tokens;
792
793        // start with some hand-picked cases
794        let mut duration_rate_pairs = vec![
795            (Duration::from_nanos(0), 1),
796            (Duration::from_nanos(1), 1),
797            (Duration::from_micros(2), 1),
798            (Duration::MAX, 1),
799            (Duration::from_nanos(0), 3),
800            (Duration::from_nanos(1), 3),
801            (Duration::from_micros(2), 3),
802            (Duration::MAX, 3),
803            (Duration::from_nanos(0), 1000),
804            (Duration::from_nanos(1), 1000),
805            (Duration::from_micros(2), 1000),
806            (Duration::MAX, 1000),
807            (Duration::from_nanos(0), u64::MAX),
808            (Duration::from_nanos(1), u64::MAX),
809            (Duration::from_micros(2), u64::MAX),
810            (Duration::MAX, u64::MAX),
811        ];
812
813        let mut rng = rand::rng();
814
815        // add some fuzzing
816        for _ in 0..10_000 {
817            let secs = rng.random();
818            let nanos = rng.random();
819            // Duration::new() may panic, so just skip if there's a panic rather than trying to
820            // write our own logic to avoid the panic in the first place
821            let Ok(random_duration) = std::panic::catch_unwind(|| Duration::new(secs, nanos))
822            else {
823                continue;
824            };
825            let random_rate = rng.random();
826            duration_rate_pairs.push((random_duration, random_rate));
827        }
828
829        // for various combinations of durations and rates, we ensure that after an initial
830        // `duration_to_tokens` calculation which may truncate, a round-trip between
831        // `tokens_to_duration` and `duration_to_tokens` isn't lossy
832        for (original_duration, rate) in duration_rate_pairs {
833            // this may give a smaller number of tokens than expected (see docs on
834            // `TokenBucket::duration_to_tokens`)
835            let tokens = duration_to_tokens(original_duration, rate);
836
837            // we want to ensure that converting these `tokens` to a duration and then back to
838            // tokens is not lossy, which implies that `tokens_to_duration` is returning the
839            // expected value and not a truncated value due to saturating arithmetic
840            let duration = tokens_to_duration(tokens, rate).unwrap();
841            assert_eq!(tokens, duration_to_tokens(duration, rate));
842        }
843    }
844}