tor_proto/util/token_bucket/writer.rs
1//! An [`AsyncWrite`] rate limiter.
2
3use std::future::Future;
4use std::num::NonZero;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::time::{Duration, Instant};
8
9use futures::io::Error;
10use futures::AsyncWrite;
11use sync_wrapper::SyncFuture;
12use tor_rtcompat::SleepProvider;
13
14use super::bucket::{NeverEnoughTokensError, TokenBucket, TokenBucketConfig};
15
16/// A rate-limited async [writer](AsyncWrite).
17///
18/// This can be used as a wrapper around an existing [`AsyncWrite`] writer.
19#[derive(educe::Educe)]
20#[educe(Debug)]
21#[pin_project::pin_project]
22pub(crate) struct RateLimitedWriter<W: AsyncWrite, P: SleepProvider> {
23 /// The token bucket.
24 bucket: TokenBucket<Instant>,
25 /// The sleep provider, for getting the current time and creating new sleep futures.
26 ///
27 /// While we use [`Instant`] for the time, we should always get the time from this
28 /// [`SleepProvider`].
29 /// For example, use [`SleepProvider::now()`], not [`Instant::now()`].
30 #[educe(Debug(ignore))]
31 sleep_provider: P,
32 /// See [`RateLimitedWriterConfig::wake_when_bytes_available`].
33 wake_when_bytes_available: NonZero<u64>,
34 /// The inner writer.
35 #[educe(Debug(ignore))]
36 #[pin]
37 inner: W,
38 /// We need to store the sleep future if [`AsyncWrite::poll_write()`] blocks.
39 #[educe(Debug(ignore))]
40 #[pin]
41 sleep_fut: Option<SyncFuture<P::SleepFuture>>,
42}
43
44impl<W, P> RateLimitedWriter<W, P>
45where
46 W: AsyncWrite,
47 P: SleepProvider,
48{
49 /// Create a new [`RateLimitedWriter`].
50 // We take the rate and bucket max directly rather than a `TokenBucket` to ensure that the token
51 // bucket only ever uses times from `sleep_provider`.
52 pub(crate) fn new(writer: W, config: &RateLimitedWriterConfig, sleep_provider: P) -> Self {
53 let bucket_config = TokenBucketConfig {
54 rate: config.rate,
55 bucket_max: config.burst,
56 };
57 Self::from_token_bucket(
58 writer,
59 TokenBucket::new(&bucket_config, sleep_provider.now()),
60 config.wake_when_bytes_available,
61 sleep_provider,
62 )
63 }
64
65 /// Create a new [`RateLimitedWriter`] from a [`TokenBucket`].
66 ///
67 /// The token bucket must have only been used with times created by `sleep_provider`.
68 #[cfg_attr(test, visibility::make(pub(super)))]
69 fn from_token_bucket(
70 writer: W,
71 bucket: TokenBucket<Instant>,
72 wake_when_bytes_available: NonZero<u64>,
73 sleep_provider: P,
74 ) -> Self {
75 Self {
76 bucket,
77 sleep_provider,
78 wake_when_bytes_available,
79 inner: writer,
80 sleep_fut: None,
81 }
82 }
83
84 /// Access the inner [`AsyncWrite`] writer.
85 pub(crate) fn inner(&self) -> &W {
86 &self.inner
87 }
88
89 /// Adjust the refill rate and burst.
90 ///
91 /// A rate and/or burst of 0 is allowed.
92 pub(crate) fn adjust(
93 self: &mut Pin<&mut Self>,
94 now: Instant,
95 config: &RateLimitedWriterConfig,
96 ) {
97 let self_ = self.as_mut().project();
98
99 // destructuring allows us to make sure we aren't forgetting to handle any fields
100 let RateLimitedWriterConfig {
101 rate,
102 burst,
103 wake_when_bytes_available,
104 } = *config;
105
106 let bucket_config = TokenBucketConfig {
107 rate,
108 bucket_max: burst,
109 };
110
111 self_.bucket.adjust(now, &bucket_config);
112 *self_.wake_when_bytes_available = wake_when_bytes_available;
113 }
114
115 /// The sleep provider.
116 ///
117 /// We don't want this to be generally accessible, only to other token bucket-related modules
118 /// like [`DynamicRateLimitedWriter`](super::dynamic_writer::DynamicRateLimitedWriter).
119 pub(super) fn sleep_provider(&self) -> &P {
120 &self.sleep_provider
121 }
122
123 /// Configure this writer to sleep for `duration`.
124 ///
125 /// A `duration` of `None` is interpreted as "forever".
126 ///
127 /// It's considered a bug if asked to sleep for `Duration::ZERO` time.
128 fn register_sleep(
129 sleep_fut: &mut Pin<&mut Option<SyncFuture<P::SleepFuture>>>,
130 sleep_provider: &mut P,
131 cx: &mut Context<'_>,
132 duration: Option<Duration>,
133 ) -> Poll<()> {
134 match duration {
135 None => {
136 sleep_fut.as_mut().set(None);
137 Poll::Pending
138 }
139 Some(duration) => {
140 debug_assert_ne!(duration, Duration::ZERO, "asked to sleep for 0 time");
141 sleep_fut
142 .as_mut()
143 .set(Some(SyncFuture::new(sleep_provider.sleep(duration))));
144 sleep_fut
145 .as_mut()
146 .as_pin_mut()
147 .expect("but we just set it to `Some`?!")
148 .poll(cx)
149 }
150 }
151 }
152}
153
154impl<W, P> AsyncWrite for RateLimitedWriter<W, P>
155where
156 W: AsyncWrite,
157 P: SleepProvider,
158{
159 fn poll_write(
160 mut self: Pin<&mut Self>,
161 cx: &mut Context<'_>,
162 mut buf: &[u8],
163 ) -> Poll<Result<usize, Error>> {
164 let mut self_ = self.as_mut().project();
165
166 // this should be optimized to a no-op on at least x86-64
167 fn to_u64(x: usize) -> u64 {
168 x.try_into().expect("failed usize to u64 conversion")
169 }
170
171 // for an empty buffer, just defer to the inner writer's impl
172 if buf.is_empty() {
173 return self_.inner.poll_write(cx, buf);
174 }
175
176 let now = self_.sleep_provider.now();
177
178 // refill the bucket and attempt to claim all of the bytes
179 self_.bucket.refill(now);
180 let claim = self_.bucket.claim(to_u64(buf.len()));
181
182 let mut claim = match claim {
183 // claim was successful
184 Ok(x) => x,
185 // not enough tokens, so let's use a smaller buffer
186 Err(e) => {
187 let available = e.available_tokens();
188
189 // need to drop the old claim so that we can access the token bucket again
190 drop(claim);
191
192 // if no tokens in bucket, we must sleep
193 if available == 0 {
194 // number of tokens we'll wait for
195 let wake_at_tokens = to_u64(buf.len());
196
197 // If the user wants to write X tokens, we don't necessarily want to sleep until
198 // we have room for X tokens. We also don't want to wake every time that a
199 // single byte can be written. We allow the user to configure this threshold
200 // with `RateLimitedWriterConfig::wake_when_bytes_available`.
201 let wake_at_tokens =
202 std::cmp::min(wake_at_tokens, self_.wake_when_bytes_available.get());
203
204 // max number of tokens the bucket can hold
205 let bucket_max = self_.bucket.max();
206
207 // how long to sleep for; `None` indicates to sleep forever
208 let sleep_for = if bucket_max == 0 {
209 // bucket can't hold any tokens, so sleep forever
210 None
211 } else {
212 // if the bucket has a max of X tokens, we should never try to wait for >X
213 // tokens
214 let wake_at_tokens = std::cmp::min(wake_at_tokens, bucket_max);
215
216 // if we asked for 0 tokens, we'd get a time of ~now, which is not what we
217 // want
218 debug_assert!(wake_at_tokens > 0);
219
220 let wake_at = self_.bucket.tokens_available_at(wake_at_tokens);
221 let sleep_for = wake_at.map(|x| x.saturating_duration_since(now));
222
223 match sleep_for {
224 Ok(x) => Some(x),
225 Err(NeverEnoughTokensError::ExceedsMaxTokens) => {
226 panic!(
227 "exceeds max tokens, but we took the max into account above"
228 );
229 }
230 // we aren't refilling, so sleep forever
231 Err(NeverEnoughTokensError::ZeroRate) => None,
232 // too far in the future to be represented, so sleep forever
233 Err(NeverEnoughTokensError::InstantNotRepresentable) => None,
234 }
235 };
236
237 // configure the sleep future and poll it to register
238 let poll = Self::register_sleep(
239 &mut self_.sleep_fut,
240 self_.sleep_provider,
241 cx,
242 sleep_for,
243 );
244 return match poll {
245 // wait for the sleep to finish
246 Poll::Pending => Poll::Pending,
247 // The sleep is already ready?! A recursive call here isn't great, but
248 // there's not much else we can do here. Hopefully this second `poll_write`
249 // will succeed since we should now have enough tokens.
250 Poll::Ready(()) => self.poll_write(cx, buf),
251 };
252 }
253
254 /// Convert a `u64` to `usize`, saturating if size of `usize` is smaller than `u64`.
255 // This is a separate function to ensure we don't accidentally try to convert a
256 // signed integer into a `usize`, in which case `unwrap_or(MAX)` wouldn't make
257 // sense.
258 fn to_usize_saturating(x: u64) -> usize {
259 x.try_into().unwrap_or(usize::MAX)
260 }
261
262 // There are tokens, so try to write as many as are available.
263 let available_usize = to_usize_saturating(available);
264 buf = &buf[0..available_usize];
265 self_.bucket.claim(to_u64(buf.len())).unwrap_or_else(|_| {
266 panic!(
267 "bucket has {available} tokens available, but can't claim {}?",
268 buf.len(),
269 )
270 })
271 }
272 };
273
274 let rv = self_.inner.poll_write(cx, buf);
275
276 match rv {
277 // no bytes were written, so discard the claim
278 Poll::Pending | Poll::Ready(Err(_)) => claim.discard(),
279 // `x` bytes were written, so only commit those tokens
280 Poll::Ready(Ok(x)) => {
281 if x <= buf.len() {
282 claim
283 .reduce(to_u64(x))
284 .expect("can't commit fewer tokens?!");
285 claim.commit();
286 } else {
287 cfg_if::cfg_if! {
288 if #[cfg(debug_assertions)] {
289 panic!(
290 "Writer is claiming it wrote more bytes {x} than we gave it {}",
291 buf.len(),
292 );
293 } else {
294 // the best we can do is to just claim the original amount
295 claim.commit();
296 }
297 }
298 }
299 }
300 };
301
302 rv
303 }
304
305 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
306 self.project().inner.poll_flush(cx)
307 }
308
309 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
310 // some implementers of `AsyncWrite` (like `Vec`) don't do anything other than flush when
311 // closed and will continue to accept bytes even after being closed, so we must continue to
312 // apply rate limiting even after being closed
313 self.project().inner.poll_close(cx)
314 }
315}
316
317/// A module to make it easier to implement tokio traits without putting `cfg()` conditionals
318/// everywhere.
319#[cfg(feature = "tokio")]
320mod tokio_impl {
321 use super::*;
322
323 use tokio_crate::io::AsyncWrite as TokioAsyncWrite;
324 use tokio_util::compat::FuturesAsyncWriteCompatExt;
325
326 use std::io::Result as IoResult;
327
328 impl<W, P> TokioAsyncWrite for RateLimitedWriter<W, P>
329 where
330 W: AsyncWrite,
331 P: SleepProvider,
332 {
333 fn poll_write(
334 self: Pin<&mut Self>,
335 cx: &mut Context<'_>,
336 buf: &[u8],
337 ) -> Poll<IoResult<usize>> {
338 TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
339 }
340
341 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
342 TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
343 }
344
345 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
346 TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
347 }
348 }
349}
350
351/// The refill rate and burst for a [`RateLimitedWriter`].
352#[derive(Clone, Debug)]
353pub(crate) struct RateLimitedWriterConfig {
354 /// The refill rate in bytes/second.
355 pub(crate) rate: u64,
356 /// The "burst" in bytes.
357 pub(crate) burst: u64,
358 /// When polled, block until at most this many bytes are available.
359 ///
360 /// Or in other words, wake when we can write this many bytes, even if the provided buffer is
361 /// larger.
362 ///
363 /// For example if a user attempts to write a large buffer, we usually don't want to block until
364 /// the entire buffer can be written. We'd prefer several partial writes to a single large
365 /// write. So instead of blocking until the entire buffer can be written, we only block until
366 /// at most this many bytes are available.
367 pub(crate) wake_when_bytes_available: NonZero<u64>,
368}
369
370#[cfg(test)]
371mod test {
372 #![allow(clippy::unwrap_used)]
373
374 use super::*;
375
376 use futures::task::SpawnExt;
377 use futures::{AsyncWriteExt, FutureExt};
378
379 #[test]
380 fn writer() {
381 tor_rtmock::MockRuntime::test_with_various(|rt| async move {
382 let start = rt.now();
383
384 // increases 10 tokens/second (one every 100 ms)
385 let config = TokenBucketConfig {
386 rate: 10,
387 bucket_max: 100,
388 };
389 let mut tb = TokenBucket::new(&config, start);
390 // drain the bucket
391 tb.drain(100).unwrap();
392
393 let wake_when_bytes_available = NonZero::new(15).unwrap();
394
395 let mut writer = Vec::new();
396 let mut writer = RateLimitedWriter::from_token_bucket(
397 &mut writer,
398 tb,
399 wake_when_bytes_available,
400 rt.clone(),
401 );
402
403 // drive time forward from 0 to 20_000 ms in 50 ms intervals
404 let rt_clone = rt.clone();
405 rt.spawn(async move {
406 for _ in 0..400 {
407 rt_clone.progress_until_stalled().await;
408 rt_clone.advance_by(Duration::from_millis(50)).await;
409 }
410 })
411 .unwrap();
412
413 // try writing 60 bytes, which sleeps until we can write at least 15 of them
414 assert_eq!(15, writer.write(&[0; 60]).await.unwrap());
415 assert_eq!(1500, rt.now().duration_since(start).as_millis());
416
417 // wait 2 seconds
418 rt.sleep(Duration::from_millis(2000)).await;
419
420 // ensure that we can write immediately, and that we can write
421 // 2000 ms / (100 ms/token) = 20 bytes
422 assert_eq!(
423 Some(20),
424 writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
425 );
426 });
427 }
428
429 /// Test that writing to a token bucket which has a rate and/or max of 0 works as expected.
430 #[test]
431 fn rate_burst_zero() {
432 let configs = [
433 // non-zero rate, zero max
434 TokenBucketConfig {
435 rate: 10,
436 bucket_max: 0,
437 },
438 // zero rate, non-zero max
439 TokenBucketConfig {
440 rate: 0,
441 bucket_max: 10,
442 },
443 // zero rate, zero max
444 TokenBucketConfig {
445 rate: 0,
446 bucket_max: 0,
447 },
448 ];
449 for config in configs {
450 tor_rtmock::MockRuntime::test_with_various(|rt| {
451 let config = config.clone();
452 async move {
453 // an empty token bucket
454 let mut tb = TokenBucket::new(&config, rt.now());
455 tb.drain(tb.max()).unwrap();
456 assert!(tb.is_empty());
457
458 let wake_when_bytes_available = NonZero::new(2).unwrap();
459
460 let mut writer = Vec::new();
461 let mut writer = RateLimitedWriter::from_token_bucket(
462 &mut writer,
463 tb,
464 wake_when_bytes_available,
465 rt.clone(),
466 );
467
468 // drive time forward from 0 to 10_000 ms in 100 ms intervals
469 let rt_clone = rt.clone();
470 rt.spawn(async move {
471 for _ in 0..100 {
472 rt_clone.progress_until_stalled().await;
473 rt_clone.advance_by(Duration::from_millis(100)).await;
474 }
475 })
476 .unwrap();
477
478 // ensure that a write returns `Pending`
479 assert_eq!(
480 None,
481 writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
482 );
483
484 // wait 5 seconds
485 rt.sleep(Duration::from_millis(5000)).await;
486
487 // ensure that a write still returns `Pending`
488 assert_eq!(
489 None,
490 writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
491 );
492 }
493 });
494 }
495 }
496}