1
//! [`StreamUnobtrusivePeeker`]
2
//!
3
//! The memory tracker needs a way to look at the next item of a stream
4
//! (if there is one, or there can immediately be one),
5
//! *without* getting involved with the async tasks.
6

            
7
use educe::Educe;
8
use futures::stream::FusedStream;
9
use futures::task::noop_waker_ref;
10
use futures::Stream;
11
use pin_project::pin_project;
12

            
13
use crate::peekable_stream::{PeekableStream, UnobtrusivePeekableStream};
14

            
15
use std::fmt::Debug;
16
use std::future::Future;
17
use std::pin::Pin;
18
use std::task::{Context, Poll, Poll::*, Waker};
19

            
20
/// Wraps [`Stream`] and provides `\[poll_]peek` and `unobtrusive_peek`
21
///
22
/// [`unobtrusive_peek`](StreamUnobtrusivePeeker::unobtrusive_peek)
23
/// is callable in sync contexts, outside the reading task.
24
///
25
/// Like [`futures::stream::Peekable`],
26
/// this has an async `peek` method, and `poll_peek`,
27
/// for use from the task that is also reading (via the [`Stream`] impl).
28
/// But, that type doesn't have `unobtrusive_peek`.
29
///
30
/// One way to conceptualise this is that `StreamUnobtrusivePeeker` is dual-ported:
31
/// the two sets of APIs, while provided on the same type,
32
/// are typically called from different contexts.
33
//
34
// It wasn't particularly easy to think of a good name for this type.
35
// We intend, probably:
36
//     struct StreamUnobtrusivePeeker
37
//     trait StreamUnobtrusivePeekable
38
//     trait StreamPeekable (impl for StreamUnobtrusivePeeker and futures::stream::Peekable)
39
//
40
// Searching a thesaurus produced these suggested words:
41
//     unobtrusive subtle discreet inconspicuous cautious furtive
42
// Asking in MR review also suggested
43
//     quick
44
//
45
// It's awkward because "peek" already has significant connotations of not disturbing things.
46
// That's why it was used in Iterator::peek.
47
//
48
// But when we translate this into async context,
49
// we have the poll_peek method on futures::stream::Peekable,
50
// which doesn't remove items from the stream,
51
// but *does* *wait* for items and therefore engages with the async context,
52
// and therefore involves *mutating* the Peekable (to store the new waker).
53
//
54
// Now we end up needing a word for an *even less disturbing* kind of interaction.
55
//
56
// `quick` (and synonyms) isn't quite right either because it's not necessarily faster,
57
// and certainly not more performant.
58
#[derive(Debug)]
59
31720
#[pin_project(project = PeekerProj)]
60
pub struct StreamUnobtrusivePeeker<S: Stream> {
61
    /// An item that we have peeked.
62
    ///
63
    /// (If we peeked EOF, that's represented by `None` in inner.)
64
    buffered: Option<S::Item>,
65

            
66
    /// The `Waker` from the last time we were polled and returned `Pending`
67
    ///
68
    /// "polled" includes any of our `poll_` methods
69
    /// but *not* `unobtrusive_peek`.
70
    ///
71
    /// `None` if we haven't been polled, or the last poll returned `Ready`.
72
    poll_waker: Option<Waker>,
73

            
74
    /// The inner stream
75
    ///
76
    /// `None if it has yielded `None` meaning EOF.  We don't require S: FusedStream.
77
    #[pin]
78
    inner: Option<S>,
79
}
80

            
81
impl<S: Stream> StreamUnobtrusivePeeker<S> {
82
    /// Create a new `StreamUnobtrusivePeeker` from a `Stream`
83
2257
    pub fn new(inner: S) -> Self {
84
2257
        StreamUnobtrusivePeeker {
85
2257
            buffered: None,
86
2257
            poll_waker: None,
87
2257
            inner: Some(inner),
88
2257
        }
89
2257
    }
90
}
91

            
92
impl<S: Stream> UnobtrusivePeekableStream for StreamUnobtrusivePeeker<S> {
93
2848
    fn unobtrusive_peek_mut<'s>(mut self: Pin<&'s mut Self>) -> Option<&'s mut S::Item> {
94
2848
        #[allow(clippy::question_mark)] // We use explicit control flow here for clarity
95
2848
        if self.as_mut().project().buffered.is_none() {
96
            // We don't have a buffered item, but the stream may have an item available.
97
            // We must poll it to find out.
98
            //
99
            // We need to pass a Context to poll_next.
100
            // inner may store this context, replacing one provided via poll_*.
101
            //
102
            // Despite that, we need to make sure that wakeups will happen as expected.
103
            // To achieve this we have retained a copy of the caller's Waker.
104
            //
105
            // When a future or stream returns Pending, it proposes to wake `waker`
106
            // when it wants to be polled again.
107
            //
108
            // We uphold that promise by
109
            // - only returning Pending from our poll methods if inner also returned Pending
110
            // - when one of our poll methods returns Pending, saving the caller-supplied
111
            //   waker, so that we can make the intermediate poll call here.
112
            //
113
            // If the inner poll returns Ready, inner no longer guarantees to wake anyone.
114
            // In principle, if our user is waiting (we returned Pending),
115
            // then inner ought to have called `wake` on the caller's `Waker`.
116
            // But I don't think we can guarantee that an executor won't defer a wakeup,
117
            // and respond to a dropped Waker by cancelling that wakeup;
118
            // or to put it another way, the wakeup might be "in flight" on entry,
119
            // but the call to inner's poll_next returning Ready
120
            // might somehow "cancel" the wakeup.
121
            //
122
            // So just to be sure, if we get a Ready here, we wake the stored waker.
123

            
124
104
            let mut self_ = self.as_mut().project();
125

            
126
104
            let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
127
32
                return None;
128
            };
129

            
130
72
            let waker = if let Some(waker) = self_.poll_waker.as_ref() {
131
                waker
132
            } else {
133
72
                noop_waker_ref()
134
            };
135

            
136
72
            match inner.poll_next(&mut Context::from_waker(waker)) {
137
24
                Pending => {}
138
48
                Ready(item_or_eof) => {
139
48
                    if let Some(waker) = self_.poll_waker.take() {
140
                        waker.wake();
141
48
                    }
142
48
                    match item_or_eof {
143
                        None => self_.inner.set(None),
144
48
                        Some(item) => *self_.buffered = Some(item),
145
                    }
146
                }
147
            };
148
2744
        }
149

            
150
2816
        self.project().buffered.as_mut()
151
2848
    }
152
}
153

            
154
impl<S: Stream> PeekableStream for StreamUnobtrusivePeeker<S> {
155
44
    fn poll_peek<'s>(self: Pin<&'s mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
156
44
        self.impl_poll_next_or_peek(cx, |buffered| buffered.as_ref())
157
44
    }
158

            
159
8370
    fn poll_peek_mut<'s>(
160
8370
        self: Pin<&'s mut Self>,
161
8370
        cx: &mut Context<'_>,
162
8370
    ) -> Poll<Option<&'s mut S::Item>> {
163
8370
        self.impl_poll_next_or_peek(cx, |buffered| buffered.as_mut())
164
8370
    }
165
}
166

            
167
impl<S: Stream> StreamUnobtrusivePeeker<S> {
168
    /// Implementation of `poll_{peek,next}`
169
    ///
170
    /// This takes care of
171
    ///   * examining the state of our buffer, and polling inner if needed
172
    ///   * ensuring that we store a waker, if needed
173
    ///   * dealing with some borrowck awkwardness
174
    ///
175
    /// The `Ready` value is always calculated from `buffer`.
176
    /// `return_value_obtainer` is called only if we are going to return `Ready`.
177
    /// It's given `buffer` and should either:
178
    ///   * [`take`](Option::take) the contained value (for `poll_next`)
179
    ///   * return a reference using [`Option::as_ref`] (for `poll_peek`)
180
24039
    fn impl_poll_next_or_peek<'s, R: 's>(
181
24039
        self: Pin<&'s mut Self>,
182
24039
        cx: &mut Context<'_>,
183
24039
        return_value_obtainer: impl FnOnce(&'s mut Option<S::Item>) -> Option<R>,
184
24039
    ) -> Poll<Option<R>> {
185
24039
        let mut self_ = self.project();
186
24039
        let r = Self::next_or_peek_inner(&mut self_, cx);
187
24039
        let r = r.map(|()| return_value_obtainer(self_.buffered));
188
24039
        Self::return_from_poll(self_.poll_waker, cx, r)
189
24039
    }
190

            
191
    /// Try to populate `buffer`, and calculate if we're `Ready`
192
    ///
193
    /// Returns `Ready` iff `poll_next` or `poll_peek` should return `Ready`.
194
    /// The actual `Ready` value (an `Option`) will be calculated later.
195
24039
    fn next_or_peek_inner(self_: &mut PeekerProj<S>, cx: &mut Context<'_>) -> Poll<()> {
196
24039
        if let Some(_item) = self_.buffered.as_ref() {
197
            // `return_value_obtainer` will find `Some` in `buffered`;
198
            // overall, we'll return `Ready(Some(..))`.
199
8304
            return Ready(());
200
15735
        }
201
15735
        let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
202
            // `return_value_obtainer` will find `None` in `buffered`;
203
            // overall, we'll return `Ready(None)`, ie EOF.
204
66
            return Ready(());
205
        };
206
15669
        match inner.poll_next(cx) {
207
            Ready(None) => {
208
1945
                self_.inner.set(None);
209
1945
                // `buffered` is `None`, still.
210
1945
                // overall, we'll return `Ready(None)`, ie EOF.
211
1945
                Ready(())
212
            }
213
9042
            Ready(Some(item)) => {
214
9042
                *self_.buffered = Some(item);
215
9042
                // return_value_obtainer` will find `Some` in `buffered`
216
9042
                Ready(())
217
            }
218
            Pending => {
219
                // `return_value_obtainer` won't be called.
220
                // overall, we'll return Pending
221
4682
                Pending
222
            }
223
        }
224
24039
    }
225

            
226
    /// Wait for an item to be ready, and then inspect it
227
    ///
228
    /// Equivalent to [`futures::stream::Peekable::peek`].
229
    ///
230
    /// # Tasks, waking, and calling context
231
    ///
232
    /// This should be called by the task that is reading from the stream.
233
    /// If it is called by another task, the reading task would miss notifications.
234
    //
235
    // This ^ docs section is triplicated for poll_peek, poll_peek_mut, and peek
236
    //
237
    // TODO this should be a method on the `PeekableStream` trait? Or a
238
    // `PeekableStreamExt` trait?
239
    // TODO should there be peek_mut ?
240
    #[allow(dead_code)] // TODO remove this allow if and when we make this module public
241
16
    pub fn peek(self: Pin<&mut Self>) -> PeekFuture<Self> {
242
16
        PeekFuture { peeker: Some(self) }
243
16
    }
244

            
245
    /// Return from a `poll_*` function, setting the stored waker appropriately
246
    ///
247
    /// Our `poll` functions always use this.
248
    /// The rule is that if a future returns `Pending`, it has stored the waker.
249
24039
    fn return_from_poll<R>(
250
24039
        poll_waker: &mut Option<Waker>,
251
24039
        cx: &mut Context<'_>,
252
24039
        r: Poll<R>,
253
24039
    ) -> Poll<R> {
254
24039
        *poll_waker = match &r {
255
            Ready(_) => {
256
                // No need to wake this task up any more.
257
19357
                None
258
            }
259
            Pending => {
260
                // try_peek must use the same waker to poll later
261
4682
                Some(cx.waker().clone())
262
            }
263
        };
264
24039
        r
265
24039
    }
266

            
267
    /// Obtain a raw reference to the inner stream
268
    ///
269
    /// ### Correctness!
270
    ///
271
    /// This method must be used with care!
272
    /// Whatever you do mustn't interfere with polling and peeking.
273
    /// Careless use can result in wrong behaviour including deadlocks.
274
1913
    pub fn as_raw_inner_pin_mut<'s>(self: Pin<&'s mut Self>) -> Option<Pin<&'s mut S>> {
275
1913
        self.project().inner.as_pin_mut()
276
1913
    }
277
}
278

            
279
impl<S: Stream> Stream for StreamUnobtrusivePeeker<S> {
280
    type Item = S::Item;
281

            
282
15625
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
283
15625
        self.impl_poll_next_or_peek(cx, |buffered| buffered.take())
284
15625
    }
285

            
286
    fn size_hint(&self) -> (usize, Option<usize>) {
287
        let buf = self.buffered.iter().count();
288
        let (imin, imax) = match &self.inner {
289
            Some(inner) => inner.size_hint(),
290
            None => (0, Some(0)),
291
        };
292
        (imin + buf, imax.and_then(|imap| imap.checked_add(buf)))
293
    }
294
}
295

            
296
impl<S: Stream> FusedStream for StreamUnobtrusivePeeker<S> {
297
3354
    fn is_terminated(&self) -> bool {
298
3354
        self.buffered.is_none() && self.inner.is_none()
299
3354
    }
300
}
301

            
302
/// Future from [`StreamUnobtrusivePeeker::peek`]
303
// TODO: Move to tor_async_utils::peekable_stream.
304
#[derive(Educe)]
305
#[educe(Debug(bound("S: Debug")))]
306
#[must_use = "peek() return a Future, which does nothing unless awaited"]
307
pub struct PeekFuture<'s, S> {
308
    /// The underlying stream.
309
    ///
310
    /// `Some` until we have returned `Ready`, then `None`.
311
    /// See comment in `poll`.
312
    peeker: Option<Pin<&'s mut S>>,
313
}
314

            
315
impl<'s, S: PeekableStream> PeekFuture<'s, S> {
316
    /// Create a new `PeekFuture`.
317
    // TODO: replace with a trait method.
318
    pub fn new(stream: Pin<&'s mut S>) -> Self {
319
        Self {
320
            peeker: Some(stream),
321
        }
322
    }
323
}
324

            
325
impl<'s, S: PeekableStream> Future for PeekFuture<'s, S> {
326
    type Output = Option<&'s S::Item>;
327
28
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
328
28
        let self_ = self.get_mut();
329
28
        let peeker = self_
330
28
            .peeker
331
28
            .as_mut()
332
28
            .expect("PeekFuture polled after Ready");
333
28
        match peeker.as_mut().poll_peek(cx) {
334
12
            Pending => return Pending,
335
16
            Ready(_y) => {
336
16
                // Ideally we would have returned `y` here, but it's borrowed from PeekFuture
337
16
                // not from the original StreamUnobtrusivePeeker, and there's no way
338
16
                // to get a value with the right lifetime.  (In non-async code,
339
16
                // this is usually handled by the special magic for reborrowing &mut.)
340
16
                //
341
16
                // So we must redo the poll, but this time consuming `peeker`,
342
16
                // which gets us the right lifetime.  That's why it has to be `Option`.
343
16
                // Because we own &mut ... Self, we know that repeating the poll
344
16
                // gives the same answer.
345
16
            }
346
16
        }
347
16
        let peeker = self_.peeker.take().expect("it was Some before!");
348
16
        let r = peeker.poll_peek(cx);
349
16
        assert!(r.is_ready(), "it was Ready before!");
350
16
        r
351
28
    }
352
}
353

            
354
#[cfg(test)]
355
mod test {
356
    // @@ begin test lint list maintained by maint/add_warning @@
357
    #![allow(clippy::bool_assert_comparison)]
358
    #![allow(clippy::clone_on_copy)]
359
    #![allow(clippy::dbg_macro)]
360
    #![allow(clippy::mixed_attributes_style)]
361
    #![allow(clippy::print_stderr)]
362
    #![allow(clippy::print_stdout)]
363
    #![allow(clippy::single_char_pattern)]
364
    #![allow(clippy::unwrap_used)]
365
    #![allow(clippy::unchecked_duration_subtraction)]
366
    #![allow(clippy::useless_vec)]
367
    #![allow(clippy::needless_pass_by_value)]
368
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
369

            
370
    use super::*;
371
    use futures::channel::mpsc;
372
    use futures::{SinkExt as _, StreamExt as _};
373
    use std::pin::pin;
374
    use std::sync::{Arc, Mutex};
375
    use std::time::Duration;
376
    use tor_rtcompat::SleepProvider as _;
377
    use tor_rtmock::MockRuntime;
378

            
379
    fn ms(ms: u64) -> Duration {
380
        Duration::from_millis(ms)
381
    }
382

            
383
    #[test]
384
    fn wakeups() {
385
        MockRuntime::test_with_various(|rt| async move {
386
            let (mut tx, rx) = mpsc::unbounded();
387
            let ended = Arc::new(Mutex::new(false));
388

            
389
            rt.spawn_identified("rxr", {
390
                let rt = rt.clone();
391
                let ended = ended.clone();
392

            
393
                async move {
394
                    let rx = StreamUnobtrusivePeeker::new(rx);
395
                    let mut rx = pin!(rx);
396

            
397
                    let mut next = 0;
398
                    loop {
399
                        rt.sleep(ms(50)).await;
400
                        eprintln!("rx peek... ");
401
                        let peeked = rx.as_mut().unobtrusive_peek_mut();
402
                        eprintln!("rx peeked {peeked:?}");
403

            
404
                        if let Some(peeked) = peeked {
405
                            assert_eq!(*peeked, next);
406
                        }
407

            
408
                        rt.sleep(ms(50)).await;
409
                        eprintln!("rx next... ");
410
                        let eaten = rx.next().await;
411
                        eprintln!("rx eaten {eaten:?}");
412
                        if let Some(eaten) = eaten {
413
                            assert_eq!(eaten, next);
414
                            next += 1;
415
                        } else {
416
                            break;
417
                        }
418
                    }
419

            
420
                    *ended.lock().unwrap() = true;
421
                    eprintln!("rx ended");
422
                }
423
            });
424

            
425
            rt.spawn_identified("tx", {
426
                let rt = rt.clone();
427

            
428
                async move {
429
                    let mut numbers = 0..;
430
                    for wait in [125, 1, 125, 45, 1, 1, 1, 1000, 20, 1, 125, 125, 1000] {
431
                        eprintln!("tx sleep {wait}");
432
                        rt.sleep(ms(wait)).await;
433
                        let num = numbers.next().unwrap();
434
                        eprintln!("tx sending {num}");
435
                        tx.send(num).await.unwrap();
436
                    }
437

            
438
                    // This schedule arranges that, when we send EOF, the rx task
439
                    // has *peeked* rather than *polled* most recently,
440
                    // demonstrating that we can wake up the subsequent poll on EOF too.
441
                    eprintln!("tx final #1");
442
                    rt.sleep(ms(75)).await;
443
                    eprintln!("tx EOF");
444
                    drop(tx);
445
                    eprintln!("tx final #2");
446
                    rt.sleep(ms(10)).await;
447
                    assert!(!*ended.lock().unwrap());
448
                    eprintln!("tx final #3");
449
                    rt.sleep(ms(50)).await;
450
                    eprintln!("tx final #4");
451
                    assert!(*ended.lock().unwrap());
452
                }
453
            });
454

            
455
            rt.advance_until_stalled().await;
456
        });
457
    }
458

            
459
    #[test]
460
    fn poll_peek_paths() {
461
        MockRuntime::test_with_various(|rt| async move {
462
            let (mut tx, rx) = mpsc::unbounded();
463
            let ended = Arc::new(Mutex::new(false));
464

            
465
            rt.spawn_identified("rxr", {
466
                let rt = rt.clone();
467
                let ended = ended.clone();
468

            
469
                async move {
470
                    let rx = StreamUnobtrusivePeeker::new(rx);
471
                    let mut rx = pin!(rx);
472

            
473
                    while let Some(peeked) = rx.as_mut().peek().await.copied() {
474
                        eprintln!("rx peeked {peeked}");
475
                        let eaten = rx.next().await.unwrap();
476
                        eprintln!("rx eaten  {eaten}");
477
                        assert_eq!(peeked, eaten);
478
                        rt.sleep(ms(10)).await;
479
                        eprintln!("rx slept, peeking");
480
                    }
481
                    *ended.lock().unwrap() = true;
482
                    eprintln!("rx ended");
483
                }
484
            });
485

            
486
            rt.spawn_identified("tx", {
487
                let rt = rt.clone();
488

            
489
                async move {
490
                    let mut numbers = 0..;
491

            
492
                    // macro because we don't have proper async closures
493
                    macro_rules! send { {} => {
494
                        let num = numbers.next().unwrap();
495
                        eprintln!("tx send   {num}");
496
                        tx.send(num).await.unwrap();
497
                    } }
498

            
499
                    eprintln!("tx starting");
500
                    rt.sleep(ms(100)).await;
501
                    send!();
502
                    rt.sleep(ms(100)).await;
503
                    send!();
504
                    send!();
505
                    rt.sleep(ms(100)).await;
506
                    eprintln!("tx dropping");
507
                    drop(tx);
508
                    rt.sleep(ms(5)).await;
509
                    eprintln!("tx ending");
510
                    assert!(*ended.lock().unwrap());
511
                }
512
            });
513

            
514
            rt.advance_until_stalled().await;
515
        });
516
    }
517
}