1
//! [`SinkTrySend`]
2

            
3
use std::error::Error;
4
use std::pin::Pin;
5
use std::sync::Arc;
6

            
7
use futures::channel::mpsc;
8
use futures::Sink;
9

            
10
use derive_deftly::{define_derive_deftly, Deftly};
11
use thiserror::Error;
12

            
13
//---------- principal API ----------
14

            
15
/// A [`Sink`] with a `try_send` method like [`futures::channel::mpsc::Sender`'s]
16
pub trait SinkTrySend<T>: Sink<T> {
17
    /// Errors that is not disconnected, or full
18
    type Error: SinkTrySendError;
19

            
20
    /// Try to send a message `msg`
21
    ///
22
    /// If this returns with an error indicating that the stream is full,
23
    /// *No* arrangements will have been made for a wakeup when space becomes available.
24
    ///
25
    /// If the send fails, `item` is dropped.
26
    /// If you need it back, use [`try_send_or_return`](SinkTrySend::try_send_or_return),
27
    ///
28
    /// (When implementing the trait, implement `try_send_or_return`, *not* this method.)
29
112
    fn try_send(self: Pin<&mut Self>, item: T) -> Result<(), <Self as SinkTrySend<T>>::Error> {
30
112
        self.try_send_or_return(item)
31
112
            .map_err(|(error, _item)| error)
32
112
    }
33

            
34
    /// Try to send a message `msg`
35
    ///
36
    /// Like [`try_send`](SinkTrySend::try_send),
37
    /// but if the send fails, the item is returned.
38
    ///
39
    /// (When implementing the trait, implement this method.)
40
    fn try_send_or_return(
41
        self: Pin<&mut Self>,
42
        item: T,
43
    ) -> Result<(), (<Self as SinkTrySend<T>>::Error, T)>;
44
}
45

            
46
/// Error from [`SinkTrySend::try_send`]
47
///
48
/// See also [`ErasedSinkTrySendError`] which can often
49
/// be usefully used when an implementation of `SinkTrySendError` is needed.
50
pub trait SinkTrySendError: Error + 'static {
51
    /// The stream was full.
52
    ///
53
    /// *No* arrangements will have been made for a wakeup when space becomes available.
54
    ///
55
    /// Corresponds to [`futures::channel::mpsc::TrySendError::is_full`]
56
    fn is_full(&self) -> bool;
57

            
58
    /// The stream has disconnected
59
    ///
60
    /// Corresponds to [`futures::channel::mpsc::TrySendError::is_disconnected`]
61
    fn is_disconnected(&self) -> bool;
62
}
63

            
64
//---------- macrology - this has to come here, ideally all in one go ----------
65

            
66
#[rustfmt::skip] // rustfmt makes a complete hash of this
67
define_derive_deftly! {
68
    /// Implements various things which handle `full` and `disconnected`
69
    ///
70
    /// # Generates
71
    ///
72
    ///  * `SinkTrySendError for`ErasedSinkTrySendError`
73
    ///  * `From<E: SinkTrySendError> for`ErasedSinkTrySendError`
74
    ///  * [`handle_mpsc_error`]
75
    ///
76
    /// Use of macros avoids copypaste errors like
77
    /// `fn is_full(..) { self.is_disconnected() }`.
78
    ErasedSinkTrySendError expect items:
79

            
80
    ${defcond PREDICATE vmeta(predicate)}
81
    ${define PREDICATE { $<is_ ${snake_case $vname}> }}
82

            
83
    impl SinkTrySendError for ErasedSinkTrySendError {
84
        $(
85
            ${when PREDICATE}
86

            
87
12
            fn $PREDICATE(&self) -> bool {
88
                matches!(self, $vtype)
89
            }
90
        )
91
    }
92

            
93
    impl ErasedSinkTrySendError {
94
        /// Obtain an `ErasedSinkTrySendError` from a concrete `SinkTrySendError`
95
        //
96
        // (Can't be a `From` impl because it conflicts with the identity `From<T> for T`.)
97
24
        pub fn from<E>(e: E) -> ErasedSinkTrySendError
98
24
        where E: SinkTrySendError + Send + Sync
99
24
        {
100
            $(
101
                ${when PREDICATE}
102
                if e.$PREDICATE() {
103
                    $vtype
104
                } else
105
            )
106
                /* else */ {
107
                    let e = Arc::new(e);
108
                    // Avoid generating a nested ErasedSinkTrySendError.
109
                    // Is it *already* an ESTSE (necessarily, then, an `Other`?)
110
                    //
111
                    // TODO replace this with a call to `downcast_value` from arti!2460
112
                    let e2 = e.clone();
113
                    match Arc::downcast(e2) {
114
                        Ok::<Arc<ErasedSinkTrySendError>, _>(y2) => {
115
                            drop(e); // Drop the original
116
                            let inner: ErasedSinkTrySendError =
117
                                Arc::into_inner(y2).expect(
118
              "somehow we weren't the only owner, despite us just having made an Arc!"
119
                                );
120
                            return inner;
121
                        }
122
                        Err(other_e2) => {
123
                            drop(other_e2);
124
                            // We need to use e, not other_e2, because Arc::downcast
125
                            // returns dyn Any but we need dyn SinkTrySendError.
126
                            ErasedSinkTrySendError::Other(e)
127
                        },
128
                    }
129
                }
130
        }
131
    }
132

            
133
    fn handle_mpsc_error<T>(me: mpsc::TrySendError<T>) -> (ErasedSinkTrySendError, T) {
134
        let error = $(
135
            ${when PREDICATE}
136

            
137
            if me.$PREDICATE() {
138
                $vtype
139
            } else
140
        )
141
            /* else */ {
142
                $ttype::Other(Arc::new(MpscOtherSinkTrySendError {}))
143
            };
144
        (error, me.into_inner())
145
    }
146
}
147

            
148
//---------- helper - erased error ----------
149

            
150
/// Type-erased error for [`SinkTrySend::try_send`]
151
///
152
/// Provided for situations where providing a concrete error type is awkward.
153
///
154
/// `futures::channel::mpsc::Sender` wants this because when its `try_send` method fails,
155
/// it is not possible to extract both the sent item, and the error!
156
///
157
/// `tor_memquota::mq_queue::Sender` wants this because the types of the error return
158
/// from `its `try_send` would otherwise be tainted by complex generics,
159
/// including its private `Entry` type.
160
#[derive(Debug, Error, Clone, Deftly)]
161
#[derive_deftly(ErasedSinkTrySendError)]
162
#[allow(clippy::exhaustive_enums)] // Adding other variants would be a breaking change anyway
163
pub enum ErasedSinkTrySendError {
164
    /// The stream was full.
165
    ///
166
    /// *No* arrangements will have been made for a wakeup when space becomes available.
167
    ///
168
    /// Corresponds to [`SinkTrySendError::is_full`]
169
    #[error("stream full (backpressure)")]
170
    #[deftly(predicate)]
171
    Full,
172

            
173
    /// The stream has disconnected
174
    ///
175
    /// Corresponds to [`SinkTrySendError::is_disconnected`]
176
    #[error("stream disconnected")]
177
    #[deftly(predicate)]
178
    Disconnected,
179

            
180
    /// Something else went wrong
181
    #[error("failed to convey data")]
182
    Other(#[source] Arc<dyn Error + Send + Sync + 'static>),
183
}
184

            
185
//---------- impl for futures::channel::mpsc ----------
186

            
187
/// [`mpsc::Sender::try_send`] returned an uncategorisable error
188
///
189
/// Both `.full()` and `.disconnected()` returned `false`.
190
/// We could call [`mpsc::TrySendError::into_send_error`] but then we don't get the payload.
191
/// In the future, we might replace this type with a type alias for [`mpsc::SendError`].
192
///
193
/// When returned from `<mpsc::Sender::SinkTrySend::try_send`,
194
/// this is wrapped in [`ErasedSinkTrySendError::Other`].
195
#[derive(Debug, Error)]
196
#[error("mpsc::Sender::try_send returned an error which is neither .full() nor .disconnected()")]
197
#[non_exhaustive]
198
pub struct MpscOtherSinkTrySendError {}
199

            
200
impl<T> SinkTrySend<T> for mpsc::Sender<T> {
201
    // Ideally we would just use [`mpsc::SendError`].
202
    // But `mpsc::TrySendError` lacks an `into_parts` method that gives both `SendError` and `T`.
203
    type Error = ErasedSinkTrySendError;
204

            
205
112
    fn try_send_or_return(
206
112
        self: Pin<&mut Self>,
207
112
        item: T,
208
112
    ) -> Result<(), (ErasedSinkTrySendError, T)> {
209
112
        let self_: &mut Self = Pin::into_inner(self);
210
112
        mpsc::Sender::try_send(self_, item).map_err(handle_mpsc_error)
211
112
    }
212
}
213

            
214
#[cfg(test)]
215
mod test {
216
    // @@ begin test lint list maintained by maint/add_warning @@
217
    #![allow(clippy::bool_assert_comparison)]
218
    #![allow(clippy::clone_on_copy)]
219
    #![allow(clippy::dbg_macro)]
220
    #![allow(clippy::mixed_attributes_style)]
221
    #![allow(clippy::print_stderr)]
222
    #![allow(clippy::print_stdout)]
223
    #![allow(clippy::single_char_pattern)]
224
    #![allow(clippy::unwrap_used)]
225
    #![allow(clippy::unchecked_duration_subtraction)]
226
    #![allow(clippy::useless_vec)]
227
    #![allow(clippy::needless_pass_by_value)]
228
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
229
    #![allow(clippy::arithmetic_side_effects)] // don't mind potential panicking ops in tests
230
    #![allow(clippy::useless_format)] // srsly
231

            
232
    use super::*;
233
    use derive_deftly::derive_deftly_adhoc;
234
    use tor_error::ErrorReport as _;
235

            
236
    #[test]
237
    fn chk_erased_sink() {
238
        #[derive(Error, Clone, Debug, Deftly)]
239
        #[error("concrete {is_full} {is_disconnected}")]
240
        #[derive_deftly_adhoc]
241
        struct Concrete {
242
            is_full: bool,
243
            is_disconnected: bool,
244
        }
245

            
246
        derive_deftly_adhoc! {
247
            Concrete:
248

            
249
            impl SinkTrySendError for Concrete { $(
250
                fn $fname(&self) -> bool { self.$fname }
251
            ) }
252
        }
253

            
254
        for is_full in [false, true] {
255
            for is_disconnected in [false, true] {
256
                let c = Concrete {
257
                    is_full,
258
                    is_disconnected,
259
                };
260
                let e = ErasedSinkTrySendError::from(c.clone());
261
                let e2 = ErasedSinkTrySendError::from(e.clone());
262

            
263
                let cs = format!("concrete {is_full} {is_disconnected}");
264

            
265
                let es = if is_full {
266
                    format!("stream full (backpressure)")
267
                } else if is_disconnected {
268
                    format!("stream disconnected")
269
                } else {
270
                    format!("failed to convey data: {cs}")
271
                };
272

            
273
                assert_eq!(c.report().to_string(), format!("error: {cs}"));
274
                assert_eq!(e.report().to_string(), format!("error: {es}"));
275
                assert_eq!(e2.report().to_string(), format!("error: {es}"));
276
            }
277
        }
278
    }
279
}