tor_async_utils/
stream_peek.rs

1//! [`StreamUnobtrusivePeeker`]
2//!
3//! The memory tracker needs a way to look at the next item of a stream
4//! (if there is one, or there can immediately be one),
5//! *without* getting involved with the async tasks.
6
7use educe::Educe;
8use futures::stream::FusedStream;
9use futures::task::noop_waker_ref;
10use futures::Stream;
11use pin_project::pin_project;
12
13use crate::peekable_stream::{PeekableStream, UnobtrusivePeekableStream};
14
15use std::fmt::Debug;
16use std::future::Future;
17use std::pin::Pin;
18use std::task::{Context, Poll, Poll::*, Waker};
19
20/// Wraps [`Stream`] and provides `\[poll_]peek` and `unobtrusive_peek`
21///
22/// [`unobtrusive_peek`](StreamUnobtrusivePeeker::unobtrusive_peek)
23/// is callable in sync contexts, outside the reading task.
24///
25/// Like [`futures::stream::Peekable`],
26/// this has an async `peek` method, and `poll_peek`,
27/// for use from the task that is also reading (via the [`Stream`] impl).
28/// But, that type doesn't have `unobtrusive_peek`.
29///
30/// One way to conceptualise this is that `StreamUnobtrusivePeeker` is dual-ported:
31/// the two sets of APIs, while provided on the same type,
32/// are typically called from different contexts.
33//
34// It wasn't particularly easy to think of a good name for this type.
35// We intend, probably:
36//     struct StreamUnobtrusivePeeker
37//     trait StreamUnobtrusivePeekable
38//     trait StreamPeekable (impl for StreamUnobtrusivePeeker and futures::stream::Peekable)
39//
40// Searching a thesaurus produced these suggested words:
41//     unobtrusive subtle discreet inconspicuous cautious furtive
42// Asking in MR review also suggested
43//     quick
44//
45// It's awkward because "peek" already has significant connotations of not disturbing things.
46// That's why it was used in Iterator::peek.
47//
48// But when we translate this into async context,
49// we have the poll_peek method on futures::stream::Peekable,
50// which doesn't remove items from the stream,
51// but *does* *wait* for items and therefore engages with the async context,
52// and therefore involves *mutating* the Peekable (to store the new waker).
53//
54// Now we end up needing a word for an *even less disturbing* kind of interaction.
55//
56// `quick` (and synonyms) isn't quite right either because it's not necessarily faster,
57// and certainly not more performant.
58#[derive(Debug)]
59#[pin_project(project = PeekerProj)]
60pub struct StreamUnobtrusivePeeker<S: Stream> {
61    /// An item that we have peeked.
62    ///
63    /// (If we peeked EOF, that's represented by `None` in inner.)
64    buffered: Option<S::Item>,
65
66    /// The `Waker` from the last time we were polled and returned `Pending`
67    ///
68    /// "polled" includes any of our `poll_` methods
69    /// but *not* `unobtrusive_peek`.
70    ///
71    /// `None` if we haven't been polled, or the last poll returned `Ready`.
72    poll_waker: Option<Waker>,
73
74    /// The inner stream
75    ///
76    /// `None if it has yielded `None` meaning EOF.  We don't require S: FusedStream.
77    #[pin]
78    inner: Option<S>,
79}
80
81impl<S: Stream> StreamUnobtrusivePeeker<S> {
82    /// Create a new `StreamUnobtrusivePeeker` from a `Stream`
83    pub fn new(inner: S) -> Self {
84        StreamUnobtrusivePeeker {
85            buffered: None,
86            poll_waker: None,
87            inner: Some(inner),
88        }
89    }
90}
91
92impl<S: Stream> UnobtrusivePeekableStream for StreamUnobtrusivePeeker<S> {
93    fn unobtrusive_peek_mut<'s>(mut self: Pin<&'s mut Self>) -> Option<&'s mut S::Item> {
94        #[allow(clippy::question_mark)] // We use explicit control flow here for clarity
95        if self.as_mut().project().buffered.is_none() {
96            // We don't have a buffered item, but the stream may have an item available.
97            // We must poll it to find out.
98            //
99            // We need to pass a Context to poll_next.
100            // inner may store this context, replacing one provided via poll_*.
101            //
102            // Despite that, we need to make sure that wakeups will happen as expected.
103            // To achieve this we have retained a copy of the caller's Waker.
104            //
105            // When a future or stream returns Pending, it proposes to wake `waker`
106            // when it wants to be polled again.
107            //
108            // We uphold that promise by
109            // - only returning Pending from our poll methods if inner also returned Pending
110            // - when one of our poll methods returns Pending, saving the caller-supplied
111            //   waker, so that we can make the intermediate poll call here.
112            //
113            // If the inner poll returns Ready, inner no longer guarantees to wake anyone.
114            // In principle, if our user is waiting (we returned Pending),
115            // then inner ought to have called `wake` on the caller's `Waker`.
116            // But I don't think we can guarantee that an executor won't defer a wakeup,
117            // and respond to a dropped Waker by cancelling that wakeup;
118            // or to put it another way, the wakeup might be "in flight" on entry,
119            // but the call to inner's poll_next returning Ready
120            // might somehow "cancel" the wakeup.
121            //
122            // So just to be sure, if we get a Ready here, we wake the stored waker.
123
124            let mut self_ = self.as_mut().project();
125
126            let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
127                return None;
128            };
129
130            let waker = if let Some(waker) = self_.poll_waker.as_ref() {
131                waker
132            } else {
133                noop_waker_ref()
134            };
135
136            match inner.poll_next(&mut Context::from_waker(waker)) {
137                Pending => {}
138                Ready(item_or_eof) => {
139                    if let Some(waker) = self_.poll_waker.take() {
140                        waker.wake();
141                    }
142                    match item_or_eof {
143                        None => self_.inner.set(None),
144                        Some(item) => *self_.buffered = Some(item),
145                    }
146                }
147            };
148        }
149
150        self.project().buffered.as_mut()
151    }
152}
153
154impl<S: Stream> PeekableStream for StreamUnobtrusivePeeker<S> {
155    fn poll_peek<'s>(self: Pin<&'s mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
156        self.impl_poll_next_or_peek(cx, |buffered| buffered.as_ref())
157    }
158
159    fn poll_peek_mut<'s>(
160        self: Pin<&'s mut Self>,
161        cx: &mut Context<'_>,
162    ) -> Poll<Option<&'s mut S::Item>> {
163        self.impl_poll_next_or_peek(cx, |buffered| buffered.as_mut())
164    }
165}
166
167impl<S: Stream> StreamUnobtrusivePeeker<S> {
168    /// Implementation of `poll_{peek,next}`
169    ///
170    /// This takes care of
171    ///   * examining the state of our buffer, and polling inner if needed
172    ///   * ensuring that we store a waker, if needed
173    ///   * dealing with some borrowck awkwardness
174    ///
175    /// The `Ready` value is always calculated from `buffer`.
176    /// `return_value_obtainer` is called only if we are going to return `Ready`.
177    /// It's given `buffer` and should either:
178    ///   * [`take`](Option::take) the contained value (for `poll_next`)
179    ///   * return a reference using [`Option::as_ref`] (for `poll_peek`)
180    fn impl_poll_next_or_peek<'s, R: 's>(
181        self: Pin<&'s mut Self>,
182        cx: &mut Context<'_>,
183        return_value_obtainer: impl FnOnce(&'s mut Option<S::Item>) -> Option<R>,
184    ) -> Poll<Option<R>> {
185        let mut self_ = self.project();
186        let r = Self::next_or_peek_inner(&mut self_, cx);
187        let r = r.map(|()| return_value_obtainer(self_.buffered));
188        Self::return_from_poll(self_.poll_waker, cx, r)
189    }
190
191    /// Try to populate `buffer`, and calculate if we're `Ready`
192    ///
193    /// Returns `Ready` iff `poll_next` or `poll_peek` should return `Ready`.
194    /// The actual `Ready` value (an `Option`) will be calculated later.
195    fn next_or_peek_inner(self_: &mut PeekerProj<S>, cx: &mut Context<'_>) -> Poll<()> {
196        if let Some(_item) = self_.buffered.as_ref() {
197            // `return_value_obtainer` will find `Some` in `buffered`;
198            // overall, we'll return `Ready(Some(..))`.
199            return Ready(());
200        }
201        let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
202            // `return_value_obtainer` will find `None` in `buffered`;
203            // overall, we'll return `Ready(None)`, ie EOF.
204            return Ready(());
205        };
206        match inner.poll_next(cx) {
207            Ready(None) => {
208                self_.inner.set(None);
209                // `buffered` is `None`, still.
210                // overall, we'll return `Ready(None)`, ie EOF.
211                Ready(())
212            }
213            Ready(Some(item)) => {
214                *self_.buffered = Some(item);
215                // return_value_obtainer` will find `Some` in `buffered`
216                Ready(())
217            }
218            Pending => {
219                // `return_value_obtainer` won't be called.
220                // overall, we'll return Pending
221                Pending
222            }
223        }
224    }
225
226    /// Wait for an item to be ready, and then inspect it
227    ///
228    /// Equivalent to [`futures::stream::Peekable::peek`].
229    ///
230    /// # Tasks, waking, and calling context
231    ///
232    /// This should be called by the task that is reading from the stream.
233    /// If it is called by another task, the reading task would miss notifications.
234    //
235    // This ^ docs section is triplicated for poll_peek, poll_peek_mut, and peek
236    //
237    // TODO this should be a method on the `PeekableStream` trait? Or a
238    // `PeekableStreamExt` trait?
239    // TODO should there be peek_mut ?
240    #[allow(dead_code)] // TODO remove this allow if and when we make this module public
241    pub fn peek(self: Pin<&mut Self>) -> PeekFuture<Self> {
242        PeekFuture { peeker: Some(self) }
243    }
244
245    /// Return from a `poll_*` function, setting the stored waker appropriately
246    ///
247    /// Our `poll` functions always use this.
248    /// The rule is that if a future returns `Pending`, it has stored the waker.
249    fn return_from_poll<R>(
250        poll_waker: &mut Option<Waker>,
251        cx: &mut Context<'_>,
252        r: Poll<R>,
253    ) -> Poll<R> {
254        *poll_waker = match &r {
255            Ready(_) => {
256                // No need to wake this task up any more.
257                None
258            }
259            Pending => {
260                // try_peek must use the same waker to poll later
261                Some(cx.waker().clone())
262            }
263        };
264        r
265    }
266
267    /// Obtain a raw reference to the inner stream
268    ///
269    /// ### Correctness!
270    ///
271    /// This method must be used with care!
272    /// Whatever you do mustn't interfere with polling and peeking.
273    /// Careless use can result in wrong behaviour including deadlocks.
274    pub fn as_raw_inner_pin_mut<'s>(self: Pin<&'s mut Self>) -> Option<Pin<&'s mut S>> {
275        self.project().inner.as_pin_mut()
276    }
277}
278
279impl<S: Stream> Stream for StreamUnobtrusivePeeker<S> {
280    type Item = S::Item;
281
282    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
283        self.impl_poll_next_or_peek(cx, |buffered| buffered.take())
284    }
285
286    fn size_hint(&self) -> (usize, Option<usize>) {
287        let buf = self.buffered.iter().count();
288        let (imin, imax) = match &self.inner {
289            Some(inner) => inner.size_hint(),
290            None => (0, Some(0)),
291        };
292        (imin + buf, imax.and_then(|imap| imap.checked_add(buf)))
293    }
294}
295
296impl<S: Stream> FusedStream for StreamUnobtrusivePeeker<S> {
297    fn is_terminated(&self) -> bool {
298        self.buffered.is_none() && self.inner.is_none()
299    }
300}
301
302/// Future from [`StreamUnobtrusivePeeker::peek`]
303// TODO: Move to tor_async_utils::peekable_stream.
304#[derive(Educe)]
305#[educe(Debug(bound("S: Debug")))]
306#[must_use = "peek() return a Future, which does nothing unless awaited"]
307pub struct PeekFuture<'s, S> {
308    /// The underlying stream.
309    ///
310    /// `Some` until we have returned `Ready`, then `None`.
311    /// See comment in `poll`.
312    peeker: Option<Pin<&'s mut S>>,
313}
314
315impl<'s, S: PeekableStream> PeekFuture<'s, S> {
316    /// Create a new `PeekFuture`.
317    // TODO: replace with a trait method.
318    pub fn new(stream: Pin<&'s mut S>) -> Self {
319        Self {
320            peeker: Some(stream),
321        }
322    }
323}
324
325impl<'s, S: PeekableStream> Future for PeekFuture<'s, S> {
326    type Output = Option<&'s S::Item>;
327    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
328        let self_ = self.get_mut();
329        let peeker = self_
330            .peeker
331            .as_mut()
332            .expect("PeekFuture polled after Ready");
333        match peeker.as_mut().poll_peek(cx) {
334            Pending => return Pending,
335            Ready(_y) => {
336                // Ideally we would have returned `y` here, but it's borrowed from PeekFuture
337                // not from the original StreamUnobtrusivePeeker, and there's no way
338                // to get a value with the right lifetime.  (In non-async code,
339                // this is usually handled by the special magic for reborrowing &mut.)
340                //
341                // So we must redo the poll, but this time consuming `peeker`,
342                // which gets us the right lifetime.  That's why it has to be `Option`.
343                // Because we own &mut ... Self, we know that repeating the poll
344                // gives the same answer.
345            }
346        }
347        let peeker = self_.peeker.take().expect("it was Some before!");
348        let r = peeker.poll_peek(cx);
349        assert!(r.is_ready(), "it was Ready before!");
350        r
351    }
352}
353
354#[cfg(test)]
355mod test {
356    // @@ begin test lint list maintained by maint/add_warning @@
357    #![allow(clippy::bool_assert_comparison)]
358    #![allow(clippy::clone_on_copy)]
359    #![allow(clippy::dbg_macro)]
360    #![allow(clippy::mixed_attributes_style)]
361    #![allow(clippy::print_stderr)]
362    #![allow(clippy::print_stdout)]
363    #![allow(clippy::single_char_pattern)]
364    #![allow(clippy::unwrap_used)]
365    #![allow(clippy::unchecked_duration_subtraction)]
366    #![allow(clippy::useless_vec)]
367    #![allow(clippy::needless_pass_by_value)]
368    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
369
370    use super::*;
371    use futures::channel::mpsc;
372    use futures::{SinkExt as _, StreamExt as _};
373    use std::pin::pin;
374    use std::sync::{Arc, Mutex};
375    use std::time::Duration;
376    use tor_rtcompat::SleepProvider as _;
377    use tor_rtmock::MockRuntime;
378
379    fn ms(ms: u64) -> Duration {
380        Duration::from_millis(ms)
381    }
382
383    #[test]
384    fn wakeups() {
385        MockRuntime::test_with_various(|rt| async move {
386            let (mut tx, rx) = mpsc::unbounded();
387            let ended = Arc::new(Mutex::new(false));
388
389            rt.spawn_identified("rxr", {
390                let rt = rt.clone();
391                let ended = ended.clone();
392
393                async move {
394                    let rx = StreamUnobtrusivePeeker::new(rx);
395                    let mut rx = pin!(rx);
396
397                    let mut next = 0;
398                    loop {
399                        rt.sleep(ms(50)).await;
400                        eprintln!("rx peek... ");
401                        let peeked = rx.as_mut().unobtrusive_peek_mut();
402                        eprintln!("rx peeked {peeked:?}");
403
404                        if let Some(peeked) = peeked {
405                            assert_eq!(*peeked, next);
406                        }
407
408                        rt.sleep(ms(50)).await;
409                        eprintln!("rx next... ");
410                        let eaten = rx.next().await;
411                        eprintln!("rx eaten {eaten:?}");
412                        if let Some(eaten) = eaten {
413                            assert_eq!(eaten, next);
414                            next += 1;
415                        } else {
416                            break;
417                        }
418                    }
419
420                    *ended.lock().unwrap() = true;
421                    eprintln!("rx ended");
422                }
423            });
424
425            rt.spawn_identified("tx", {
426                let rt = rt.clone();
427
428                async move {
429                    let mut numbers = 0..;
430                    for wait in [125, 1, 125, 45, 1, 1, 1, 1000, 20, 1, 125, 125, 1000] {
431                        eprintln!("tx sleep {wait}");
432                        rt.sleep(ms(wait)).await;
433                        let num = numbers.next().unwrap();
434                        eprintln!("tx sending {num}");
435                        tx.send(num).await.unwrap();
436                    }
437
438                    // This schedule arranges that, when we send EOF, the rx task
439                    // has *peeked* rather than *polled* most recently,
440                    // demonstrating that we can wake up the subsequent poll on EOF too.
441                    eprintln!("tx final #1");
442                    rt.sleep(ms(75)).await;
443                    eprintln!("tx EOF");
444                    drop(tx);
445                    eprintln!("tx final #2");
446                    rt.sleep(ms(10)).await;
447                    assert!(!*ended.lock().unwrap());
448                    eprintln!("tx final #3");
449                    rt.sleep(ms(50)).await;
450                    eprintln!("tx final #4");
451                    assert!(*ended.lock().unwrap());
452                }
453            });
454
455            rt.advance_until_stalled().await;
456        });
457    }
458
459    #[test]
460    fn poll_peek_paths() {
461        MockRuntime::test_with_various(|rt| async move {
462            let (mut tx, rx) = mpsc::unbounded();
463            let ended = Arc::new(Mutex::new(false));
464
465            rt.spawn_identified("rxr", {
466                let rt = rt.clone();
467                let ended = ended.clone();
468
469                async move {
470                    let rx = StreamUnobtrusivePeeker::new(rx);
471                    let mut rx = pin!(rx);
472
473                    while let Some(peeked) = rx.as_mut().peek().await.copied() {
474                        eprintln!("rx peeked {peeked}");
475                        let eaten = rx.next().await.unwrap();
476                        eprintln!("rx eaten  {eaten}");
477                        assert_eq!(peeked, eaten);
478                        rt.sleep(ms(10)).await;
479                        eprintln!("rx slept, peeking");
480                    }
481                    *ended.lock().unwrap() = true;
482                    eprintln!("rx ended");
483                }
484            });
485
486            rt.spawn_identified("tx", {
487                let rt = rt.clone();
488
489                async move {
490                    let mut numbers = 0..;
491
492                    // macro because we don't have proper async closures
493                    macro_rules! send { {} => {
494                        let num = numbers.next().unwrap();
495                        eprintln!("tx send   {num}");
496                        tx.send(num).await.unwrap();
497                    } }
498
499                    eprintln!("tx starting");
500                    rt.sleep(ms(100)).await;
501                    send!();
502                    rt.sleep(ms(100)).await;
503                    send!();
504                    send!();
505                    rt.sleep(ms(100)).await;
506                    eprintln!("tx dropping");
507                    drop(tx);
508                    rt.sleep(ms(5)).await;
509                    eprintln!("tx ending");
510                    assert!(*ended.lock().unwrap());
511                }
512            });
513
514            rt.advance_until_stalled().await;
515        });
516    }
517}