tor_proto/util/token_bucket/bucket.rs
1//! A token bucket implementation.
2
3use std::fmt::Debug;
4use std::time::{Duration, Instant};
5
6/// A token bucket.
7///
8/// Calculations are performed at microsecond resolution.
9/// You likely want to call [`refill()`](Self::refill) each time you want to access or perform an
10/// operation on the token bucket.
11///
12/// This is partially inspired by tor's `token_bucket_ctr_t`,
13/// but the implementation is quite a bit different.
14/// We use larger values here (for example `u64`),
15/// and we aim to avoid drift when refills occur at times that aren't exactly in period with the
16/// refill rate.
17///
18/// It's possible that we could relax these requirements to reduce memory usage and computation
19/// complexity, but that optimization should probably only be made if/when needed since it would
20/// make the code more difficult to reason about, and possibly more complex.
21#[derive(Debug)]
22pub(crate) struct TokenBucket<I> {
23 /// The refill rate in tokens/second.
24 rate: u64,
25 /// The max amount of tokens in the bucket.
26 /// Commonly referred to as the "burst".
27 bucket_max: u64,
28 /// Current amount of tokens in the bucket.
29 // It's possible that in the future we may want a token bucket to allow negative values. For
30 // example we might want to send a few extra bytes over the allowed limit if it would mean that
31 // we send a complete TLS record.
32 bucket: u64,
33 /// Time that the most recent token was added to the bucket.
34 ///
35 /// While this can be thought of as the last time the bucket was partially refilled, it more
36 /// specifically is the time that the most recent token was added. For example if the bucket
37 /// refills one token every 100 ms, and the bucket is refilled at time 510 ms, the bucket would
38 /// gain 5 tokens and the stored time would be 500 ms.
39 added_tokens_at: I,
40}
41
42impl<I: TokenBucketInstant> TokenBucket<I> {
43 /// A new [`TokenBucket`] with a given `rate` in tokens/second and a `max` token limit.
44 ///
45 /// The bucket will initially be full.
46 /// The value `max` is commonly referred to as the "burst".
47 pub(crate) fn new(config: &TokenBucketConfig, now: I) -> Self {
48 Self {
49 rate: config.rate,
50 bucket_max: config.bucket_max,
51 bucket: config.bucket_max,
52 added_tokens_at: now,
53 }
54 }
55
56 /// Are there no tokens in the bucket?
57 // remove this if we use it in the future
58 #[cfg_attr(not(test), expect(dead_code))]
59 pub(crate) fn is_empty(&self) -> bool {
60 self.bucket == 0
61 }
62
63 /// The maximum number of tokens that this bucket can hold.
64 pub(crate) fn max(&self) -> u64 {
65 self.bucket_max
66 }
67
68 /// Remove `count` tokens from the bucket.
69 // remove this if we use it in the future
70 #[cfg_attr(not(test), expect(dead_code))]
71 pub(crate) fn drain(&mut self, count: u64) -> Result<BecameEmpty, InsufficientTokensError> {
72 Ok(self.claim(count)?.commit())
73 }
74
75 /// Claim a number of tokens.
76 ///
77 /// The claim will be held by the returned [`ClaimedTokens`], and committed when dropped.
78 ///
79 /// **Note:** You probably want to call [`refill()`](Self::refill) before this.
80 // Since the `ClaimedTokens` holds a `&mut` to this `TokenBucket`, we don't need to worry about
81 // other calls accessing the `TokenBucket` before the `ClaimedTokens` are committed.
82 pub(crate) fn claim(
83 &mut self,
84 count: u64,
85 ) -> Result<ClaimedTokens<I>, InsufficientTokensError> {
86 if count > self.bucket {
87 return Err(InsufficientTokensError {
88 available: self.bucket,
89 });
90 }
91
92 Ok(ClaimedTokens::new(self, count))
93 }
94
95 /// Adjust the refill rate and max tokens of the bucket.
96 ///
97 /// The token bucket is refilled up to `now` before changing the rate.
98 ///
99 /// If the new max is smaller than the existing number of tokens,
100 /// the number of tokens will be reduced to the new max.
101 ///
102 /// A rate and/or max of 0 is allowed.
103 pub(crate) fn adjust(&mut self, now: I, config: &TokenBucketConfig) {
104 // make sure that the bucket gets the tokens it is owed before we change the rate
105 self.refill(now);
106
107 // If the old rate was small (or 0), the `refill()` might not have updated
108 // `added_tokens_at`.
109 //
110 // For example if the bucket has a rate of 0 and was last refilled 10 seconds ago, it will
111 // not have gained any tokens in the last 10 seconds. If we were to only update the rate to
112 // 100 tokens/second now, the bucket would immediately become eligible to refill 1000
113 // tokens. We only want the rate change to become effective now, not in the past, so we
114 // ensure this by resetting `added_tokens_at`.
115 self.added_tokens_at = std::cmp::max(self.added_tokens_at, now);
116
117 self.rate = config.rate;
118 self.bucket_max = config.bucket_max;
119 self.bucket = std::cmp::min(self.bucket, self.bucket_max);
120 }
121
122 /// An estimated time at which the bucket will have `tokens` available.
123 ///
124 /// It is not guaranteed that `tokens` will be available at the returned time.
125 ///
126 /// If there are already enough tokens available, a time in the past may be returned.
127 ///
128 /// A value of `None` implies "never",
129 /// for example if the refill rate is 0,
130 /// the bucket max is too small,
131 /// or the time is too large to be represented as an `I`.
132 pub(crate) fn tokens_available_at(&self, tokens: u64) -> Result<I, NeverEnoughTokensError> {
133 let tokens_needed = tokens.saturating_sub(self.bucket);
134
135 // check if we currently have enough tokens before considering refilling
136 if tokens_needed == 0 {
137 return Ok(self.added_tokens_at);
138 }
139
140 // if the rate is 0, we'll never get more tokens
141 if self.rate == 0 {
142 return Err(NeverEnoughTokensError::ZeroRate);
143 }
144
145 // if more tokens are wanted than the capacity of the bucket, we'll never get enough
146 if tokens > self.bucket_max {
147 return Err(NeverEnoughTokensError::ExceedsMaxTokens);
148 }
149
150 // this may underestimate the time if either argument is very large
151 let time_needed = Self::tokens_to_duration(tokens_needed, self.rate)
152 .ok_or(NeverEnoughTokensError::ZeroRate)?;
153
154 // Always return at least 1 microsecond since:
155 // 1. We don't want to return `Duration::ZERO` if the tokens aren't ready,
156 // which may occur if the rate is very large (<1 ns/token).
157 // 2. Clocks generally don't operate at <1 us resolution.
158 let time_needed = std::cmp::max(time_needed, Duration::from_micros(1));
159
160 self.added_tokens_at
161 .checked_add(time_needed)
162 .ok_or(NeverEnoughTokensError::InstantNotRepresentable)
163 }
164
165 /// Refill the bucket.
166 pub(crate) fn refill(&mut self, now: I) -> BecameNonEmpty {
167 // time since we last added tokens
168 let elapsed = now.saturating_duration_since(self.added_tokens_at);
169
170 // If we exceeded the threshold, update the timestamp and return.
171 // This is taken from tor, which has the comment below:
172 //
173 // > Skip over updates that include an overflow or a very large jump. This can happen for
174 // > platform specific reasons, such as the old ~48 day windows timer.
175 //
176 // It's unclear if this type of OS bug is still common enough that this check is useful,
177 // but it shouldn't hurt.
178 if elapsed > I::IGNORE_THRESHOLD {
179 tracing::debug!(
180 "Time jump of {elapsed:?} is larger than {:?}; not refilling token bucket",
181 I::IGNORE_THRESHOLD,
182 );
183 self.added_tokens_at = now;
184 return BecameNonEmpty::No;
185 }
186
187 let old_bucket = self.bucket;
188
189 // Compute how much we should increment the bucket by.
190 // This may be underestimated in some cases.
191 let bucket_inc = Self::duration_to_tokens(elapsed, self.rate);
192
193 self.bucket = std::cmp::min(self.bucket_max, self.bucket.saturating_add(bucket_inc));
194
195 // Compute how much we should increment the `last_added_tokens` time by. This avoids
196 // drifting if the `bucket_inc` was underestimated, and avoids rounding errors which could
197 // cause the token bucket to effectively use a lower rate. For example if the rate was
198 // "1 token / sec" and the elapsed time was "1.2 sec", we only want to refill 1 token and
199 // increment the time by 1 second.
200 //
201 // While the docs for `tokens_to_duration` say that a smaller than expected duration may be
202 // returned, we have a test `test_duration_token_round_trip` which ensures that
203 // `tokens_to_duration` returns the expected value when used with the result from
204 // `duration_to_tokens`.
205 let added_tokens_at_inc =
206 Self::tokens_to_duration(bucket_inc, self.rate).unwrap_or(Duration::ZERO);
207
208 self.added_tokens_at = self
209 .added_tokens_at
210 .checked_add(added_tokens_at_inc)
211 .expect("overflowed time");
212 debug_assert!(self.added_tokens_at <= now);
213
214 if old_bucket == 0 && self.bucket != 0 {
215 BecameNonEmpty::Yes
216 } else {
217 BecameNonEmpty::No
218 }
219 }
220
221 /// How long would it take to refill `tokens` at `rate`?
222 ///
223 /// The result is rounded up to the nearest microsecond.
224 /// If the number of `tokens` is large,
225 /// the result may be much lower than the expected duration due to saturating 64-bit arithmetic.
226 ///
227 /// `None` will be returned if the `rate` is 0.
228 fn tokens_to_duration(tokens: u64, rate: u64) -> Option<Duration> {
229 // Perform the calculation in microseconds rather than nanoseconds since timers typically
230 // have microsecond granularity, and it lowers the chance that the calculation overflows the
231 // `u64::MAX` limit compared to nanoseconds. In the case that the calculation saturates, the
232 // returned duration will be shorter than the real value.
233 //
234 // For example with `tokens = u64::MAX` and `rate = u64::MAX` we'd expect a result of 1
235 // second, but:
236 // u64::MAX.saturating_mul(1000 * 1000).div_ceil(u64::MAX) = 1 microsecond
237 //
238 // The `div_ceil` ensures we always round up to the nearest microsecond.
239 //
240 // dimensional analysis:
241 // (tokens) * (microseconds / second) / (tokens / second) = microseconds
242 if rate == 0 {
243 return None;
244 }
245 let micros = tokens.saturating_mul(1000 * 1000).div_ceil(rate);
246 Some(Duration::from_micros(micros))
247 }
248
249 /// How many tokens would be refilled within `time` at `rate`?
250 ///
251 /// The `time` is truncated to microsecond granularity.
252 /// If the `time` or `rate` is large,
253 /// the result may be much lower than the expected number of tokens due to saturating 64-bit
254 /// arithmetic.
255 fn duration_to_tokens(time: Duration, rate: u64) -> u64 {
256 let micros = u64::try_from(time.as_micros()).unwrap_or(u64::MAX);
257 // dimensional analysis:
258 // (tokens / second) * (microseconds) / (microseconds / second) = tokens
259 rate.saturating_mul(micros) / (1000 * 1000)
260 }
261}
262
263/// The refill rate and token max for a [`TokenBucket`].
264#[derive(Clone, Debug)]
265pub(crate) struct TokenBucketConfig {
266 /// The refill rate in tokens/second.
267 pub(crate) rate: u64,
268 /// The max amount of tokens in the bucket.
269 /// Commonly referred to as the "burst".
270 pub(crate) bucket_max: u64,
271}
272
273/// A handle to a number of claimed tokens.
274///
275/// Dropping this handle will commit the claim.
276#[derive(Debug)]
277pub(crate) struct ClaimedTokens<'a, I> {
278 /// The bucket that the claim is for.
279 bucket: &'a mut TokenBucket<I>,
280 /// How many tokens to remove from the bucket.
281 count: u64,
282}
283
284impl<'a, I> ClaimedTokens<'a, I> {
285 /// Create a new [`ClaimedTokens`] that will remove `count` tokens from the token `bucket` when
286 /// dropped.
287 fn new(bucket: &'a mut TokenBucket<I>, count: u64) -> Self {
288 Self { bucket, count }
289 }
290
291 /// Commit the claimed tokens.
292 ///
293 /// This is equivalent to just dropping the [`ClaimedTokens`], but also returns whether the
294 /// token bucket became empty or not.
295 pub(crate) fn commit(mut self) -> BecameEmpty {
296 self.commit_impl()
297 }
298
299 /// Reduce the claim to a fewer number of tokens than the original claim.
300 ///
301 /// If `count` is larger than the original claim, an error will be returned containing the
302 /// current number of claimed tokens.
303 pub(crate) fn reduce(&mut self, count: u64) -> Result<(), InsufficientTokensError> {
304 if count > self.count {
305 return Err(InsufficientTokensError {
306 available: self.count,
307 });
308 }
309
310 self.count = count;
311 Ok(())
312 }
313
314 /// Discard the claim.
315 ///
316 /// This does not remove any tokens from the token bucket.
317 pub(crate) fn discard(mut self) {
318 self.count = 0;
319 }
320
321 /// The commit implementation.
322 ///
323 /// After calling [`commit_impl()`](Self::commit_impl),
324 /// the [`ClaimedTokens`] should no longer be used and should be dropped immediately.
325 fn commit_impl(&mut self) -> BecameEmpty {
326 // when the `ClaimedTokens` was created by the `TokenBucket`, it should have ensured that
327 // there were enough tokens
328 self.bucket.bucket = self
329 .bucket
330 .bucket
331 .checked_sub(self.count)
332 .unwrap_or_else(|| {
333 panic!(
334 "claim commit failed: {}, {}",
335 self.count, self.bucket.bucket,
336 )
337 });
338
339 // when `self` is dropped some time after this function ends,
340 // we don't want to subtract again
341 self.count = 0;
342
343 if self.bucket.bucket > 0 {
344 BecameEmpty::No
345 } else {
346 BecameEmpty::Yes
347 }
348 }
349}
350
351impl<'a, I> std::ops::Drop for ClaimedTokens<'a, I> {
352 fn drop(&mut self) {
353 self.commit_impl();
354 }
355}
356
357/// An operation was attempted to reduce the number of tokens,
358/// but the token bucket did not have enough tokens.
359#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
360#[error("insufficient tokens for operation")]
361pub(crate) struct InsufficientTokensError {
362 /// The number of tokens that are available to drain/commit.
363 available: u64,
364}
365
366impl InsufficientTokensError {
367 /// Get the number of tokens that are available to drain/commit.
368 pub(crate) fn available_tokens(&self) -> u64 {
369 self.available
370 }
371}
372
373/// The token bucket will never have the requested number of tokens.
374#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
375#[error("there will never be enough tokens for this operation")]
376pub(crate) enum NeverEnoughTokensError {
377 /// The request exceeds the bucket's maximum number of tokens.
378 ExceedsMaxTokens,
379 /// The refill rate is 0.
380 ZeroRate,
381 /// The time is not representable.
382 ///
383 /// For example the if the rate is low and a large number of tokens were requested, it may be
384 /// too far in the future that it cannot be represented as a time value.
385 InstantNotRepresentable,
386}
387
388/// The token bucket transitioned from "empty" to "non-empty".
389#[derive(Copy, Clone, Debug, PartialEq, Eq)]
390pub(crate) enum BecameNonEmpty {
391 /// Token bucket became non-empty.
392 Yes,
393 /// Token bucket remains empty.
394 No,
395}
396
397/// The token bucket transitioned from "non-empty" to "empty".
398#[derive(Copy, Clone, Debug, PartialEq, Eq)]
399pub(crate) enum BecameEmpty {
400 /// Token bucket became empty.
401 Yes,
402 /// Token bucket remains non-empty.
403 No,
404}
405
406/// Any type implementing this must be represented as a measurement of a monotonically nondecreasing
407/// clock.
408pub(crate) trait TokenBucketInstant:
409 Copy + Clone + Debug + PartialEq + Eq + PartialOrd + Ord
410{
411 /// An unrealistically large time jump.
412 ///
413 /// We assume that any time change larger than this indicates a broken monotonic clock,
414 /// and the bucket will not be refilled.
415 const IGNORE_THRESHOLD: Duration;
416
417 /// See [`Instant::checked_add`].
418 fn checked_add(&self, duration: Duration) -> Option<Self>;
419
420 /// See [`Instant::checked_duration_since`].
421 fn checked_duration_since(&self, earlier: Self) -> Option<Duration>;
422
423 /// See [`Instant::saturating_duration_since`].
424 fn saturating_duration_since(&self, earlier: Self) -> Duration {
425 self.checked_duration_since(earlier).unwrap_or_default()
426 }
427}
428
429impl TokenBucketInstant for Instant {
430 // This value is taken from tor (see `elapsed_ticks <= UINT32_MAX/4` in
431 // `src/lib/evloop/token_bucket.c`).
432 const IGNORE_THRESHOLD: Duration = Duration::from_secs((u32::MAX / 4) as u64);
433
434 #[inline]
435 fn checked_add(&self, duration: Duration) -> Option<Self> {
436 self.checked_add(duration)
437 }
438
439 #[inline]
440 fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
441 self.checked_duration_since(earlier)
442 }
443
444 #[inline]
445 fn saturating_duration_since(&self, earlier: Self) -> Duration {
446 self.saturating_duration_since(earlier)
447 }
448}
449
450#[cfg(test)]
451mod test {
452 #![allow(clippy::unwrap_used)]
453
454 use super::*;
455
456 use rand::Rng;
457
458 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
459 struct MillisTimestamp(u64);
460
461 impl TokenBucketInstant for MillisTimestamp {
462 const IGNORE_THRESHOLD: Duration = Duration::from_millis(1_000_000_000);
463
464 fn checked_add(&self, duration: Duration) -> Option<Self> {
465 let duration = u64::try_from(duration.as_millis()).ok()?;
466 self.0.checked_add(duration).map(Self)
467 }
468
469 fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
470 Some(Duration::from_millis(self.0.checked_sub(earlier.0)?))
471 }
472 }
473
474 #[test]
475 fn adjust_now() {
476 let time = MillisTimestamp(100);
477
478 let config = TokenBucketConfig {
479 rate: 10,
480 bucket_max: 100,
481 };
482 let mut tb = TokenBucket::new(&config, time);
483 assert_eq!(tb.bucket, 100);
484 assert_eq!(tb.bucket_max, 100);
485 assert_eq!(tb.rate, 10);
486
487 tb.adjust(
488 time,
489 &TokenBucketConfig {
490 rate: 20,
491 bucket_max: 100,
492 },
493 );
494 assert_eq!(tb.bucket, 100);
495 assert_eq!(tb.bucket_max, 100);
496
497 tb.adjust(
498 time,
499 &TokenBucketConfig {
500 rate: 20,
501 bucket_max: 40,
502 },
503 );
504 assert_eq!(tb.bucket, 40);
505 assert_eq!(tb.bucket_max, 40);
506
507 tb.adjust(
508 time,
509 &TokenBucketConfig {
510 rate: 20,
511 bucket_max: 100,
512 },
513 );
514 assert_eq!(tb.bucket, 40);
515 assert_eq!(tb.bucket_max, 100);
516
517 tb.adjust(
518 time,
519 &TokenBucketConfig {
520 rate: 200,
521 bucket_max: 100,
522 },
523 );
524 assert_eq!(tb.bucket, 40);
525 assert_eq!(tb.bucket_max, 100);
526 assert_eq!(tb.rate, 200);
527 }
528
529 #[test]
530 fn adjust_future() {
531 let config = TokenBucketConfig {
532 rate: 10,
533 bucket_max: 100,
534 };
535 let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
536 assert_eq!(tb.bucket, 100);
537 assert_eq!(tb.bucket_max, 100);
538 assert_eq!(tb.rate, 10);
539
540 // at 300 ms: increase rate and max; bucket was already full, so doesn't gain any tokens
541 tb.adjust(
542 MillisTimestamp(300),
543 &TokenBucketConfig {
544 rate: 20,
545 bucket_max: 200,
546 },
547 );
548 assert_eq!(tb.bucket, 100);
549 assert_eq!(tb.bucket_max, 200);
550
551 // at 500 ms: no changes; bucket is refilled during `adjust()`, so gains 4 tokens
552 tb.adjust(
553 MillisTimestamp(500),
554 &TokenBucketConfig {
555 rate: 20,
556 bucket_max: 200,
557 },
558 );
559 assert_eq!(tb.bucket, 104);
560 assert_eq!(tb.bucket_max, 200);
561
562 // at 700 ms: lower rate and max; bucket is lowered to new max, so loses 4 tokens
563 tb.adjust(
564 MillisTimestamp(700),
565 &TokenBucketConfig {
566 rate: 0,
567 bucket_max: 100,
568 },
569 );
570 assert_eq!(tb.bucket, 100);
571 assert_eq!(tb.bucket_max, 100);
572
573 // at 900 ms: raise rate and max; rate was previously 0 so doesn't gain any tokens
574 tb.adjust(
575 MillisTimestamp(900),
576 &TokenBucketConfig {
577 rate: 100,
578 bucket_max: 200,
579 },
580 );
581 assert_eq!(tb.bucket, 100);
582 assert_eq!(tb.bucket_max, 200);
583 }
584
585 #[test]
586 fn adjust_zero() {
587 let time = MillisTimestamp(100);
588
589 let config = TokenBucketConfig {
590 rate: 10,
591 bucket_max: 100,
592 };
593
594 let mut tb = TokenBucket::new(&config, time);
595 tb.adjust(
596 time,
597 &TokenBucketConfig {
598 rate: 0,
599 bucket_max: 200,
600 },
601 );
602 assert_eq!(tb.bucket, 100);
603 assert_eq!(tb.bucket_max, 200);
604 assert_eq!(tb.rate, 0);
605 // bucket should not increase
606 tb.refill(MillisTimestamp(10_000_000));
607 assert_eq!(tb.bucket, 100);
608
609 let mut tb = TokenBucket::new(&config, time);
610 tb.adjust(
611 time,
612 &TokenBucketConfig {
613 rate: 10,
614 bucket_max: 0,
615 },
616 );
617 assert_eq!(tb.bucket, 0);
618 assert_eq!(tb.bucket_max, 0);
619 assert_eq!(tb.rate, 10);
620 // bucket should stay empty
621 tb.refill(MillisTimestamp(10_000_000));
622 assert_eq!(tb.bucket, 0);
623
624 let mut tb = TokenBucket::new(&config, time);
625 tb.adjust(
626 time,
627 &TokenBucketConfig {
628 rate: 0,
629 bucket_max: 0,
630 },
631 );
632 assert_eq!(tb.bucket, 0);
633 assert_eq!(tb.bucket_max, 0);
634 assert_eq!(tb.rate, 0);
635 // bucket should stay empty
636 tb.refill(MillisTimestamp(10_000_000));
637 assert_eq!(tb.bucket, 0);
638 }
639
640 #[test]
641 fn is_empty() {
642 // increases 10 tokens/second (one every 100 ms)
643 let config = TokenBucketConfig {
644 rate: 10,
645 bucket_max: 100,
646 };
647 let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
648 assert!(!tb.is_empty());
649
650 tb.drain(99).unwrap();
651 assert!(!tb.is_empty());
652
653 tb.drain(1).unwrap();
654 assert!(tb.is_empty());
655
656 tb.refill(MillisTimestamp(199));
657 assert!(tb.is_empty());
658
659 tb.refill(MillisTimestamp(200));
660 assert!(!tb.is_empty());
661 }
662
663 #[test]
664 fn correctness() {
665 // increases 10 tokens/second (one every 100 ms)
666 let config = TokenBucketConfig {
667 rate: 10,
668 bucket_max: 100,
669 };
670 let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
671
672 tb.drain(50).unwrap();
673 assert_eq!(tb.bucket, 50);
674
675 tb.refill(MillisTimestamp(1100));
676 assert_eq!(tb.bucket, 60);
677
678 tb.drain(50).unwrap();
679 assert_eq!(tb.bucket, 10);
680
681 tb.refill(MillisTimestamp(2100));
682 assert_eq!(tb.bucket, 20);
683
684 tb.refill(MillisTimestamp(2101));
685 assert_eq!(tb.bucket, 20);
686 tb.refill(MillisTimestamp(2199));
687 assert_eq!(tb.bucket, 20);
688 tb.refill(MillisTimestamp(2200));
689 assert_eq!(tb.bucket, 21);
690 }
691
692 #[test]
693 fn rounding() {
694 // increases 10 tokens/second (one every 100 ms)
695 let config = TokenBucketConfig {
696 rate: 10,
697 bucket_max: 100,
698 };
699 let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
700 tb.drain(100).unwrap();
701
702 // ensure that refilling at 150 ms does not change the `added_tokens_at` time to 150 ms,
703 // otherwise the next refill wouldn't occur until 250 ms instead of 200 ms
704 tb.refill(MillisTimestamp(99));
705 assert_eq!(tb.bucket, 0);
706 tb.refill(MillisTimestamp(150));
707 assert_eq!(tb.bucket, 1);
708 tb.refill(MillisTimestamp(199));
709 assert_eq!(tb.bucket, 1);
710 tb.refill(MillisTimestamp(200));
711 assert_eq!(tb.bucket, 2);
712 }
713
714 #[test]
715 fn tokens_available_at() {
716 // increases 10 tokens/second (one every 100 ms)
717 let config = TokenBucketConfig {
718 rate: 10,
719 bucket_max: 100,
720 };
721 let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
722
723 // bucket is empty at 0 ms, next token at 100 ms
724 tb.drain(100).unwrap();
725
726 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
727 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
728 assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
729
730 // bucket is still empty at 40 ms, next token at 100 ms
731 tb.refill(MillisTimestamp(40));
732
733 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
734 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
735 assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
736
737 // bucket has 1 token at 100 ms, next token at 200 ms
738 tb.refill(MillisTimestamp(100));
739
740 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
741 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
742 assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
743
744 // bucket is empty at 100 ms, next token at 200 ms
745 tb.drain(1).unwrap();
746
747 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
748 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
749 assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
750
751 // bucket is empty at 140 ms, next token at 200 ms
752 tb.refill(MillisTimestamp(140));
753
754 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
755 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
756 assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
757
758 // bucket has 1 token at 210 ms, next token at 300 ms
759 tb.refill(MillisTimestamp(210));
760
761 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(200)));
762 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
763 assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
764
765 use NeverEnoughTokensError as NETE;
766
767 assert_eq!(tb.tokens_available_at(100), Ok(MillisTimestamp(10_100)));
768 assert_eq!(tb.tokens_available_at(101), Err(NETE::ExceedsMaxTokens));
769 assert_eq!(
770 tb.tokens_available_at(u64::MAX),
771 Err(NETE::ExceedsMaxTokens),
772 );
773
774 // set the refill rate to 0; note that adjusting the rate also resets `added_tokens_at`
775 tb.adjust(
776 MillisTimestamp(210),
777 &TokenBucketConfig {
778 rate: 0,
779 bucket_max: 100,
780 },
781 );
782
783 assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(210)));
784 assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(210)));
785 assert_eq!(tb.tokens_available_at(2), Err(NETE::ZeroRate));
786 }
787
788 #[test]
789 fn test_duration_token_round_trip() {
790 let tokens_to_duration = TokenBucket::<Instant>::tokens_to_duration;
791 let duration_to_tokens = TokenBucket::<Instant>::duration_to_tokens;
792
793 // start with some hand-picked cases
794 let mut duration_rate_pairs = vec![
795 (Duration::from_nanos(0), 1),
796 (Duration::from_nanos(1), 1),
797 (Duration::from_micros(2), 1),
798 (Duration::MAX, 1),
799 (Duration::from_nanos(0), 3),
800 (Duration::from_nanos(1), 3),
801 (Duration::from_micros(2), 3),
802 (Duration::MAX, 3),
803 (Duration::from_nanos(0), 1000),
804 (Duration::from_nanos(1), 1000),
805 (Duration::from_micros(2), 1000),
806 (Duration::MAX, 1000),
807 (Duration::from_nanos(0), u64::MAX),
808 (Duration::from_nanos(1), u64::MAX),
809 (Duration::from_micros(2), u64::MAX),
810 (Duration::MAX, u64::MAX),
811 ];
812
813 let mut rng = rand::rng();
814
815 // add some fuzzing
816 for _ in 0..10_000 {
817 let secs = rng.random();
818 let nanos = rng.random();
819 // Duration::new() may panic, so just skip if there's a panic rather than trying to
820 // write our own logic to avoid the panic in the first place
821 let Ok(random_duration) = std::panic::catch_unwind(|| Duration::new(secs, nanos))
822 else {
823 continue;
824 };
825 let random_rate = rng.random();
826 duration_rate_pairs.push((random_duration, random_rate));
827 }
828
829 // for various combinations of durations and rates, we ensure that after an initial
830 // `duration_to_tokens` calculation which may truncate, a round-trip between
831 // `tokens_to_duration` and `duration_to_tokens` isn't lossy
832 for (original_duration, rate) in duration_rate_pairs {
833 // this may give a smaller number of tokens than expected (see docs on
834 // `TokenBucket::duration_to_tokens`)
835 let tokens = duration_to_tokens(original_duration, rate);
836
837 // we want to ensure that converting these `tokens` to a duration and then back to
838 // tokens is not lossy, which implies that `tokens_to_duration` is returning the
839 // expected value and not a truncated value due to saturating arithmetic
840 let duration = tokens_to_duration(tokens, rate).unwrap();
841 assert_eq!(tokens, duration_to_tokens(duration, rate));
842 }
843 }
844}