1
//! Extension trait for using [`Sink`] more safely.
2

            
3
use std::future::Future;
4
use std::marker::PhantomData;
5
use std::pin::Pin;
6
use std::task::{Context, Poll};
7

            
8
use futures::future::FusedFuture;
9
use futures::ready;
10
use futures::Sink;
11
use pin_project::pin_project;
12

            
13
/// Switch to the nontrivial version of this, to get debugging output on stderr
14
macro_rules! dprintln { { $f:literal $($a:tt)* } => { () } }
15
//macro_rules! dprintln { { $f:literal $($a:tt)* } => { eprintln!(concat!("    ",$f) $($a)*) } }
16

            
17
/// Extension trait for [`Sink`] to add a method for cancel-safe usage.
18
pub trait SinkPrepareExt<'w, OS, OM>
19
where
20
    OS: Sink<OM>,
21
{
22
    /// For processing an item obtained from a future, avoiding async cancel lossage
23
    ///
24
    /// ```
25
    /// # use futures::channel::mpsc;
26
    /// # use tor_async_utils::SinkPrepareExt as _;
27
    /// #
28
    /// # #[tokio::main]
29
    /// # async fn main() -> Result<(),mpsc::SendError> {
30
    /// #   let (mut sink, sink_r) = mpsc::unbounded::<usize>();
31
    /// #   let message_generator_future = futures::future::ready(42);
32
    /// #   let process_message = |m| Ok::<_,mpsc::SendError>(m);
33
    ///     let (message, sendable) = sink.prepare_send_from(
34
    ///         message_generator_future
35
    ///     ).await?;
36
    ///     let message = process_message(message)?;
37
    ///     sendable.send(message);
38
    /// #   Ok(())
39
    /// # }
40
    /// ```
41
    ///
42
    /// Prepares to send a output message[^terminology] `OM` to an output sink `OS` (`self`),
43
    /// where the `OM` is made from an input message `IM`,
44
    /// and the `IM` is obtained from a future, `generator: IF`.
45
    ///
46
    /// [^terminology]: We sometimes use slightly inconsistent terminology,
47
    /// "item" vs "message".
48
    /// This avoids having to have the generic parameters by named `OI` and `II`
49
    /// where `I` is sometimes "item" and sometimes "input".
50
    ///
51
    /// When successfully run, `prepare_send_from` gives `(IM, SinkSendable)`.
52
    ///
53
    /// After processing `IM` into `OM`,
54
    /// use the [`SinkSendable`] to [`send`](SinkSendable::send) the `OM` to `OS`.
55
    ///
56
    /// # Why use this
57
    ///
58
    /// This avoids the an async cancellation hazard
59
    /// which exists with naive use of `select!`
60
    /// followed by `OS.send().await`.  You might write this:
61
    ///
62
    /// ```rust,ignore
63
    /// select!{
64
    ///     message = input_stream.next() => {
65
    ///         if let Some(message) = message {
66
    ///             let message = do_our_processing(message);
67
    ///             output_sink(message).await; // <---**BUG**
68
    ///         }
69
    ///     }
70
    ///     control = something_else() => { .. }
71
    /// }
72
    /// ```
73
    ///
74
    /// If, when we reach `BUG`, the output sink is not ready to receive the message,
75
    /// the future for that particular `select!` branch will be suspended.
76
    /// But when `select!` finds that *any one* of the branches is ready,
77
    /// it *drops* the futures for the other branches.
78
    /// That drops all the local variables, including possibly `message`, losing it.
79
    ///
80
    /// For more about cancellation safety, see
81
    /// [Rust for the Polyglot Programmer](https://www.chiark.greenend.org.uk/~ianmdlvl/rust-polyglot/async.html#cancellation-safety)
82
    /// which has a general summary, and
83
    /// Matthias Einwag's
84
    /// [extensive discussion in his gist](https://gist.github.com/Matthias247/ffc0f189742abf6aa41a226fe07398a8#cancellation-in-async-rust)
85
    /// with comparisons to other languages.
86
    ///
87
    /// ## Alternatives
88
    ///
89
    /// Unbounded mpsc channels, and certain other primitives,
90
    /// do not suffer from this problem because they do not block.
91
    /// `UnboundedSender` offers
92
    /// [`unbounded_send`](futures::channel::mpsc::UnboundedSender::unbounded_send)
93
    /// but only as an inherent method, so this does not compose with `Sink` combinators.
94
    /// And of course unbounded channels do not implement any backpressure.
95
    ///
96
    /// The problem can otherwise be avoided by completely eschewing use of `select!`
97
    /// and writing manual implementations of `Future`, `Sink`, and so on,
98
    /// However, such code is typically considerably more complex and involves
99
    /// entangling the primary logic with future machinery.
100
    /// It is normally better to write primary functionality in `async { }`
101
    /// using utilities (often "futures combinators") such as this one.
102
    ///
103
    // Personal note from @Diziet:
104
    // IMO it is generally accepted in the Rust community that
105
    // it is not good practice to write principal code at the manual futues level.
106
    // However, I have not been able to find very clear support for this proposition.
107
    // There are endless articles explaining how futures work internally,
108
    // often by describing how to reimplement standard combinators such as `map`.
109
    // ISTM that these exist to help understanding,
110
    // but it seems to be only rarely stated that doing this is not generally a good idea.
111
    //
112
    // I did find the following:
113
    //
114
    //  https://dev.to/mindflavor/rust-futures-an-uneducated-short-and-hopefully-not-boring-tutorial---part-4---a-real-future-from-scratch-734#conclusion
115
    //
116
    //    Of course you generally do not write a future manually. You use the ones provided by
117
    //    libraries and compose them as needed. It's important to understand how they work
118
    //    nevertheless.
119
    //
120
    // And of curse the existence of the `futures` crate is indicative:
121
    // it consists almost entirely of combinators and utilities
122
    // whose purpose is to allow you to write many structures in async code
123
    // without needing to resort to manual future impls.
124
    //
125
    /// # Example
126
    ///
127
    /// This comprehensive example demonstrates how to read from possibly multiple sources
128
    /// and also be able to process other events:
129
    ///
130
    /// ```
131
    /// # #[tokio::main]
132
    /// # async fn main() {
133
    /// use futures::select;
134
    /// use futures::{SinkExt as _, StreamExt as _};
135
    /// use tor_async_utils::SinkPrepareExt as _;
136
    ///
137
    /// let (mut input_w, mut input_r) = futures::channel::mpsc::unbounded::<usize>();
138
    /// let (mut output_w, mut output_r) = futures::channel::mpsc::unbounded::<String>();
139
    /// input_w.send(42).await;
140
    /// select!{
141
    ///     ret = output_w.prepare_send_from(async {
142
    ///         select!{
143
    ///             got_input = input_r.next() => got_input.expect("input stream ended!"),
144
    ///             () = futures::future::pending() => panic!(), // other branches are OK here
145
    ///         }
146
    ///     }) => {
147
    ///         let (input_msg, sendable) = ret.unwrap();
148
    ///         let output_msg = input_msg.to_string();
149
    ///         let () = sendable.send(output_msg).unwrap();
150
    ///     },
151
    ///     () = futures::future::pending() => panic!(), // other branches are OK here
152
    /// }
153
    ///
154
    /// assert_eq!(output_r.next().await.unwrap(), "42");
155
    /// # }
156
    /// ```
157
    ///
158
    /// # Formally
159
    ///
160
    /// [`prepare_send_from`](SinkPrepareExt::prepare_send_from)
161
    /// returns a [`SinkPrepareSendFuture`] which, when awaited:
162
    ///
163
    ///  * Waits for `OS` to be ready to receive an item.
164
    ///  * Runs `message_generator` to obtain a `IM`.
165
    ///  * Returns the `IM` (for processing), and a [`SinkSendable`].
166
    ///
167
    /// The caller should then:
168
    ///
169
    ///  * Check the error from `prepare_send_from`
170
    ///    (which came from the *output* sink).
171
    ///  * Process the `IM`, making an `OM` out of it.
172
    ///  * Call [`sendable.send()`](SinkSendable::send) (and check its error).
173
    ///
174
    /// # Flushing
175
    ///
176
    /// `prepare_send_from` will (when awaited)
177
    /// [`flush`](futures::SinkExt::flush) the output sink
178
    /// when it finds the input is not ready yet.
179
    /// Until then items may be buffered
180
    /// (as if they had been written with [`feed`](futures::SinkExt::feed)).
181
    ///
182
    /// # Errors
183
    ///
184
    /// ## Output sink errors
185
    ///
186
    /// The call site can experience output sink errors in two places,
187
    /// [`prepare_send_from()`](SinkPrepareExt::prepare_send_from) and [`SinkSendable::send()`].
188
    /// The caller should typically handle them the same way regardless of when they occurred.
189
    ///
190
    /// If the error happens at [`SinkSendable::send()`],
191
    /// the call site will usually be forced to discard the item being processed.
192
    /// This will only occur if the sink is actually broken.
193
    ///
194
    /// ## Errors specific to the call site: faillible input, and fallible processing
195
    ///
196
    /// At some call sites, the input future may yield errors
197
    /// (perhaps it is reading from a `Stream` of [`Result`]s).
198
    /// in that case the value from the input future will be a [`Result`].
199
    /// Then `IM` is a `Result`, and is provided in the `.0` element
200
    /// of the "successful" return from `prepare_send_from`.
201
    ///
202
    /// And, at some call sites, the processing of an `IM` into an `OM` is fallible.
203
    ///
204
    /// Handling these latter two error caess is up to the caller,
205
    /// in the code which processes `IM`.
206
    /// The call site will often want to deal with such an error
207
    /// without sending anything into the output sink,
208
    /// and can then just drop the [`SinkSendable`].
209
    ///
210
    /// # Implementations
211
    ///
212
    /// This is an extension trait and you are not expected to need to implement it.
213
    ///
214
    /// There are provided implementations for `Pin<&mut impl Sink>`
215
    /// and `&mut impl Sink + Unpin`, for your convenience.
216
    fn prepare_send_from<IF, IM>(
217
        self,
218
        message_generator: IF,
219
    ) -> SinkPrepareSendFuture<'w, IF, OS, OM>
220
    where
221
        IF: Future<Output = IM>;
222
}
223

            
224
impl<'w, OS, OM> SinkPrepareExt<'w, OS, OM> for Pin<&'w mut OS>
225
where
226
    OS: Sink<OM>,
227
{
228
5762
    fn prepare_send_from<'r, IF, IM>(
229
5762
        self,
230
5762
        message_generator: IF,
231
5762
    ) -> SinkPrepareSendFuture<'w, IF, OS, OM>
232
5762
    where
233
5762
        IF: Future<Output = IM>,
234
5762
    {
235
5762
        SinkPrepareSendFuture {
236
5762
            output: Some(self),
237
5762
            generator: message_generator,
238
5762
            tw: PhantomData,
239
5762
        }
240
5762
    }
241
}
242

            
243
impl<'w, OS, OM> SinkPrepareExt<'w, OS, OM> for &'w mut OS
244
where
245
    OS: Sink<OM> + Unpin,
246
{
247
5762
    fn prepare_send_from<'r, IF, IM>(
248
5762
        self,
249
5762
        message_generator: IF,
250
5762
    ) -> SinkPrepareSendFuture<'w, IF, OS, OM>
251
5762
    where
252
5762
        IF: Future<Output = IM>,
253
5762
    {
254
5762
        Pin::new(self).prepare_send_from(message_generator)
255
5762
    }
256
}
257

            
258
/// Future for `SinkPrepareExt::prepare_send_from`
259
6674
#[pin_project]
260
#[must_use]
261
pub struct SinkPrepareSendFuture<'w, IF, OS, OM> {
262
    /// Underlying future that will yield a message.
263
    #[pin]
264
    generator: IF,
265

            
266
    /// This Option exists because otherwise SinkPrepareSendFuture::poll()
267
    /// can't move `output` out of this struct to put it into the `SinkSendable`.
268
    /// (The poll() impl cannot borrow from SinkPrepareSendFuture.)
269
    output: Option<Pin<&'w mut OS>>,
270

            
271
    /// `fn(OM)` gives contravariance in OM.
272
    ///
273
    /// Variance is confusing.
274
    /// Loosely, a SinkPrepareSendFuture<..OM> consumes an OM.
275
    /// Actually, we don't really need to add any variance restricions wrt OM,
276
    /// because the &mut OS already implies the correct variance,
277
    /// so we could have used the PhantomData<fn(*const OM)> trick.
278
    /// Happily there is no unsafe anywhere nearby, so it is not possible for us to write
279
    /// a bug due to getting the variance wrong - only to erroneously prevent some use
280
    /// case.
281
    tw: PhantomData<fn(OM)>,
282
}
283

            
284
/// A [`Sink`] which is ready to receive an item
285
///
286
/// Produced by [`SinkPrepareExt::prepare_send_from`].  See there for the overview docs.
287
///
288
/// This references an output sink `OS`.
289
/// It offers the ability to write into the sink without blocking,
290
/// (and constitutes a proof token that the sink has declared itself ready for that).
291
///
292
/// The only useful method is [`send`](SinkSendable::send).
293
///
294
/// `SinkSendable` has no drop glue and can be freely dropped,
295
/// for example if you prepare to send a message and then
296
/// encounter an error when producing the output message.
297
#[must_use]
298
pub struct SinkSendable<'w, OS, OM> {
299
    /// Reference to underlying output sink.
300
    output: Pin<&'w mut OS>,
301
    /// Marker to ensure that `OM` is used.
302
    tw: PhantomData<fn(OM)>,
303
}
304

            
305
impl<'w, IF, OS, IM, OM> Future for SinkPrepareSendFuture<'w, IF, OS, OM>
306
where
307
    IF: Future<Output = IM>,
308
    OS: Sink<OM>,
309
{
310
    type Output = Result<(IM, SinkSendable<'w, OS, OM>), OS::Error>;
311

            
312
6674
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
313
6674
        let mut self_ = self.project();
314

            
315
        /// returns `&mut Pin<&'w mut OS>` from self_.output
316
        //
317
        // macro because the closure's type parameters would be unnameable.
318
        macro_rules! get_output {
319
            ($self_:expr) => {
320
                $self_.output.as_mut().expect(BAD_POLL_MSG).as_mut()
321
            };
322
        }
323
        /// Message to give when panicking because of improper extra poll.
324
        const BAD_POLL_MSG: &str =
325
            "future from SinkPrepareExt::prepare_send_from (SinkPrepareSendFuture) \
326
                 polled after returning Ready(Ok)";
327

            
328
6674
        let () = match ready!(get_output!(self_).poll_ready(cx)) {
329
52
            Err(e) => {
330
52
                dprintln!("poll: output poll = IF.Err    SO  IF.Err");
331
52
                // Deliberately don't fuse by `take`ing output.  If we did that, we would expose
332
52
                // our caller to an additional panic risk.  There is no harm in polling the output
333
52
                // sink again: although `Sink` documents that a sink that returns errors will
334
52
                // probably continue to do so, it is not forbidden to try it and see.  This is in
335
52
                // any case better than definitely crashing if the `SinkPrepareSendFuture` is
336
52
                // polled after it gave Ready.
337
52
                return Poll::Ready(Err(e));
338
            }
339
6604
            Ok(()) => {
340
6604
                dprintln!("poll: output poll = IF.Ok     calling generator");
341
6604
            }
342
        };
343

            
344
6604
        let value = match self_.generator.as_mut().poll(cx) {
345
            Poll::Pending => {
346
                // We defer flushing the output until the input stops yielding.
347
                // This allows our caller (which is typically a loop) to transfer multiple
348
                // items from their input to their output between flushes.
349
                //
350
                // But we must not return `Pending` without flushing, or the caller could block
351
                // without flushing output, leading to untimely delivery of buffered data.
352
1342
                dprintln!("poll: generator = Pending     calling output flush");
353
1342
                let flushed = get_output!(self_).poll_flush(cx);
354
1340
                return match flushed {
355
2
                    Poll::Ready(Err(e)) => {
356
2
                        dprintln!("poll: output flush = IF.Err   SO  IF.Err");
357
2
                        Poll::Ready(Err(e))
358
                    }
359
                    Poll::Ready(Ok(())) => {
360
1338
                        dprintln!("poll: output flush = IF.Ok    SO  Pending");
361
1338
                        Poll::Pending
362
                    }
363
                    Poll::Pending => {
364
2
                        dprintln!("poll: output flush = Pending  SO  Pending");
365
2
                        Poll::Pending
366
                    }
367
                };
368
            }
369
5262
            Poll::Ready(v) => {
370
5262
                dprintln!("poll: generator = Ready       SO  IF.Ok");
371
5262
                v
372
5262
            }
373
5262
        };
374
5262

            
375
5262
        let sendable = SinkSendable {
376
5262
            output: self_.output.take().expect(BAD_POLL_MSG),
377
5262
            tw: PhantomData,
378
5262
        };
379
5262

            
380
5262
        Poll::Ready(Ok((value, sendable)))
381
6674
    }
382
}
383

            
384
impl<'w, IF, OS, IM, OM> FusedFuture for SinkPrepareSendFuture<'w, IF, OS, OM>
385
where
386
    IF: Future<Output = IM>,
387
    OS: Sink<OM>,
388
{
389
6666
    fn is_terminated(&self) -> bool {
390
6666
        let r = self.output.is_none();
391
6666
        dprintln!("is_terminated = {}", r);
392
6666
        r
393
6666
    }
394
}
395

            
396
impl<'w, OS, OM> SinkSendable<'w, OS, OM>
397
where
398
    OS: Sink<OM>,
399
{
400
    /// Synchronously send an item into `OS`, which is a [`Sink`]
401
    ///
402
    /// Can fail if the sink `OS` reports an error.
403
    ///
404
    /// (However, the existence of the `SinkSendable` demonstrates that
405
    /// the sink reported itself ready for sending,
406
    /// so this call is synchronous, avoiding cancellation hazards.)
407
5258
    pub fn send(self, item: OM) -> Result<(), OS::Error> {
408
5258
        dprintln!("send ...");
409
5258
        let r = self.output.start_send(item);
410
5258
        dprintln!("send: {:?}", r.as_ref().map_err(|_| (())));
411
5258
        r
412
5258
    }
413
}
414

            
415
#[cfg(test)]
416
mod test {
417
    // @@ begin test lint list maintained by maint/add_warning @@
418
    #![allow(clippy::bool_assert_comparison)]
419
    #![allow(clippy::clone_on_copy)]
420
    #![allow(clippy::dbg_macro)]
421
    #![allow(clippy::mixed_attributes_style)]
422
    #![allow(clippy::print_stderr)]
423
    #![allow(clippy::print_stdout)]
424
    #![allow(clippy::single_char_pattern)]
425
    #![allow(clippy::unwrap_used)]
426
    #![allow(clippy::unchecked_duration_subtraction)]
427
    #![allow(clippy::useless_vec)]
428
    #![allow(clippy::needless_pass_by_value)]
429
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
430

            
431
    use super::*;
432
    use futures::channel::mpsc;
433
    use futures::future::poll_fn;
434
    use futures::select_biased;
435
    use futures::SinkExt as _;
436
    use futures_await_test::async_test;
437
    use std::convert::Infallible;
438
    use std::sync::Arc;
439
    use std::sync::Mutex;
440

            
441
    #[async_test]
442
    async fn prepare_send() {
443
        // Early versions of this used unfold quite a lot more, but it is not really
444
        // convenient for testing.  It buffers one item internally, and is also buggy:
445
        //   https://github.com/rust-lang/futures-rs/issues/2600
446
        // So we use mpsc channels, which (perhaps with buffering) are quite controllable.
447

            
448
        // The eprintln!("FOR ...") calls correspond go the dprintln1() calls in the impl,
449
        // and can check that each code path in the implementation is used,
450
        // by turning on the dbug and using `--nocapture`.
451
        {
452
            eprintln!("-- disconnected ---");
453
            eprintln!("FOR poll: output poll = IF.Err    SO  IF.Err");
454
            let (mut w, r) = mpsc::unbounded::<usize>();
455
            drop(r);
456
            let ret = w.prepare_send_from(async { Ok::<_, Infallible>(12) }).await;
457
            assert!(ret.map(|_| ()).unwrap_err().is_disconnected());
458
        }
459

            
460
        {
461
            eprintln!("-- buffered late disconnect --");
462
            eprintln!("FOR poll: output poll = IF.Ok     calling generator");
463
            eprintln!("FOR poll: output flush = IF.Err   SO  IF.Err");
464
            let (w, r) = mpsc::unbounded::<usize>();
465
            let mut w = w.buffer(10);
466
            let mut r = Some(r);
467
            w.feed(66).await.unwrap();
468
            let ret = w
469
                .prepare_send_from(poll_fn(move |_cx| {
470
                    drop(r.take());
471
                    Poll::Pending::<usize>
472
                }))
473
                .await;
474
            assert!(ret.map(|_| ()).unwrap_err().is_disconnected());
475
        }
476

            
477
        {
478
            eprintln!("-- flushing before wait --");
479
            eprintln!("FOR poll: output flush = IF.Ok    SO  Pending");
480
            let (mut w, _r) = mpsc::unbounded::<usize>();
481
            let () = select_biased! {
482
                _ = w.prepare_send_from(poll_fn(
483
                    move |_cx| {
484
                        Poll::Pending::<usize>
485
                    }
486
                )) => panic!(),
487
                _ = futures::future::ready(()) => { },
488
            };
489
        }
490

            
491
        {
492
            eprintln!("-- flush before wait is pending --");
493
            eprintln!("FOR poll: output flush = Pending  SO  Pending");
494
            let (mut w, _r) = mpsc::channel::<usize>(0);
495
            let () = w.feed(77).await.unwrap();
496
            let mut w = w.buffer(10);
497
            let () = select_biased! {
498
                _ = w.prepare_send_from(poll_fn(
499
                    move |_cx| {
500
                        Poll::Pending::<usize>
501
                    }
502
                )) => panic!(),
503
                _ = futures::future::ready(()) => { },
504
            };
505
        }
506

            
507
        {
508
            eprintln!("-- flush before wait is pending --");
509
            eprintln!("FOR poll: generator = Ready       SO  IF.Ok");
510
            eprintln!("FOR send ...");
511
            eprintln!("ALSO check that bufferinrg works as expected");
512

            
513
            let sunk = Arc::new(Mutex::new(vec![]));
514
            let unfold = futures::sink::unfold((), |(), v| {
515
                let sunk = sunk.clone();
516
                async move {
517
                    dbg!();
518
                    sunk.lock().unwrap().push(v);
519
                    Ok::<_, Infallible>(())
520
                }
521
            });
522
            let mut unfold = Box::pin(unfold.buffer(10));
523
            for v in [42, 43] {
524
                // We can only do two here because that's how many we can actually buffer in Buffer
525
                // and Unfold.  Because our closure is always ready, the buffering isn't actually
526
                // as copious as all that.  This is fine, because the point of this test is to test
527
                // *flushing*.
528
                dbg!(v);
529
                let ret = unfold
530
                    .prepare_send_from(async move { Ok::<_, Infallible>(v) })
531
                    .await;
532
                let (msg, sendable) = ret.unwrap();
533
                let msg = msg.unwrap();
534
                assert_eq!(msg, v);
535
                let () = sendable.send(msg).unwrap();
536
                assert_eq!(*sunk.lock().unwrap(), &[]); // It's still buffered
537
            }
538
            select_biased! {
539
                _ = unfold.prepare_send_from(futures::future::pending::<()>()) => panic!(),
540
                _ = futures::future::ready(()) => { },
541
            };
542
            assert_eq!(*sunk.lock().unwrap(), &[42, 43]);
543
        }
544
    }
545
}