tor_async_utils/
sinkext.rs

1//! Extension trait for `Sink`.
2
3use std::{
4    marker::PhantomData,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use futures::{ready, sink::Sink};
10use pin_project::pin_project;
11
12/// Extension trait for `Sink`
13pub trait SinkExt<Item>: Sink<Item> {
14    /// As `Sink::with`, but takes a function that returns an `Item` rather
15    /// than `Future<Output=Item>`.
16    fn with_fn<F, T, E>(self, func: F) -> WithFn<Self, F, T, E>
17    // or error?
18    where
19        Self: Sized,
20        F: FnMut(T) -> Result<Item, E>,
21        E: From<Self::Error>;
22}
23
24impl<Item, S> SinkExt<Item> for S
25where
26    S: Sink<Item>,
27{
28    fn with_fn<F, T, E>(self, func: F) -> WithFn<Self, F, T, E>
29    where
30        Self: Sized,
31        F: FnMut(T) -> Result<Item, E>,
32        E: From<Self::Error>,
33    {
34        WithFn {
35            sink: self,
36            func,
37            _phantom: PhantomData,
38        }
39    }
40}
41
42/// Sink returned by [`SinkExt::with_fn`].
43#[pin_project]
44pub struct WithFn<S, F, T, E> {
45    /// The underlying sink
46    #[pin]
47    sink: S,
48    /// The user-provided function.
49    func: F,
50    /// Phantom data to ensure type consistency.
51    _phantom: PhantomData<fn() -> Result<T, E>>,
52}
53
54impl<S, Item, F, T, E> Sink<T> for WithFn<S, F, T, E>
55where
56    S: Sink<Item>,
57    F: FnMut(T) -> Result<Item, E>,
58    E: From<S::Error>,
59{
60    type Error = E;
61
62    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63        ready!(self.project().sink.poll_ready(cx))?;
64        Poll::Ready(Ok(()))
65    }
66
67    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
68        ready!(self.project().sink.poll_flush(cx))?;
69        Poll::Ready(Ok(()))
70    }
71
72    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73        ready!(self.project().sink.poll_close(cx))?;
74        Poll::Ready(Ok(()))
75    }
76
77    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
78        let this = self.project();
79        let item = (this.func)(item)?;
80        this.sink.start_send(item).map_err(E::from)
81    }
82}