1
//! [`SinkCloseChannel`]
2

            
3
use std::pin::Pin;
4

            
5
use futures::channel::mpsc;
6
use futures::Sink;
7

            
8
//---------- principal API ----------
9

            
10
/// A [`Sink`] with a `close_channel` method like [`futures::channel::mpsc::Sender`'s]
11
pub trait SinkCloseChannel<T>: Sink<T> {
12
    /// Close the channel from the sending end, giving EOF at the receiver
13
    ///
14
    /// Future attempts to send will get a disconnected error.
15
    ///
16
    /// This closes *all* equivalent senders for the underlying data sink.
17
    /// For example, if `Self` is `Clone`, all clones are affected.
18
    ///
19
    /// If the Sink is for a channel,
20
    /// the receiver will see EOF after reading the messages that were successfully sent so far.
21
    fn close_channel(self: Pin<&mut Self>);
22
}
23

            
24
//---------- impl for futures::channel::mpsc ----------
25

            
26
impl<T> SinkCloseChannel<T> for mpsc::Sender<T> {
27
8
    fn close_channel(self: Pin<&mut Self>) {
28
8
        let self_: &mut Self = Pin::into_inner(self);
29
8
        self_.close_channel();
30
8
    }
31
}
32

            
33
#[cfg(test)]
34
mod test {
35
    // @@ begin test lint list maintained by maint/add_warning @@
36
    #![allow(clippy::bool_assert_comparison)]
37
    #![allow(clippy::clone_on_copy)]
38
    #![allow(clippy::dbg_macro)]
39
    #![allow(clippy::mixed_attributes_style)]
40
    #![allow(clippy::print_stderr)]
41
    #![allow(clippy::print_stdout)]
42
    #![allow(clippy::single_char_pattern)]
43
    #![allow(clippy::unwrap_used)]
44
    #![allow(clippy::unchecked_duration_subtraction)]
45
    #![allow(clippy::useless_vec)]
46
    #![allow(clippy::needless_pass_by_value)]
47
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
48
    #![allow(clippy::arithmetic_side_effects)] // don't mind potential panicking ops in tests
49
    #![allow(clippy::useless_format)] // sorely
50

            
51
    use super::*;
52
    use futures::{SinkExt as _, StreamExt as _};
53

            
54
    #[test]
55
    fn close_channel() {
56
        tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
57
            let (mut tx, mut rx) = mpsc::channel::<i32>(20);
58
            tx.send(0).await.unwrap();
59
            let mut tx2 = tx.clone();
60
            tx2.send(1).await.unwrap();
61
            tx2.close_channel();
62
            let _: mpsc::SendError = tx.send(66).await.unwrap_err();
63
            for i in 0..=1 {
64
                assert_eq!(rx.next().await.unwrap(), i);
65
            }
66
            assert_eq!(rx.next().await, None);
67
        });
68
    }
69
}