tor_rtcompat/impls/
tokio.rs

1//! Re-exports of the tokio runtime for use with arti.
2//!
3//! This crate helps define a slim API around our async runtime so that we
4//! can easily swap it out.
5
6/// Types used for networking (tokio implementation)
7pub(crate) mod net {
8    use crate::{impls, traits};
9    use async_trait::async_trait;
10    use tor_general_addr::unix;
11
12    pub(crate) use tokio_crate::net::{
13        TcpListener as TokioTcpListener, TcpStream as TokioTcpStream, UdpSocket as TokioUdpSocket,
14    };
15    #[cfg(unix)]
16    pub(crate) use tokio_crate::net::{
17        UnixListener as TokioUnixListener, UnixStream as TokioUnixStream,
18    };
19
20    use futures::io::{AsyncRead, AsyncWrite};
21    use paste::paste;
22    use tokio_util::compat::{Compat, TokioAsyncReadCompatExt as _};
23
24    use std::io::Result as IoResult;
25    use std::net::SocketAddr;
26    use std::pin::Pin;
27    use std::task::{Context, Poll};
28
29    /// Provide a set of network stream wrappers for a single stream type.
30    macro_rules! stream_impl {
31        {
32            $kind:ident,
33            $addr:ty,
34            $cvt_addr:ident
35        } => {paste!{
36            /// Wrapper for Tokio's
37            #[doc = stringify!($kind)]
38            /// streams,
39            /// that implements the standard
40            /// AsyncRead and AsyncWrite.
41            pub struct [<$kind Stream>] {
42                /// Underlying tokio_util::compat::Compat wrapper.
43                s: Compat<[<Tokio $kind Stream>]>,
44            }
45            impl From<[<Tokio $kind Stream>]> for [<$kind Stream>] {
46                fn from(s: [<Tokio $kind Stream>]) ->  [<$kind Stream>] {
47                    let s = s.compat();
48                    [<$kind Stream>] { s }
49                }
50            }
51            impl AsyncRead for  [<$kind Stream>] {
52                fn poll_read(
53                    mut self: Pin<&mut Self>,
54                    cx: &mut Context<'_>,
55                    buf: &mut [u8],
56                ) -> Poll<IoResult<usize>> {
57                    Pin::new(&mut self.s).poll_read(cx, buf)
58                }
59            }
60            impl AsyncWrite for  [<$kind Stream>] {
61                fn poll_write(
62                    mut self: Pin<&mut Self>,
63                    cx: &mut Context<'_>,
64                    buf: &[u8],
65                ) -> Poll<IoResult<usize>> {
66                    Pin::new(&mut self.s).poll_write(cx, buf)
67                }
68                fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
69                    Pin::new(&mut self.s).poll_flush(cx)
70                }
71                fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
72                    Pin::new(&mut self.s).poll_close(cx)
73                }
74            }
75
76            /// Wrap a Tokio
77            #[doc = stringify!($kind)]
78            /// Listener to behave as a futures::io::TcpListener.
79            pub struct [<$kind Listener>] {
80                /// The underlying listener.
81                pub(super) lis: [<Tokio $kind Listener>],
82            }
83
84            /// Asynchronous stream that yields incoming connections from a
85            #[doc = stringify!($kind)]
86            /// Listener.
87            ///
88            /// This is analogous to async_std::net::Incoming.
89            pub struct [<Incoming $kind Streams>] {
90                /// Reference to the underlying listener.
91                pub(super) lis: [<Tokio $kind Listener>],
92            }
93
94            impl futures::stream::Stream for [<Incoming $kind Streams>] {
95                type Item = IoResult<([<$kind Stream>], $addr)>;
96
97                fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
98                    match self.lis.poll_accept(cx) {
99                        Poll::Ready(Ok((s, a))) => Poll::Ready(Some(Ok((s.into(), $cvt_addr(a)? )))),
100                        Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
101                        Poll::Pending => Poll::Pending,
102                    }
103                }
104            }
105            impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
106                type Stream = [<$kind Stream>];
107                type Incoming = [<Incoming $kind Streams>];
108                fn incoming(self) -> Self::Incoming {
109                    [<Incoming $kind Streams>] { lis: self.lis }
110                }
111                fn local_addr(&self) -> IoResult<$addr> {
112                    $cvt_addr(self.lis.local_addr()?)
113                }
114            }
115        }}
116    }
117
118    /// Try to convert a tokio `unix::SocketAddr` into a crate::SocketAddr.
119    ///
120    /// Frustratingly, this information is _right there_: Tokio's SocketAddr has a
121    /// std::unix::net::SocketAddr internally, but there appears to be no way to get it out.
122    #[cfg(unix)]
123    #[allow(clippy::needless_pass_by_value)]
124    fn try_cvt_tokio_unix_addr(
125        addr: tokio_crate::net::unix::SocketAddr,
126    ) -> IoResult<unix::SocketAddr> {
127        if addr.is_unnamed() {
128            crate::unix::new_unnamed_socketaddr()
129        } else if let Some(p) = addr.as_pathname() {
130            unix::SocketAddr::from_pathname(p)
131        } else {
132            Err(crate::unix::UnsupportedAfUnixAddressType.into())
133        }
134    }
135
136    /// Wrapper for (not) converting std::net::SocketAddr to itself.
137    #[allow(clippy::unnecessary_wraps)]
138    fn identity_fn_socketaddr(addr: std::net::SocketAddr) -> IoResult<std::net::SocketAddr> {
139        Ok(addr)
140    }
141
142    stream_impl! { Tcp, std::net::SocketAddr, identity_fn_socketaddr }
143    #[cfg(unix)]
144    stream_impl! { Unix, unix::SocketAddr, try_cvt_tokio_unix_addr }
145
146    /// Wrap a Tokio UdpSocket
147    pub struct UdpSocket {
148        /// The underelying UdpSocket
149        socket: TokioUdpSocket,
150    }
151
152    impl UdpSocket {
153        /// Bind a UdpSocket
154        pub async fn bind(addr: SocketAddr) -> IoResult<Self> {
155            TokioUdpSocket::bind(addr)
156                .await
157                .map(|socket| UdpSocket { socket })
158        }
159    }
160
161    #[async_trait]
162    impl traits::UdpSocket for UdpSocket {
163        async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
164            self.socket.recv_from(buf).await
165        }
166
167        async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
168            self.socket.send_to(buf, target).await
169        }
170
171        fn local_addr(&self) -> IoResult<SocketAddr> {
172            self.socket.local_addr()
173        }
174    }
175
176    impl traits::StreamOps for TcpStream {
177        fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
178            impls::streamops::set_tcp_notsent_lowat(&self.s, notsent_lowat)
179        }
180
181        #[cfg(target_os = "linux")]
182        fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
183            Box::new(impls::streamops::TcpSockFd::from_fd(&self.s))
184        }
185    }
186
187    #[cfg(unix)]
188    impl traits::StreamOps for UnixStream {
189        fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
190            Err(traits::UnsupportedStreamOp::new(
191                "set_tcp_notsent_lowat",
192                "unsupported on Unix streams",
193            )
194            .into())
195        }
196    }
197}
198
199// ==============================
200
201use crate::traits::*;
202use async_trait::async_trait;
203use futures::Future;
204use std::io::Result as IoResult;
205use std::time::Duration;
206use tor_general_addr::unix;
207
208impl SleepProvider for TokioRuntimeHandle {
209    type SleepFuture = tokio_crate::time::Sleep;
210    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
211        tokio_crate::time::sleep(duration)
212    }
213}
214
215#[async_trait]
216impl crate::traits::NetStreamProvider for TokioRuntimeHandle {
217    type Stream = net::TcpStream;
218    type Listener = net::TcpListener;
219
220    async fn connect(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Stream> {
221        let s = net::TokioTcpStream::connect(addr).await?;
222        Ok(s.into())
223    }
224    async fn listen(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Listener> {
225        let lis = net::TokioTcpListener::bind(*addr).await?;
226        Ok(net::TcpListener { lis })
227    }
228}
229
230#[cfg(unix)]
231#[async_trait]
232impl crate::traits::NetStreamProvider<unix::SocketAddr> for TokioRuntimeHandle {
233    type Stream = net::UnixStream;
234    type Listener = net::UnixListener;
235
236    async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
237        let path = addr
238            .as_pathname()
239            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
240        let s = net::TokioUnixStream::connect(path).await?;
241        Ok(s.into())
242    }
243    async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
244        let path = addr
245            .as_pathname()
246            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
247        let lis = net::TokioUnixListener::bind(path)?;
248        Ok(net::UnixListener { lis })
249    }
250}
251
252#[cfg(not(unix))]
253crate::impls::impl_unix_non_provider! { TokioRuntimeHandle }
254
255#[async_trait]
256impl crate::traits::UdpProvider for TokioRuntimeHandle {
257    type UdpSocket = net::UdpSocket;
258
259    async fn bind(&self, addr: &std::net::SocketAddr) -> IoResult<Self::UdpSocket> {
260        net::UdpSocket::bind(*addr).await
261    }
262}
263
264/// Create and return a new Tokio multithreaded runtime.
265pub(crate) fn create_runtime() -> IoResult<TokioRuntimeHandle> {
266    let runtime = async_executors::exec::TokioTp::new()
267        .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
268    Ok(runtime.into())
269}
270
271/// Wrapper around a Handle to a tokio runtime.
272///
273/// Ideally, this type would go away, and we would just use
274/// `tokio::runtime::Handle` directly.  Unfortunately, we can't implement
275/// `futures::Spawn` on it ourselves because of Rust's orphan rules, so we need
276/// to define a new type here.
277///
278/// # Limitations
279///
280/// Note that Arti requires that the runtime should have working implementations
281/// for Tokio's time, net, and io facilities, but we have no good way to check
282/// that when creating this object.
283#[derive(Clone, Debug)]
284pub struct TokioRuntimeHandle {
285    /// If present, the tokio executor that we've created (and which we own).
286    ///
287    /// We never access this directly; only through `handle`.  We keep it here
288    /// so that our Runtime types can be agnostic about whether they own the
289    /// executor.
290    owned: Option<async_executors::TokioTp>,
291    /// The underlying Handle.
292    handle: tokio_crate::runtime::Handle,
293}
294
295impl TokioRuntimeHandle {
296    /// Wrap a tokio runtime handle into a format that Arti can use.
297    ///
298    /// # Limitations
299    ///
300    /// Note that Arti requires that the runtime should have working
301    /// implementations for Tokio's time, net, and io facilities, but we have
302    /// no good way to check that when creating this object.
303    pub(crate) fn new(handle: tokio_crate::runtime::Handle) -> Self {
304        handle.into()
305    }
306
307    /// Return true if this handle owns the executor that it points to.
308    pub fn is_owned(&self) -> bool {
309        self.owned.is_some()
310    }
311}
312
313impl From<tokio_crate::runtime::Handle> for TokioRuntimeHandle {
314    fn from(handle: tokio_crate::runtime::Handle) -> Self {
315        Self {
316            owned: None,
317            handle,
318        }
319    }
320}
321
322impl From<async_executors::TokioTp> for TokioRuntimeHandle {
323    fn from(owner: async_executors::TokioTp) -> TokioRuntimeHandle {
324        let handle = owner.block_on(async { tokio_crate::runtime::Handle::current() });
325        Self {
326            owned: Some(owner),
327            handle,
328        }
329    }
330}
331
332impl ToplevelBlockOn for TokioRuntimeHandle {
333    #[track_caller]
334    fn block_on<F: Future>(&self, f: F) -> F::Output {
335        self.handle.block_on(f)
336    }
337}
338
339impl Blocking for TokioRuntimeHandle {
340    type ThreadHandle<T: Send + 'static> = async_executors::BlockingHandle<T>;
341
342    #[track_caller]
343    fn spawn_blocking<F, T>(&self, f: F) -> async_executors::BlockingHandle<T>
344    where
345        F: FnOnce() -> T + Send + 'static,
346        T: Send + 'static,
347    {
348        async_executors::BlockingHandle::tokio(self.handle.spawn_blocking(f))
349    }
350
351    #[track_caller]
352    fn reenter_block_on<F: Future>(&self, future: F) -> F::Output {
353        self.handle.block_on(future)
354    }
355
356    #[track_caller]
357    fn blocking_io<F, T>(&self, f: F) -> impl Future<Output = T>
358    where
359        F: FnOnce() -> T + Send + 'static,
360        T: Send + 'static,
361    {
362        let r = tokio_crate::task::block_in_place(f);
363        std::future::ready(r)
364    }
365}
366
367impl futures::task::Spawn for TokioRuntimeHandle {
368    #[track_caller]
369    fn spawn_obj(
370        &self,
371        future: futures::task::FutureObj<'static, ()>,
372    ) -> Result<(), futures::task::SpawnError> {
373        let join_handle = self.handle.spawn(future);
374        drop(join_handle); // this makes the task detached.
375        Ok(())
376    }
377}