tor_async_utils/
sink_try_send.rs

1//! [`SinkTrySend`]
2
3use std::error::Error;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use futures::channel::mpsc;
8use futures::Sink;
9
10use derive_deftly::{define_derive_deftly, Deftly};
11use thiserror::Error;
12
13//---------- principal API ----------
14
15/// A [`Sink`] with a `try_send` method like [`futures::channel::mpsc::Sender`'s]
16pub trait SinkTrySend<T>: Sink<T> {
17    /// Errors that is not disconnected, or full
18    type Error: SinkTrySendError;
19
20    /// Try to send a message `msg`
21    ///
22    /// If this returns with an error indicating that the stream is full,
23    /// *No* arrangements will have been made for a wakeup when space becomes available.
24    ///
25    /// If the send fails, `item` is dropped.
26    /// If you need it back, use [`try_send_or_return`](SinkTrySend::try_send_or_return),
27    ///
28    /// (When implementing the trait, implement `try_send_or_return`, *not* this method.)
29    fn try_send(self: Pin<&mut Self>, item: T) -> Result<(), <Self as SinkTrySend<T>>::Error> {
30        self.try_send_or_return(item)
31            .map_err(|(error, _item)| error)
32    }
33
34    /// Try to send a message `msg`
35    ///
36    /// Like [`try_send`](SinkTrySend::try_send),
37    /// but if the send fails, the item is returned.
38    ///
39    /// (When implementing the trait, implement this method.)
40    fn try_send_or_return(
41        self: Pin<&mut Self>,
42        item: T,
43    ) -> Result<(), (<Self as SinkTrySend<T>>::Error, T)>;
44}
45
46/// Error from [`SinkTrySend::try_send`]
47///
48/// See also [`ErasedSinkTrySendError`] which can often
49/// be usefully used when an implementation of `SinkTrySendError` is needed.
50pub trait SinkTrySendError: Error + 'static {
51    /// The stream was full.
52    ///
53    /// *No* arrangements will have been made for a wakeup when space becomes available.
54    ///
55    /// Corresponds to [`futures::channel::mpsc::TrySendError::is_full`]
56    fn is_full(&self) -> bool;
57
58    /// The stream has disconnected
59    ///
60    /// Corresponds to [`futures::channel::mpsc::TrySendError::is_disconnected`]
61    fn is_disconnected(&self) -> bool;
62}
63
64//---------- macrology - this has to come here, ideally all in one go ----------
65
66#[rustfmt::skip] // rustfmt makes a complete hash of this
67define_derive_deftly! {
68    /// Implements various things which handle `full` and `disconnected`
69    ///
70    /// # Generates
71    ///
72    ///  * `SinkTrySendError for`ErasedSinkTrySendError`
73    ///  * `From<E: SinkTrySendError> for`ErasedSinkTrySendError`
74    ///  * [`handle_mpsc_error`]
75    ///
76    /// Use of macros avoids copypaste errors like
77    /// `fn is_full(..) { self.is_disconnected() }`.
78    ErasedSinkTrySendError expect items:
79
80    ${defcond PREDICATE vmeta(predicate)}
81    ${define PREDICATE { $<is_ ${snake_case $vname}> }}
82
83    impl SinkTrySendError for ErasedSinkTrySendError {
84        $(
85            ${when PREDICATE}
86
87            fn $PREDICATE(&self) -> bool {
88                matches!(self, $vtype)
89            }
90        )
91    }
92
93    impl ErasedSinkTrySendError {
94        /// Obtain an `ErasedSinkTrySendError` from a concrete `SinkTrySendError`
95        //
96        // (Can't be a `From` impl because it conflicts with the identity `From<T> for T`.)
97        pub fn from<E>(e: E) -> ErasedSinkTrySendError
98        where E: SinkTrySendError + Send + Sync
99        {
100            $(
101                ${when PREDICATE}
102                if e.$PREDICATE() {
103                    $vtype
104                } else
105            )
106                /* else */ {
107                    let e = Arc::new(e);
108                    // Avoid generating a nested ErasedSinkTrySendError.
109                    // Is it *already* an ESTSE (necessarily, then, an `Other`?)
110                    //
111                    // TODO replace this with a call to `downcast_value` from arti!2460
112                    let e2 = e.clone();
113                    match Arc::downcast(e2) {
114                        Ok::<Arc<ErasedSinkTrySendError>, _>(y2) => {
115                            drop(e); // Drop the original
116                            let inner: ErasedSinkTrySendError =
117                                Arc::into_inner(y2).expect(
118              "somehow we weren't the only owner, despite us just having made an Arc!"
119                                );
120                            return inner;
121                        }
122                        Err(other_e2) => {
123                            drop(other_e2);
124                            // We need to use e, not other_e2, because Arc::downcast
125                            // returns dyn Any but we need dyn SinkTrySendError.
126                            ErasedSinkTrySendError::Other(e)
127                        },
128                    }
129                }
130        }
131    }
132
133    fn handle_mpsc_error<T>(me: mpsc::TrySendError<T>) -> (ErasedSinkTrySendError, T) {
134        let error = $(
135            ${when PREDICATE}
136
137            if me.$PREDICATE() {
138                $vtype
139            } else
140        )
141            /* else */ {
142                $ttype::Other(Arc::new(MpscOtherSinkTrySendError {}))
143            };
144        (error, me.into_inner())
145    }
146}
147
148//---------- helper - erased error ----------
149
150/// Type-erased error for [`SinkTrySend::try_send`]
151///
152/// Provided for situations where providing a concrete error type is awkward.
153///
154/// `futures::channel::mpsc::Sender` wants this because when its `try_send` method fails,
155/// it is not possible to extract both the sent item, and the error!
156///
157/// `tor_memquota::mq_queue::Sender` wants this because the types of the error return
158/// from `its `try_send` would otherwise be tainted by complex generics,
159/// including its private `Entry` type.
160#[derive(Debug, Error, Clone, Deftly)]
161#[derive_deftly(ErasedSinkTrySendError)]
162#[allow(clippy::exhaustive_enums)] // Adding other variants would be a breaking change anyway
163pub enum ErasedSinkTrySendError {
164    /// The stream was full.
165    ///
166    /// *No* arrangements will have been made for a wakeup when space becomes available.
167    ///
168    /// Corresponds to [`SinkTrySendError::is_full`]
169    #[error("stream full (backpressure)")]
170    #[deftly(predicate)]
171    Full,
172
173    /// The stream has disconnected
174    ///
175    /// Corresponds to [`SinkTrySendError::is_disconnected`]
176    #[error("stream disconnected")]
177    #[deftly(predicate)]
178    Disconnected,
179
180    /// Something else went wrong
181    #[error("failed to convey data")]
182    Other(#[source] Arc<dyn Error + Send + Sync + 'static>),
183}
184
185//---------- impl for futures::channel::mpsc ----------
186
187/// [`mpsc::Sender::try_send`] returned an uncategorisable error
188///
189/// Both `.full()` and `.disconnected()` returned `false`.
190/// We could call [`mpsc::TrySendError::into_send_error`] but then we don't get the payload.
191/// In the future, we might replace this type with a type alias for [`mpsc::SendError`].
192///
193/// When returned from `<mpsc::Sender::SinkTrySend::try_send`,
194/// this is wrapped in [`ErasedSinkTrySendError::Other`].
195#[derive(Debug, Error)]
196#[error("mpsc::Sender::try_send returned an error which is neither .full() nor .disconnected()")]
197#[non_exhaustive]
198pub struct MpscOtherSinkTrySendError {}
199
200impl<T> SinkTrySend<T> for mpsc::Sender<T> {
201    // Ideally we would just use [`mpsc::SendError`].
202    // But `mpsc::TrySendError` lacks an `into_parts` method that gives both `SendError` and `T`.
203    type Error = ErasedSinkTrySendError;
204
205    fn try_send_or_return(
206        self: Pin<&mut Self>,
207        item: T,
208    ) -> Result<(), (ErasedSinkTrySendError, T)> {
209        let self_: &mut Self = Pin::into_inner(self);
210        mpsc::Sender::try_send(self_, item).map_err(handle_mpsc_error)
211    }
212}
213
214#[cfg(test)]
215mod test {
216    // @@ begin test lint list maintained by maint/add_warning @@
217    #![allow(clippy::bool_assert_comparison)]
218    #![allow(clippy::clone_on_copy)]
219    #![allow(clippy::dbg_macro)]
220    #![allow(clippy::mixed_attributes_style)]
221    #![allow(clippy::print_stderr)]
222    #![allow(clippy::print_stdout)]
223    #![allow(clippy::single_char_pattern)]
224    #![allow(clippy::unwrap_used)]
225    #![allow(clippy::unchecked_duration_subtraction)]
226    #![allow(clippy::useless_vec)]
227    #![allow(clippy::needless_pass_by_value)]
228    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
229    #![allow(clippy::arithmetic_side_effects)] // don't mind potential panicking ops in tests
230    #![allow(clippy::useless_format)] // srsly
231
232    use super::*;
233    use derive_deftly::derive_deftly_adhoc;
234    use tor_error::ErrorReport as _;
235
236    #[test]
237    fn chk_erased_sink() {
238        #[derive(Error, Clone, Debug, Deftly)]
239        #[error("concrete {is_full} {is_disconnected}")]
240        #[derive_deftly_adhoc]
241        struct Concrete {
242            is_full: bool,
243            is_disconnected: bool,
244        }
245
246        derive_deftly_adhoc! {
247            Concrete:
248
249            impl SinkTrySendError for Concrete { $(
250                fn $fname(&self) -> bool { self.$fname }
251            ) }
252        }
253
254        for is_full in [false, true] {
255            for is_disconnected in [false, true] {
256                let c = Concrete {
257                    is_full,
258                    is_disconnected,
259                };
260                let e = ErasedSinkTrySendError::from(c.clone());
261                let e2 = ErasedSinkTrySendError::from(e.clone());
262
263                let cs = format!("concrete {is_full} {is_disconnected}");
264
265                let es = if is_full {
266                    format!("stream full (backpressure)")
267                } else if is_disconnected {
268                    format!("stream disconnected")
269                } else {
270                    format!("failed to convey data: {cs}")
271                };
272
273                assert_eq!(c.report().to_string(), format!("error: {cs}"));
274                assert_eq!(e.report().to_string(), format!("error: {es}"));
275                assert_eq!(e2.report().to_string(), format!("error: {es}"));
276            }
277        }
278    }
279}