tor_proto/util/token_bucket/
writer.rs

1//! An [`AsyncWrite`] rate limiter.
2
3use std::future::Future;
4use std::num::NonZero;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::time::{Duration, Instant};
8
9use futures::io::Error;
10use futures::AsyncWrite;
11use sync_wrapper::SyncFuture;
12use tor_rtcompat::SleepProvider;
13
14use super::bucket::{NeverEnoughTokensError, TokenBucket, TokenBucketConfig};
15
16/// A rate-limited async [writer](AsyncWrite).
17///
18/// This can be used as a wrapper around an existing [`AsyncWrite`] writer.
19#[derive(educe::Educe)]
20#[educe(Debug)]
21#[pin_project::pin_project]
22pub(crate) struct RateLimitedWriter<W: AsyncWrite, P: SleepProvider> {
23    /// The token bucket.
24    bucket: TokenBucket<Instant>,
25    /// The sleep provider, for getting the current time and creating new sleep futures.
26    ///
27    /// While we use [`Instant`] for the time, we should always get the time from this
28    /// [`SleepProvider`].
29    /// For example, use [`SleepProvider::now()`], not [`Instant::now()`].
30    #[educe(Debug(ignore))]
31    sleep_provider: P,
32    /// See [`RateLimitedWriterConfig::wake_when_bytes_available`].
33    wake_when_bytes_available: NonZero<u64>,
34    /// The inner writer.
35    #[educe(Debug(ignore))]
36    #[pin]
37    inner: W,
38    /// We need to store the sleep future if [`AsyncWrite::poll_write()`] blocks.
39    #[educe(Debug(ignore))]
40    #[pin]
41    sleep_fut: Option<SyncFuture<P::SleepFuture>>,
42}
43
44impl<W, P> RateLimitedWriter<W, P>
45where
46    W: AsyncWrite,
47    P: SleepProvider,
48{
49    /// Create a new [`RateLimitedWriter`].
50    // We take the rate and bucket max directly rather than a `TokenBucket` to ensure that the token
51    // bucket only ever uses times from `sleep_provider`.
52    pub(crate) fn new(writer: W, config: &RateLimitedWriterConfig, sleep_provider: P) -> Self {
53        let bucket_config = TokenBucketConfig {
54            rate: config.rate,
55            bucket_max: config.burst,
56        };
57        Self::from_token_bucket(
58            writer,
59            TokenBucket::new(&bucket_config, sleep_provider.now()),
60            config.wake_when_bytes_available,
61            sleep_provider,
62        )
63    }
64
65    /// Create a new [`RateLimitedWriter`] from a [`TokenBucket`].
66    ///
67    /// The token bucket must have only been used with times created by `sleep_provider`.
68    #[cfg_attr(test, visibility::make(pub(super)))]
69    fn from_token_bucket(
70        writer: W,
71        bucket: TokenBucket<Instant>,
72        wake_when_bytes_available: NonZero<u64>,
73        sleep_provider: P,
74    ) -> Self {
75        Self {
76            bucket,
77            sleep_provider,
78            wake_when_bytes_available,
79            inner: writer,
80            sleep_fut: None,
81        }
82    }
83
84    /// Access the inner [`AsyncWrite`] writer.
85    pub(crate) fn inner(&self) -> &W {
86        &self.inner
87    }
88
89    /// Adjust the refill rate and burst.
90    ///
91    /// A rate and/or burst of 0 is allowed.
92    pub(crate) fn adjust(
93        self: &mut Pin<&mut Self>,
94        now: Instant,
95        config: &RateLimitedWriterConfig,
96    ) {
97        let self_ = self.as_mut().project();
98
99        // destructuring allows us to make sure we aren't forgetting to handle any fields
100        let RateLimitedWriterConfig {
101            rate,
102            burst,
103            wake_when_bytes_available,
104        } = *config;
105
106        let bucket_config = TokenBucketConfig {
107            rate,
108            bucket_max: burst,
109        };
110
111        self_.bucket.adjust(now, &bucket_config);
112        *self_.wake_when_bytes_available = wake_when_bytes_available;
113    }
114
115    /// The sleep provider.
116    ///
117    /// We don't want this to be generally accessible, only to other token bucket-related modules
118    /// like [`DynamicRateLimitedWriter`](super::dynamic_writer::DynamicRateLimitedWriter).
119    pub(super) fn sleep_provider(&self) -> &P {
120        &self.sleep_provider
121    }
122
123    /// Configure this writer to sleep for `duration`.
124    ///
125    /// A `duration` of `None` is interpreted as "forever".
126    ///
127    /// It's considered a bug if asked to sleep for `Duration::ZERO` time.
128    fn register_sleep(
129        sleep_fut: &mut Pin<&mut Option<SyncFuture<P::SleepFuture>>>,
130        sleep_provider: &mut P,
131        cx: &mut Context<'_>,
132        duration: Option<Duration>,
133    ) -> Poll<()> {
134        match duration {
135            None => {
136                sleep_fut.as_mut().set(None);
137                Poll::Pending
138            }
139            Some(duration) => {
140                debug_assert_ne!(duration, Duration::ZERO, "asked to sleep for 0 time");
141                sleep_fut
142                    .as_mut()
143                    .set(Some(SyncFuture::new(sleep_provider.sleep(duration))));
144                sleep_fut
145                    .as_mut()
146                    .as_pin_mut()
147                    .expect("but we just set it to `Some`?!")
148                    .poll(cx)
149            }
150        }
151    }
152}
153
154impl<W, P> AsyncWrite for RateLimitedWriter<W, P>
155where
156    W: AsyncWrite,
157    P: SleepProvider,
158{
159    fn poll_write(
160        mut self: Pin<&mut Self>,
161        cx: &mut Context<'_>,
162        mut buf: &[u8],
163    ) -> Poll<Result<usize, Error>> {
164        let mut self_ = self.as_mut().project();
165
166        // this should be optimized to a no-op on at least x86-64
167        fn to_u64(x: usize) -> u64 {
168            x.try_into().expect("failed usize to u64 conversion")
169        }
170
171        // for an empty buffer, just defer to the inner writer's impl
172        if buf.is_empty() {
173            return self_.inner.poll_write(cx, buf);
174        }
175
176        let now = self_.sleep_provider.now();
177
178        // refill the bucket and attempt to claim all of the bytes
179        self_.bucket.refill(now);
180        let claim = self_.bucket.claim(to_u64(buf.len()));
181
182        let mut claim = match claim {
183            // claim was successful
184            Ok(x) => x,
185            // not enough tokens, so let's use a smaller buffer
186            Err(e) => {
187                let available = e.available_tokens();
188
189                // need to drop the old claim so that we can access the token bucket again
190                drop(claim);
191
192                // if no tokens in bucket, we must sleep
193                if available == 0 {
194                    // number of tokens we'll wait for
195                    let wake_at_tokens = to_u64(buf.len());
196
197                    // If the user wants to write X tokens, we don't necessarily want to sleep until
198                    // we have room for X tokens. We also don't want to wake every time that a
199                    // single byte can be written. We allow the user to configure this threshold
200                    // with `RateLimitedWriterConfig::wake_when_bytes_available`.
201                    let wake_at_tokens =
202                        std::cmp::min(wake_at_tokens, self_.wake_when_bytes_available.get());
203
204                    // max number of tokens the bucket can hold
205                    let bucket_max = self_.bucket.max();
206
207                    // how long to sleep for; `None` indicates to sleep forever
208                    let sleep_for = if bucket_max == 0 {
209                        // bucket can't hold any tokens, so sleep forever
210                        None
211                    } else {
212                        // if the bucket has a max of X tokens, we should never try to wait for >X
213                        // tokens
214                        let wake_at_tokens = std::cmp::min(wake_at_tokens, bucket_max);
215
216                        // if we asked for 0 tokens, we'd get a time of ~now, which is not what we
217                        // want
218                        debug_assert!(wake_at_tokens > 0);
219
220                        let wake_at = self_.bucket.tokens_available_at(wake_at_tokens);
221                        let sleep_for = wake_at.map(|x| x.saturating_duration_since(now));
222
223                        match sleep_for {
224                            Ok(x) => Some(x),
225                            Err(NeverEnoughTokensError::ExceedsMaxTokens) => {
226                                panic!(
227                                    "exceeds max tokens, but we took the max into account above"
228                                );
229                            }
230                            // we aren't refilling, so sleep forever
231                            Err(NeverEnoughTokensError::ZeroRate) => None,
232                            // too far in the future to be represented, so sleep forever
233                            Err(NeverEnoughTokensError::InstantNotRepresentable) => None,
234                        }
235                    };
236
237                    // configure the sleep future and poll it to register
238                    let poll = Self::register_sleep(
239                        &mut self_.sleep_fut,
240                        self_.sleep_provider,
241                        cx,
242                        sleep_for,
243                    );
244                    return match poll {
245                        // wait for the sleep to finish
246                        Poll::Pending => Poll::Pending,
247                        // The sleep is already ready?! A recursive call here isn't great, but
248                        // there's not much else we can do here. Hopefully this second `poll_write`
249                        // will succeed since we should now have enough tokens.
250                        Poll::Ready(()) => self.poll_write(cx, buf),
251                    };
252                }
253
254                /// Convert a `u64` to `usize`, saturating if size of `usize` is smaller than `u64`.
255                // This is a separate function to ensure we don't accidentally try to convert a
256                // signed integer into a `usize`, in which case `unwrap_or(MAX)` wouldn't make
257                // sense.
258                fn to_usize_saturating(x: u64) -> usize {
259                    x.try_into().unwrap_or(usize::MAX)
260                }
261
262                // There are tokens, so try to write as many as are available.
263                let available_usize = to_usize_saturating(available);
264                buf = &buf[0..available_usize];
265                self_.bucket.claim(to_u64(buf.len())).unwrap_or_else(|_| {
266                    panic!(
267                        "bucket has {available} tokens available, but can't claim {}?",
268                        buf.len(),
269                    )
270                })
271            }
272        };
273
274        let rv = self_.inner.poll_write(cx, buf);
275
276        match rv {
277            // no bytes were written, so discard the claim
278            Poll::Pending | Poll::Ready(Err(_)) => claim.discard(),
279            // `x` bytes were written, so only commit those tokens
280            Poll::Ready(Ok(x)) => {
281                if x <= buf.len() {
282                    claim
283                        .reduce(to_u64(x))
284                        .expect("can't commit fewer tokens?!");
285                    claim.commit();
286                } else {
287                    cfg_if::cfg_if! {
288                        if #[cfg(debug_assertions)] {
289                            panic!(
290                                "Writer is claiming it wrote more bytes {x} than we gave it {}",
291                                buf.len(),
292                            );
293                        } else {
294                            // the best we can do is to just claim the original amount
295                            claim.commit();
296                        }
297                    }
298                }
299            }
300        };
301
302        rv
303    }
304
305    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
306        self.project().inner.poll_flush(cx)
307    }
308
309    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
310        // some implementers of `AsyncWrite` (like `Vec`) don't do anything other than flush when
311        // closed and will continue to accept bytes even after being closed, so we must continue to
312        // apply rate limiting even after being closed
313        self.project().inner.poll_close(cx)
314    }
315}
316
317/// A module to make it easier to implement tokio traits without putting `cfg()` conditionals
318/// everywhere.
319#[cfg(feature = "tokio")]
320mod tokio_impl {
321    use super::*;
322
323    use tokio_crate::io::AsyncWrite as TokioAsyncWrite;
324    use tokio_util::compat::FuturesAsyncWriteCompatExt;
325
326    use std::io::Result as IoResult;
327
328    impl<W, P> TokioAsyncWrite for RateLimitedWriter<W, P>
329    where
330        W: AsyncWrite,
331        P: SleepProvider,
332    {
333        fn poll_write(
334            self: Pin<&mut Self>,
335            cx: &mut Context<'_>,
336            buf: &[u8],
337        ) -> Poll<IoResult<usize>> {
338            TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
339        }
340
341        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
342            TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
343        }
344
345        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
346            TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
347        }
348    }
349}
350
351/// The refill rate and burst for a [`RateLimitedWriter`].
352#[derive(Clone, Debug)]
353pub(crate) struct RateLimitedWriterConfig {
354    /// The refill rate in bytes/second.
355    pub(crate) rate: u64,
356    /// The "burst" in bytes.
357    pub(crate) burst: u64,
358    /// When polled, block until at most this many bytes are available.
359    ///
360    /// Or in other words, wake when we can write this many bytes, even if the provided buffer is
361    /// larger.
362    ///
363    /// For example if a user attempts to write a large buffer, we usually don't want to block until
364    /// the entire buffer can be written. We'd prefer several partial writes to a single large
365    /// write. So instead of blocking until the entire buffer can be written, we only block until
366    /// at most this many bytes are available.
367    pub(crate) wake_when_bytes_available: NonZero<u64>,
368}
369
370#[cfg(test)]
371mod test {
372    #![allow(clippy::unwrap_used)]
373
374    use super::*;
375
376    use futures::task::SpawnExt;
377    use futures::{AsyncWriteExt, FutureExt};
378
379    #[test]
380    fn writer() {
381        tor_rtmock::MockRuntime::test_with_various(|rt| async move {
382            let start = rt.now();
383
384            // increases 10 tokens/second (one every 100 ms)
385            let config = TokenBucketConfig {
386                rate: 10,
387                bucket_max: 100,
388            };
389            let mut tb = TokenBucket::new(&config, start);
390            // drain the bucket
391            tb.drain(100).unwrap();
392
393            let wake_when_bytes_available = NonZero::new(15).unwrap();
394
395            let mut writer = Vec::new();
396            let mut writer = RateLimitedWriter::from_token_bucket(
397                &mut writer,
398                tb,
399                wake_when_bytes_available,
400                rt.clone(),
401            );
402
403            // drive time forward from 0 to 20_000 ms in 50 ms intervals
404            let rt_clone = rt.clone();
405            rt.spawn(async move {
406                for _ in 0..400 {
407                    rt_clone.progress_until_stalled().await;
408                    rt_clone.advance_by(Duration::from_millis(50)).await;
409                }
410            })
411            .unwrap();
412
413            // try writing 60 bytes, which sleeps until we can write at least 15 of them
414            assert_eq!(15, writer.write(&[0; 60]).await.unwrap());
415            assert_eq!(1500, rt.now().duration_since(start).as_millis());
416
417            // wait 2 seconds
418            rt.sleep(Duration::from_millis(2000)).await;
419
420            // ensure that we can write immediately, and that we can write
421            // 2000 ms / (100 ms/token) = 20 bytes
422            assert_eq!(
423                Some(20),
424                writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
425            );
426        });
427    }
428
429    /// Test that writing to a token bucket which has a rate and/or max of 0 works as expected.
430    #[test]
431    fn rate_burst_zero() {
432        let configs = [
433            // non-zero rate, zero max
434            TokenBucketConfig {
435                rate: 10,
436                bucket_max: 0,
437            },
438            // zero rate, non-zero max
439            TokenBucketConfig {
440                rate: 0,
441                bucket_max: 10,
442            },
443            // zero rate, zero max
444            TokenBucketConfig {
445                rate: 0,
446                bucket_max: 0,
447            },
448        ];
449        for config in configs {
450            tor_rtmock::MockRuntime::test_with_various(|rt| {
451                let config = config.clone();
452                async move {
453                    // an empty token bucket
454                    let mut tb = TokenBucket::new(&config, rt.now());
455                    tb.drain(tb.max()).unwrap();
456                    assert!(tb.is_empty());
457
458                    let wake_when_bytes_available = NonZero::new(2).unwrap();
459
460                    let mut writer = Vec::new();
461                    let mut writer = RateLimitedWriter::from_token_bucket(
462                        &mut writer,
463                        tb,
464                        wake_when_bytes_available,
465                        rt.clone(),
466                    );
467
468                    // drive time forward from 0 to 10_000 ms in 100 ms intervals
469                    let rt_clone = rt.clone();
470                    rt.spawn(async move {
471                        for _ in 0..100 {
472                            rt_clone.progress_until_stalled().await;
473                            rt_clone.advance_by(Duration::from_millis(100)).await;
474                        }
475                    })
476                    .unwrap();
477
478                    // ensure that a write returns `Pending`
479                    assert_eq!(
480                        None,
481                        writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
482                    );
483
484                    // wait 5 seconds
485                    rt.sleep(Duration::from_millis(5000)).await;
486
487                    // ensure that a write still returns `Pending`
488                    assert_eq!(
489                        None,
490                        writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
491                    );
492                }
493            });
494        }
495    }
496}