1
//! [`StreamUnobtrusivePeeker`]
2
//!
3
//! The memory tracker needs a way to look at the next item of a stream
4
//! (if there is one, or there can immediately be one),
5
//! *without* getting involved with the async tasks.
6
//
7
// TODO at some point this should probably be in tor-async-utils
8

            
9
#![allow(dead_code)] // TODO #351
10

            
11
use educe::Educe;
12
use futures::stream::FusedStream;
13
use futures::task::noop_waker_ref;
14
use futures::Stream;
15
use pin_project::pin_project;
16

            
17
use crate::peekable_stream::{PeekableStream, UnobtrusivePeekableStream};
18

            
19
use std::fmt::Debug;
20
use std::future::Future;
21
use std::pin::Pin;
22
use std::task::{Context, Poll, Poll::*, Waker};
23

            
24
/// Wraps [`Stream`] and provides `\[poll_]peek` and `unobtrusive_peek`
25
///
26
/// [`unobtrusive_peek`](StreamUnobtrusivePeeker::unobtrusive_peek)
27
/// is callable in sync contexts, outside the reading task.
28
///
29
/// Like [`futures::stream::Peekable`],
30
/// this has an async `peek` method, and `poll_peek`,
31
/// for use from the task that is also reading (via the [`Stream`] impl).
32
/// But, that type doesn't have `unobtrusive_peek`.
33
///
34
/// One way to conceptualise this is that `StreamUnobtrusivePeeker` is dual-ported:
35
/// the two sets of APIs, while provided on the same type,
36
/// are typically called from different contexts.
37
//
38
// It wasn't particularly easy to think of a good name for this type.
39
// We intend, probably:
40
//     struct StreamUnobtrusivePeeker
41
//     trait StreamUnobtrusivePeekable
42
//     trait StreamPeekable (impl for StreamUnobtrusivePeeker and futures::stream::Peekable)
43
//
44
// Searching a thesaurus produced these suggested words:
45
//     unobtrusive subtle discreet inconspicuous cautious furtive
46
// Asking in MR review also suggested
47
//     quick
48
//
49
// It's awkward because "peek" already has significant connotations of not disturbing things.
50
// That's why it was used in Iterator::peek.
51
//
52
// But when we translate this into async context,
53
// we have the poll_peek method on futures::stream::Peekable,
54
// which doesn't remove items from the stream,
55
// but *does* *wait* for items and therefore engages with the async context,
56
// and therefore involves *mutating* the Peekable (to store the new waker).
57
//
58
// Now we end up needing a word for an *even less disturbing* kind of interaction.
59
//
60
// `quick` (and synonyms) isn't quite right either because it's not necessarily faster,
61
// and certainly not more performant.
62
#[derive(Debug)]
63
31752
#[pin_project(project = PeekerProj)]
64
pub struct StreamUnobtrusivePeeker<S: Stream> {
65
    /// An item that we have peeked.
66
    ///
67
    /// (If we peeked EOF, that's represented by `None` in inner.)
68
    buffered: Option<S::Item>,
69

            
70
    /// The `Waker` from the last time we were polled and returned `Pending`
71
    ///
72
    /// "polled" includes any of our `poll_` methods
73
    /// but *not* `unobtrusive_peek`.
74
    ///
75
    /// `None` if we haven't been polled, or the last poll returned `Ready`.
76
    poll_waker: Option<Waker>,
77

            
78
    /// The inner stream
79
    ///
80
    /// `None if it has yielded `None` meaning EOF.  We don't require S: FusedStream.
81
    #[pin]
82
    inner: Option<S>,
83
}
84

            
85
impl<S: Stream> StreamUnobtrusivePeeker<S> {
86
    /// Create a new `StreamUnobtrusivePeeker` from a `Stream`
87
372
    pub fn new(inner: S) -> Self {
88
372
        StreamUnobtrusivePeeker {
89
372
            buffered: None,
90
372
            poll_waker: None,
91
372
            inner: Some(inner),
92
372
        }
93
372
    }
94
}
95

            
96
impl<S: Stream> UnobtrusivePeekableStream for StreamUnobtrusivePeeker<S> {
97
5264
    fn unobtrusive_peek_mut<'s>(mut self: Pin<&'s mut Self>) -> Option<&'s mut S::Item> {
98
5264
        #[allow(clippy::question_mark)] // We use explicit control flow here for clarity
99
5264
        if self.as_mut().project().buffered.is_none() {
100
            // We don't have a buffered item, but the stream may have an item available.
101
            // We must poll it to find out.
102
            //
103
            // We need to pass a Context to poll_next.
104
            // inner may store this context, replacing one provided via poll_*.
105
            //
106
            // Despite that, we need to make sure that wakeups will happen as expected.
107
            // To achieve this we have retained a copy of the caller's Waker.
108
            //
109
            // When a future or stream returns Pending, it proposes to wake `waker`
110
            // when it wants to be polled again.
111
            //
112
            // We uphold that promise by
113
            // - only returning Pending from our poll methods if inner also returned Pending
114
            // - when one of our poll methods returns Pending, saving the caller-supplied
115
            //   waker, so that we can make the intermediate poll call here.
116
            //
117
            // If the inner poll returns Ready, inner no longer guarantees to wake anyone.
118
            // In principle, if our user is waiting (we returned Pending),
119
            // then inner ought to have called `wake` on the caller's `Waker`.
120
            // But I don't think we can guarantee that an executor won't defer a wakeup,
121
            // and respond to a dropped Waker by cancelling that wakeup;
122
            // or to put it another way, the wakeup might be "in flight" on entry,
123
            // but the call to inner's poll_next returning Ready
124
            // might somehow "cancel" the wakeup.
125
            //
126
            // So just to be sure, if we get a Ready here, we wake the stored waker.
127

            
128
102
            let mut self_ = self.as_mut().project();
129

            
130
102
            let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
131
30
                return None;
132
            };
133

            
134
72
            let waker = if let Some(waker) = self_.poll_waker.as_ref() {
135
                waker
136
            } else {
137
72
                noop_waker_ref()
138
            };
139

            
140
72
            match inner.poll_next(&mut Context::from_waker(waker)) {
141
24
                Pending => {}
142
48
                Ready(item_or_eof) => {
143
48
                    if let Some(waker) = self_.poll_waker.take() {
144
                        waker.wake();
145
48
                    }
146
48
                    match item_or_eof {
147
                        None => self_.inner.set(None),
148
48
                        Some(item) => *self_.buffered = Some(item),
149
                    }
150
                }
151
            };
152
5162
        }
153

            
154
5234
        self.project().buffered.as_mut()
155
5264
    }
156
}
157

            
158
impl<S: Stream> PeekableStream for StreamUnobtrusivePeeker<S> {
159
44
    fn poll_peek<'s>(self: Pin<&'s mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
160
44
        self.impl_poll_next_or_peek(cx, |buffered| buffered.as_ref())
161
44
    }
162

            
163
15662
    fn poll_peek_mut<'s>(
164
15662
        self: Pin<&'s mut Self>,
165
15662
        cx: &mut Context<'_>,
166
15662
    ) -> Poll<Option<&'s mut S::Item>> {
167
15662
        self.impl_poll_next_or_peek(cx, |buffered| buffered.as_mut())
168
15662
    }
169
}
170

            
171
impl<S: Stream> StreamUnobtrusivePeeker<S> {
172
    /// Implementation of `poll_{peek,next}`
173
    ///
174
    /// This takes care of
175
    ///   * examining the state of our buffer, and polling inner if needed
176
    ///   * ensuring that we store a waker, if needed
177
    ///   * dealing with some borrowck awkwardness
178
    ///
179
    /// The `Ready` value is always calculated from `buffer`.
180
    /// `return_value_obtainer` is called only if we are going to return `Ready`.
181
    /// It's given `buffer` and should either:
182
    ///   * [`take`](Option::take) the contained value (for `poll_next`)
183
    ///   * return a reference using [`Option::as_ref`] (for `poll_peek`)
184
21132
    fn impl_poll_next_or_peek<'s, R: 's>(
185
21132
        self: Pin<&'s mut Self>,
186
21132
        cx: &mut Context<'_>,
187
21132
        return_value_obtainer: impl FnOnce(&'s mut Option<S::Item>) -> Option<R>,
188
21132
    ) -> Poll<Option<R>> {
189
21132
        let mut self_ = self.project();
190
21132
        let r = Self::next_or_peek_inner(&mut self_, cx);
191
21132
        let r = r.map(|()| return_value_obtainer(self_.buffered));
192
21132
        Self::return_from_poll(self_.poll_waker, cx, r)
193
21132
    }
194

            
195
    /// Try to populate `buffer`, and calculate if we're `Ready`
196
    ///
197
    /// Returns `Ready` iff `poll_next` or `poll_peek` should return `Ready`.
198
    /// The actual `Ready` value (an `Option`) will be calculated later.
199
21132
    fn next_or_peek_inner(self_: &mut PeekerProj<S>, cx: &mut Context<'_>) -> Poll<()> {
200
21132
        if let Some(_item) = self_.buffered.as_ref() {
201
            // `return_value_obtainer` will find `Some` in `buffered`;
202
            // overall, we'll return `Ready(Some(..))`.
203
15558
            return Ready(());
204
5574
        }
205
5574
        let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
206
            // `return_value_obtainer` will find `None` in `buffered`;
207
            // overall, we'll return `Ready(None)`, ie EOF.
208
4
            return Ready(());
209
        };
210
5570
        match inner.poll_next(cx) {
211
            Ready(None) => {
212
50
                self_.inner.set(None);
213
50
                // `buffered` is `None`, still.
214
50
                // overall, we'll return `Ready(None)`, ie EOF.
215
50
                Ready(())
216
            }
217
5350
            Ready(Some(item)) => {
218
5350
                *self_.buffered = Some(item);
219
5350
                // return_value_obtainer` will find `Some` in `buffered`
220
5350
                Ready(())
221
            }
222
            Pending => {
223
                // `return_value_obtainer` won't be called.
224
                // overall, we'll return Pending
225
170
                Pending
226
            }
227
        }
228
21132
    }
229

            
230
    /// Wait for an item to be ready, and then inspect it
231
    ///
232
    /// Equivalent to [`futures::stream::Peekable::peek`].
233
    ///
234
    /// # Tasks, waking, and calling context
235
    ///
236
    /// This should be called by the task that is reading from the stream.
237
    /// If it is called by another task, the reading task would miss notifications.
238
    //
239
    // This ^ docs section is triplicated for poll_peek, poll_peek_mut, and peek
240
    //
241
    // TODO this should be a method on the `PeekableStream` trait? Or a
242
    // `PeekableStreamExt` trait?
243
    // TODO should there be peek_mut ?
244
    #[allow(dead_code)] // TODO remove this allow if and when we make this module public
245
16
    pub fn peek(self: Pin<&mut Self>) -> PeekFuture<Self> {
246
16
        PeekFuture { peeker: Some(self) }
247
16
    }
248

            
249
    /// Return from a `poll_*` function, setting the stored waker appropriately
250
    ///
251
    /// Our `poll` functions always use this.
252
    /// The rule is that if a future returns `Pending`, it has stored the waker.
253
21132
    fn return_from_poll<R>(
254
21132
        poll_waker: &mut Option<Waker>,
255
21132
        cx: &mut Context<'_>,
256
21132
        r: Poll<R>,
257
21132
    ) -> Poll<R> {
258
21132
        *poll_waker = match &r {
259
            Ready(_) => {
260
                // No need to wake this task up any more.
261
20962
                None
262
            }
263
            Pending => {
264
                // try_peek must use the same waker to poll later
265
170
                Some(cx.waker().clone())
266
            }
267
        };
268
21132
        r
269
21132
    }
270

            
271
    /// Obtain a raw reference to the inner stream
272
    ///
273
    /// ### Correctness!
274
    ///
275
    /// This method must be used with care!
276
    /// Whatever you do mustn't interfere with polling and peeking.
277
    /// Careless use can result in wrong behaviour including deadlocks.
278
20
    pub fn as_raw_inner_pin_mut<'s>(self: Pin<&'s mut Self>) -> Option<Pin<&'s mut S>> {
279
20
        self.project().inner.as_pin_mut()
280
20
    }
281
}
282

            
283
impl<S: Stream> Stream for StreamUnobtrusivePeeker<S> {
284
    type Item = S::Item;
285

            
286
5426
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
287
5426
        self.impl_poll_next_or_peek(cx, |buffered| buffered.take())
288
5426
    }
289

            
290
    fn size_hint(&self) -> (usize, Option<usize>) {
291
        let buf = self.buffered.iter().count();
292
        let (imin, imax) = match &self.inner {
293
            Some(inner) => inner.size_hint(),
294
            None => (0, Some(0)),
295
        };
296
        (imin + buf, imax.and_then(|imap| imap.checked_add(buf)))
297
    }
298
}
299

            
300
impl<S: Stream> FusedStream for StreamUnobtrusivePeeker<S> {
301
    fn is_terminated(&self) -> bool {
302
        self.buffered.is_none() && self.inner.is_none()
303
    }
304
}
305

            
306
/// Future from [`StreamUnobtrusivePeeker::peek`]
307
// TODO: Move to tor_async_utils::peekable_stream.
308
#[derive(Educe)]
309
#[educe(Debug(bound("S: Debug")))]
310
#[must_use = "peek() return a Future, which does nothing unless awaited"]
311
pub struct PeekFuture<'s, S> {
312
    /// The underlying stream.
313
    ///
314
    /// `Some` until we have returned `Ready`, then `None`.
315
    /// See comment in `poll`.
316
    peeker: Option<Pin<&'s mut S>>,
317
}
318

            
319
impl<'s, S: PeekableStream> PeekFuture<'s, S> {
320
    /// Create a new `PeekFuture`.
321
    // TODO: replace with a trait method.
322
    pub fn new(stream: Pin<&'s mut S>) -> Self {
323
        Self {
324
            peeker: Some(stream),
325
        }
326
    }
327
}
328

            
329
impl<'s, S: PeekableStream> Future for PeekFuture<'s, S> {
330
    type Output = Option<&'s S::Item>;
331
28
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
332
28
        let self_ = self.get_mut();
333
28
        let peeker = self_
334
28
            .peeker
335
28
            .as_mut()
336
28
            .expect("PeekFuture polled after Ready");
337
28
        match peeker.as_mut().poll_peek(cx) {
338
12
            Pending => return Pending,
339
16
            Ready(_y) => {
340
16
                // Ideally we would have returned `y` here, but it's borrowed from PeekFuture
341
16
                // not from the original StreamUnobtrusivePeeker, and there's no way
342
16
                // to get a value with the right lifetime.  (In non-async code,
343
16
                // this is usually handled by the special magic for reborrowing &mut.)
344
16
                //
345
16
                // So we must redo the poll, but this time consuming `peeker`,
346
16
                // which gets us the right lifetime.  That's why it has to be `Option`.
347
16
                // Because we own &mut ... Self, we know that repeating the poll
348
16
                // gives the same answer.
349
16
            }
350
16
        }
351
16
        let peeker = self_.peeker.take().expect("it was Some before!");
352
16
        let r = peeker.poll_peek(cx);
353
16
        assert!(r.is_ready(), "it was Ready before!");
354
16
        r
355
28
    }
356
}
357

            
358
#[cfg(test)]
359
mod test {
360
    // @@ begin test lint list maintained by maint/add_warning @@
361
    #![allow(clippy::bool_assert_comparison)]
362
    #![allow(clippy::clone_on_copy)]
363
    #![allow(clippy::dbg_macro)]
364
    #![allow(clippy::mixed_attributes_style)]
365
    #![allow(clippy::print_stderr)]
366
    #![allow(clippy::print_stdout)]
367
    #![allow(clippy::single_char_pattern)]
368
    #![allow(clippy::unwrap_used)]
369
    #![allow(clippy::unchecked_duration_subtraction)]
370
    #![allow(clippy::useless_vec)]
371
    #![allow(clippy::needless_pass_by_value)]
372
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
373

            
374
    use super::*;
375
    use futures::channel::mpsc;
376
    use futures::{SinkExt as _, StreamExt as _};
377
    use std::pin::pin;
378
    use std::sync::{Arc, Mutex};
379
    use std::time::Duration;
380
    use tor_rtcompat::SleepProvider as _;
381
    use tor_rtmock::MockRuntime;
382

            
383
    fn ms(ms: u64) -> Duration {
384
        Duration::from_millis(ms)
385
    }
386

            
387
    #[test]
388
    fn wakeups() {
389
        MockRuntime::test_with_various(|rt| async move {
390
            let (mut tx, rx) = mpsc::unbounded();
391
            let ended = Arc::new(Mutex::new(false));
392

            
393
            rt.spawn_identified("rxr", {
394
                let rt = rt.clone();
395
                let ended = ended.clone();
396

            
397
                async move {
398
                    let rx = StreamUnobtrusivePeeker::new(rx);
399
                    let mut rx = pin!(rx);
400

            
401
                    let mut next = 0;
402
                    loop {
403
                        rt.sleep(ms(50)).await;
404
                        eprintln!("rx peek... ");
405
                        let peeked = rx.as_mut().unobtrusive_peek_mut();
406
                        eprintln!("rx peeked {peeked:?}");
407

            
408
                        if let Some(peeked) = peeked {
409
                            assert_eq!(*peeked, next);
410
                        }
411

            
412
                        rt.sleep(ms(50)).await;
413
                        eprintln!("rx next... ");
414
                        let eaten = rx.next().await;
415
                        eprintln!("rx eaten {eaten:?}");
416
                        if let Some(eaten) = eaten {
417
                            assert_eq!(eaten, next);
418
                            next += 1;
419
                        } else {
420
                            break;
421
                        }
422
                    }
423

            
424
                    *ended.lock().unwrap() = true;
425
                    eprintln!("rx ended");
426
                }
427
            });
428

            
429
            rt.spawn_identified("tx", {
430
                let rt = rt.clone();
431

            
432
                async move {
433
                    let mut numbers = 0..;
434
                    for wait in [125, 1, 125, 45, 1, 1, 1, 1000, 20, 1, 125, 125, 1000] {
435
                        eprintln!("tx sleep {wait}");
436
                        rt.sleep(ms(wait)).await;
437
                        let num = numbers.next().unwrap();
438
                        eprintln!("tx sending {num}");
439
                        tx.send(num).await.unwrap();
440
                    }
441

            
442
                    // This schedule arranges that, when we send EOF, the rx task
443
                    // has *peeked* rather than *polled* most recently,
444
                    // demonstrating that we can wake up the subsequent poll on EOF too.
445
                    eprintln!("tx final #1");
446
                    rt.sleep(ms(75)).await;
447
                    eprintln!("tx EOF");
448
                    drop(tx);
449
                    eprintln!("tx final #2");
450
                    rt.sleep(ms(10)).await;
451
                    assert!(!*ended.lock().unwrap());
452
                    eprintln!("tx final #3");
453
                    rt.sleep(ms(50)).await;
454
                    eprintln!("tx final #4");
455
                    assert!(*ended.lock().unwrap());
456
                }
457
            });
458

            
459
            rt.advance_until_stalled().await;
460
        });
461
    }
462

            
463
    #[test]
464
    fn poll_peek_paths() {
465
        MockRuntime::test_with_various(|rt| async move {
466
            let (mut tx, rx) = mpsc::unbounded();
467
            let ended = Arc::new(Mutex::new(false));
468

            
469
            rt.spawn_identified("rxr", {
470
                let rt = rt.clone();
471
                let ended = ended.clone();
472

            
473
                async move {
474
                    let rx = StreamUnobtrusivePeeker::new(rx);
475
                    let mut rx = pin!(rx);
476

            
477
                    while let Some(peeked) = rx.as_mut().peek().await.copied() {
478
                        eprintln!("rx peeked {peeked}");
479
                        let eaten = rx.next().await.unwrap();
480
                        eprintln!("rx eaten  {eaten}");
481
                        assert_eq!(peeked, eaten);
482
                        rt.sleep(ms(10)).await;
483
                        eprintln!("rx slept, peeking");
484
                    }
485
                    *ended.lock().unwrap() = true;
486
                    eprintln!("rx ended");
487
                }
488
            });
489

            
490
            rt.spawn_identified("tx", {
491
                let rt = rt.clone();
492

            
493
                async move {
494
                    let mut numbers = 0..;
495

            
496
                    // macro because we don't have proper async closures
497
                    macro_rules! send { {} => {
498
                        let num = numbers.next().unwrap();
499
                        eprintln!("tx send   {num}");
500
                        tx.send(num).await.unwrap();
501
                    } }
502

            
503
                    eprintln!("tx starting");
504
                    rt.sleep(ms(100)).await;
505
                    send!();
506
                    rt.sleep(ms(100)).await;
507
                    send!();
508
                    send!();
509
                    rt.sleep(ms(100)).await;
510
                    eprintln!("tx dropping");
511
                    drop(tx);
512
                    rt.sleep(ms(5)).await;
513
                    eprintln!("tx ending");
514
                    assert!(*ended.lock().unwrap());
515
                }
516
            });
517

            
518
            rt.advance_until_stalled().await;
519
        });
520
    }
521
}