tor_async_utils/stream_peek.rs
1//! [`StreamUnobtrusivePeeker`]
2//!
3//! The memory tracker needs a way to look at the next item of a stream
4//! (if there is one, or there can immediately be one),
5//! *without* getting involved with the async tasks.
6
7use educe::Educe;
8use futures::stream::FusedStream;
9use futures::task::noop_waker_ref;
10use futures::Stream;
11use pin_project::pin_project;
12
13use crate::peekable_stream::{PeekableStream, UnobtrusivePeekableStream};
14
15use std::fmt::Debug;
16use std::future::Future;
17use std::pin::Pin;
18use std::task::{Context, Poll, Poll::*, Waker};
19
20/// Wraps [`Stream`] and provides `\[poll_]peek` and `unobtrusive_peek`
21///
22/// [`unobtrusive_peek`](StreamUnobtrusivePeeker::unobtrusive_peek)
23/// is callable in sync contexts, outside the reading task.
24///
25/// Like [`futures::stream::Peekable`],
26/// this has an async `peek` method, and `poll_peek`,
27/// for use from the task that is also reading (via the [`Stream`] impl).
28/// But, that type doesn't have `unobtrusive_peek`.
29///
30/// One way to conceptualise this is that `StreamUnobtrusivePeeker` is dual-ported:
31/// the two sets of APIs, while provided on the same type,
32/// are typically called from different contexts.
33//
34// It wasn't particularly easy to think of a good name for this type.
35// We intend, probably:
36// struct StreamUnobtrusivePeeker
37// trait StreamUnobtrusivePeekable
38// trait StreamPeekable (impl for StreamUnobtrusivePeeker and futures::stream::Peekable)
39//
40// Searching a thesaurus produced these suggested words:
41// unobtrusive subtle discreet inconspicuous cautious furtive
42// Asking in MR review also suggested
43// quick
44//
45// It's awkward because "peek" already has significant connotations of not disturbing things.
46// That's why it was used in Iterator::peek.
47//
48// But when we translate this into async context,
49// we have the poll_peek method on futures::stream::Peekable,
50// which doesn't remove items from the stream,
51// but *does* *wait* for items and therefore engages with the async context,
52// and therefore involves *mutating* the Peekable (to store the new waker).
53//
54// Now we end up needing a word for an *even less disturbing* kind of interaction.
55//
56// `quick` (and synonyms) isn't quite right either because it's not necessarily faster,
57// and certainly not more performant.
58#[derive(Debug)]
59#[pin_project(project = PeekerProj)]
60pub struct StreamUnobtrusivePeeker<S: Stream> {
61 /// An item that we have peeked.
62 ///
63 /// (If we peeked EOF, that's represented by `None` in inner.)
64 buffered: Option<S::Item>,
65
66 /// The `Waker` from the last time we were polled and returned `Pending`
67 ///
68 /// "polled" includes any of our `poll_` methods
69 /// but *not* `unobtrusive_peek`.
70 ///
71 /// `None` if we haven't been polled, or the last poll returned `Ready`.
72 poll_waker: Option<Waker>,
73
74 /// The inner stream
75 ///
76 /// `None if it has yielded `None` meaning EOF. We don't require S: FusedStream.
77 #[pin]
78 inner: Option<S>,
79}
80
81impl<S: Stream> StreamUnobtrusivePeeker<S> {
82 /// Create a new `StreamUnobtrusivePeeker` from a `Stream`
83 pub fn new(inner: S) -> Self {
84 StreamUnobtrusivePeeker {
85 buffered: None,
86 poll_waker: None,
87 inner: Some(inner),
88 }
89 }
90}
91
92impl<S: Stream> UnobtrusivePeekableStream for StreamUnobtrusivePeeker<S> {
93 fn unobtrusive_peek_mut<'s>(mut self: Pin<&'s mut Self>) -> Option<&'s mut S::Item> {
94 #[allow(clippy::question_mark)] // We use explicit control flow here for clarity
95 if self.as_mut().project().buffered.is_none() {
96 // We don't have a buffered item, but the stream may have an item available.
97 // We must poll it to find out.
98 //
99 // We need to pass a Context to poll_next.
100 // inner may store this context, replacing one provided via poll_*.
101 //
102 // Despite that, we need to make sure that wakeups will happen as expected.
103 // To achieve this we have retained a copy of the caller's Waker.
104 //
105 // When a future or stream returns Pending, it proposes to wake `waker`
106 // when it wants to be polled again.
107 //
108 // We uphold that promise by
109 // - only returning Pending from our poll methods if inner also returned Pending
110 // - when one of our poll methods returns Pending, saving the caller-supplied
111 // waker, so that we can make the intermediate poll call here.
112 //
113 // If the inner poll returns Ready, inner no longer guarantees to wake anyone.
114 // In principle, if our user is waiting (we returned Pending),
115 // then inner ought to have called `wake` on the caller's `Waker`.
116 // But I don't think we can guarantee that an executor won't defer a wakeup,
117 // and respond to a dropped Waker by cancelling that wakeup;
118 // or to put it another way, the wakeup might be "in flight" on entry,
119 // but the call to inner's poll_next returning Ready
120 // might somehow "cancel" the wakeup.
121 //
122 // So just to be sure, if we get a Ready here, we wake the stored waker.
123
124 let mut self_ = self.as_mut().project();
125
126 let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
127 return None;
128 };
129
130 let waker = if let Some(waker) = self_.poll_waker.as_ref() {
131 waker
132 } else {
133 noop_waker_ref()
134 };
135
136 match inner.poll_next(&mut Context::from_waker(waker)) {
137 Pending => {}
138 Ready(item_or_eof) => {
139 if let Some(waker) = self_.poll_waker.take() {
140 waker.wake();
141 }
142 match item_or_eof {
143 None => self_.inner.set(None),
144 Some(item) => *self_.buffered = Some(item),
145 }
146 }
147 };
148 }
149
150 self.project().buffered.as_mut()
151 }
152}
153
154impl<S: Stream> PeekableStream for StreamUnobtrusivePeeker<S> {
155 fn poll_peek<'s>(self: Pin<&'s mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
156 self.impl_poll_next_or_peek(cx, |buffered| buffered.as_ref())
157 }
158
159 fn poll_peek_mut<'s>(
160 self: Pin<&'s mut Self>,
161 cx: &mut Context<'_>,
162 ) -> Poll<Option<&'s mut S::Item>> {
163 self.impl_poll_next_or_peek(cx, |buffered| buffered.as_mut())
164 }
165}
166
167impl<S: Stream> StreamUnobtrusivePeeker<S> {
168 /// Implementation of `poll_{peek,next}`
169 ///
170 /// This takes care of
171 /// * examining the state of our buffer, and polling inner if needed
172 /// * ensuring that we store a waker, if needed
173 /// * dealing with some borrowck awkwardness
174 ///
175 /// The `Ready` value is always calculated from `buffer`.
176 /// `return_value_obtainer` is called only if we are going to return `Ready`.
177 /// It's given `buffer` and should either:
178 /// * [`take`](Option::take) the contained value (for `poll_next`)
179 /// * return a reference using [`Option::as_ref`] (for `poll_peek`)
180 fn impl_poll_next_or_peek<'s, R: 's>(
181 self: Pin<&'s mut Self>,
182 cx: &mut Context<'_>,
183 return_value_obtainer: impl FnOnce(&'s mut Option<S::Item>) -> Option<R>,
184 ) -> Poll<Option<R>> {
185 let mut self_ = self.project();
186 let r = Self::next_or_peek_inner(&mut self_, cx);
187 let r = r.map(|()| return_value_obtainer(self_.buffered));
188 Self::return_from_poll(self_.poll_waker, cx, r)
189 }
190
191 /// Try to populate `buffer`, and calculate if we're `Ready`
192 ///
193 /// Returns `Ready` iff `poll_next` or `poll_peek` should return `Ready`.
194 /// The actual `Ready` value (an `Option`) will be calculated later.
195 fn next_or_peek_inner(self_: &mut PeekerProj<S>, cx: &mut Context<'_>) -> Poll<()> {
196 if let Some(_item) = self_.buffered.as_ref() {
197 // `return_value_obtainer` will find `Some` in `buffered`;
198 // overall, we'll return `Ready(Some(..))`.
199 return Ready(());
200 }
201 let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
202 // `return_value_obtainer` will find `None` in `buffered`;
203 // overall, we'll return `Ready(None)`, ie EOF.
204 return Ready(());
205 };
206 match inner.poll_next(cx) {
207 Ready(None) => {
208 self_.inner.set(None);
209 // `buffered` is `None`, still.
210 // overall, we'll return `Ready(None)`, ie EOF.
211 Ready(())
212 }
213 Ready(Some(item)) => {
214 *self_.buffered = Some(item);
215 // return_value_obtainer` will find `Some` in `buffered`
216 Ready(())
217 }
218 Pending => {
219 // `return_value_obtainer` won't be called.
220 // overall, we'll return Pending
221 Pending
222 }
223 }
224 }
225
226 /// Wait for an item to be ready, and then inspect it
227 ///
228 /// Equivalent to [`futures::stream::Peekable::peek`].
229 ///
230 /// # Tasks, waking, and calling context
231 ///
232 /// This should be called by the task that is reading from the stream.
233 /// If it is called by another task, the reading task would miss notifications.
234 //
235 // This ^ docs section is triplicated for poll_peek, poll_peek_mut, and peek
236 //
237 // TODO this should be a method on the `PeekableStream` trait? Or a
238 // `PeekableStreamExt` trait?
239 // TODO should there be peek_mut ?
240 #[allow(dead_code)] // TODO remove this allow if and when we make this module public
241 pub fn peek(self: Pin<&mut Self>) -> PeekFuture<Self> {
242 PeekFuture { peeker: Some(self) }
243 }
244
245 /// Return from a `poll_*` function, setting the stored waker appropriately
246 ///
247 /// Our `poll` functions always use this.
248 /// The rule is that if a future returns `Pending`, it has stored the waker.
249 fn return_from_poll<R>(
250 poll_waker: &mut Option<Waker>,
251 cx: &mut Context<'_>,
252 r: Poll<R>,
253 ) -> Poll<R> {
254 *poll_waker = match &r {
255 Ready(_) => {
256 // No need to wake this task up any more.
257 None
258 }
259 Pending => {
260 // try_peek must use the same waker to poll later
261 Some(cx.waker().clone())
262 }
263 };
264 r
265 }
266
267 /// Obtain a raw reference to the inner stream
268 ///
269 /// ### Correctness!
270 ///
271 /// This method must be used with care!
272 /// Whatever you do mustn't interfere with polling and peeking.
273 /// Careless use can result in wrong behaviour including deadlocks.
274 pub fn as_raw_inner_pin_mut<'s>(self: Pin<&'s mut Self>) -> Option<Pin<&'s mut S>> {
275 self.project().inner.as_pin_mut()
276 }
277}
278
279impl<S: Stream> Stream for StreamUnobtrusivePeeker<S> {
280 type Item = S::Item;
281
282 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
283 self.impl_poll_next_or_peek(cx, |buffered| buffered.take())
284 }
285
286 fn size_hint(&self) -> (usize, Option<usize>) {
287 let buf = self.buffered.iter().count();
288 let (imin, imax) = match &self.inner {
289 Some(inner) => inner.size_hint(),
290 None => (0, Some(0)),
291 };
292 (imin + buf, imax.and_then(|imap| imap.checked_add(buf)))
293 }
294}
295
296impl<S: Stream> FusedStream for StreamUnobtrusivePeeker<S> {
297 fn is_terminated(&self) -> bool {
298 self.buffered.is_none() && self.inner.is_none()
299 }
300}
301
302/// Future from [`StreamUnobtrusivePeeker::peek`]
303// TODO: Move to tor_async_utils::peekable_stream.
304#[derive(Educe)]
305#[educe(Debug(bound("S: Debug")))]
306#[must_use = "peek() return a Future, which does nothing unless awaited"]
307pub struct PeekFuture<'s, S> {
308 /// The underlying stream.
309 ///
310 /// `Some` until we have returned `Ready`, then `None`.
311 /// See comment in `poll`.
312 peeker: Option<Pin<&'s mut S>>,
313}
314
315impl<'s, S: PeekableStream> PeekFuture<'s, S> {
316 /// Create a new `PeekFuture`.
317 // TODO: replace with a trait method.
318 pub fn new(stream: Pin<&'s mut S>) -> Self {
319 Self {
320 peeker: Some(stream),
321 }
322 }
323}
324
325impl<'s, S: PeekableStream> Future for PeekFuture<'s, S> {
326 type Output = Option<&'s S::Item>;
327 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
328 let self_ = self.get_mut();
329 let peeker = self_
330 .peeker
331 .as_mut()
332 .expect("PeekFuture polled after Ready");
333 match peeker.as_mut().poll_peek(cx) {
334 Pending => return Pending,
335 Ready(_y) => {
336 // Ideally we would have returned `y` here, but it's borrowed from PeekFuture
337 // not from the original StreamUnobtrusivePeeker, and there's no way
338 // to get a value with the right lifetime. (In non-async code,
339 // this is usually handled by the special magic for reborrowing &mut.)
340 //
341 // So we must redo the poll, but this time consuming `peeker`,
342 // which gets us the right lifetime. That's why it has to be `Option`.
343 // Because we own &mut ... Self, we know that repeating the poll
344 // gives the same answer.
345 }
346 }
347 let peeker = self_.peeker.take().expect("it was Some before!");
348 let r = peeker.poll_peek(cx);
349 assert!(r.is_ready(), "it was Ready before!");
350 r
351 }
352}
353
354#[cfg(test)]
355mod test {
356 // @@ begin test lint list maintained by maint/add_warning @@
357 #![allow(clippy::bool_assert_comparison)]
358 #![allow(clippy::clone_on_copy)]
359 #![allow(clippy::dbg_macro)]
360 #![allow(clippy::mixed_attributes_style)]
361 #![allow(clippy::print_stderr)]
362 #![allow(clippy::print_stdout)]
363 #![allow(clippy::single_char_pattern)]
364 #![allow(clippy::unwrap_used)]
365 #![allow(clippy::unchecked_duration_subtraction)]
366 #![allow(clippy::useless_vec)]
367 #![allow(clippy::needless_pass_by_value)]
368 //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
369
370 use super::*;
371 use futures::channel::mpsc;
372 use futures::{SinkExt as _, StreamExt as _};
373 use std::pin::pin;
374 use std::sync::{Arc, Mutex};
375 use std::time::Duration;
376 use tor_rtcompat::SleepProvider as _;
377 use tor_rtmock::MockRuntime;
378
379 fn ms(ms: u64) -> Duration {
380 Duration::from_millis(ms)
381 }
382
383 #[test]
384 fn wakeups() {
385 MockRuntime::test_with_various(|rt| async move {
386 let (mut tx, rx) = mpsc::unbounded();
387 let ended = Arc::new(Mutex::new(false));
388
389 rt.spawn_identified("rxr", {
390 let rt = rt.clone();
391 let ended = ended.clone();
392
393 async move {
394 let rx = StreamUnobtrusivePeeker::new(rx);
395 let mut rx = pin!(rx);
396
397 let mut next = 0;
398 loop {
399 rt.sleep(ms(50)).await;
400 eprintln!("rx peek... ");
401 let peeked = rx.as_mut().unobtrusive_peek_mut();
402 eprintln!("rx peeked {peeked:?}");
403
404 if let Some(peeked) = peeked {
405 assert_eq!(*peeked, next);
406 }
407
408 rt.sleep(ms(50)).await;
409 eprintln!("rx next... ");
410 let eaten = rx.next().await;
411 eprintln!("rx eaten {eaten:?}");
412 if let Some(eaten) = eaten {
413 assert_eq!(eaten, next);
414 next += 1;
415 } else {
416 break;
417 }
418 }
419
420 *ended.lock().unwrap() = true;
421 eprintln!("rx ended");
422 }
423 });
424
425 rt.spawn_identified("tx", {
426 let rt = rt.clone();
427
428 async move {
429 let mut numbers = 0..;
430 for wait in [125, 1, 125, 45, 1, 1, 1, 1000, 20, 1, 125, 125, 1000] {
431 eprintln!("tx sleep {wait}");
432 rt.sleep(ms(wait)).await;
433 let num = numbers.next().unwrap();
434 eprintln!("tx sending {num}");
435 tx.send(num).await.unwrap();
436 }
437
438 // This schedule arranges that, when we send EOF, the rx task
439 // has *peeked* rather than *polled* most recently,
440 // demonstrating that we can wake up the subsequent poll on EOF too.
441 eprintln!("tx final #1");
442 rt.sleep(ms(75)).await;
443 eprintln!("tx EOF");
444 drop(tx);
445 eprintln!("tx final #2");
446 rt.sleep(ms(10)).await;
447 assert!(!*ended.lock().unwrap());
448 eprintln!("tx final #3");
449 rt.sleep(ms(50)).await;
450 eprintln!("tx final #4");
451 assert!(*ended.lock().unwrap());
452 }
453 });
454
455 rt.advance_until_stalled().await;
456 });
457 }
458
459 #[test]
460 fn poll_peek_paths() {
461 MockRuntime::test_with_various(|rt| async move {
462 let (mut tx, rx) = mpsc::unbounded();
463 let ended = Arc::new(Mutex::new(false));
464
465 rt.spawn_identified("rxr", {
466 let rt = rt.clone();
467 let ended = ended.clone();
468
469 async move {
470 let rx = StreamUnobtrusivePeeker::new(rx);
471 let mut rx = pin!(rx);
472
473 while let Some(peeked) = rx.as_mut().peek().await.copied() {
474 eprintln!("rx peeked {peeked}");
475 let eaten = rx.next().await.unwrap();
476 eprintln!("rx eaten {eaten}");
477 assert_eq!(peeked, eaten);
478 rt.sleep(ms(10)).await;
479 eprintln!("rx slept, peeking");
480 }
481 *ended.lock().unwrap() = true;
482 eprintln!("rx ended");
483 }
484 });
485
486 rt.spawn_identified("tx", {
487 let rt = rt.clone();
488
489 async move {
490 let mut numbers = 0..;
491
492 // macro because we don't have proper async closures
493 macro_rules! send { {} => {
494 let num = numbers.next().unwrap();
495 eprintln!("tx send {num}");
496 tx.send(num).await.unwrap();
497 } }
498
499 eprintln!("tx starting");
500 rt.sleep(ms(100)).await;
501 send!();
502 rt.sleep(ms(100)).await;
503 send!();
504 send!();
505 rt.sleep(ms(100)).await;
506 eprintln!("tx dropping");
507 drop(tx);
508 rt.sleep(ms(5)).await;
509 eprintln!("tx ending");
510 assert!(*ended.lock().unwrap());
511 }
512 });
513
514 rt.advance_until_stalled().await;
515 });
516 }
517}