tor_rtcompat/impls/
tokio.rs

1//! Re-exports of the tokio runtime for use with arti.
2//!
3//! This crate helps define a slim API around our async runtime so that we
4//! can easily swap it out.
5
6/// Types used for networking (tokio implementation)
7pub(crate) mod net {
8    use crate::{impls, traits};
9    use async_trait::async_trait;
10    use tor_general_addr::unix;
11
12    pub(crate) use tokio_crate::net::{
13        TcpListener as TokioTcpListener, TcpStream as TokioTcpStream, UdpSocket as TokioUdpSocket,
14    };
15    #[cfg(unix)]
16    pub(crate) use tokio_crate::net::{
17        UnixListener as TokioUnixListener, UnixStream as TokioUnixStream,
18    };
19
20    use futures::io::{AsyncRead, AsyncWrite};
21    use paste::paste;
22    use tokio_util::compat::{Compat, TokioAsyncReadCompatExt as _};
23
24    use std::io::Result as IoResult;
25    use std::net::SocketAddr;
26    use std::pin::Pin;
27    use std::task::{Context, Poll};
28
29    /// Provide a set of network stream wrappers for a single stream type.
30    macro_rules! stream_impl {
31        {
32            $kind:ident,
33            $addr:ty,
34            $cvt_addr:ident
35        } => {paste!{
36            /// Wrapper for Tokio's
37            #[doc = stringify!($kind)]
38            /// streams,
39            /// that implements the standard
40            /// AsyncRead and AsyncWrite.
41            pub struct [<$kind Stream>] {
42                /// Underlying tokio_util::compat::Compat wrapper.
43                s: Compat<[<Tokio $kind Stream>]>,
44            }
45            impl From<[<Tokio $kind Stream>]> for [<$kind Stream>] {
46                fn from(s: [<Tokio $kind Stream>]) ->  [<$kind Stream>] {
47                    let s = s.compat();
48                    [<$kind Stream>] { s }
49                }
50            }
51            impl AsyncRead for  [<$kind Stream>] {
52                fn poll_read(
53                    mut self: Pin<&mut Self>,
54                    cx: &mut Context<'_>,
55                    buf: &mut [u8],
56                ) -> Poll<IoResult<usize>> {
57                    Pin::new(&mut self.s).poll_read(cx, buf)
58                }
59            }
60            impl AsyncWrite for  [<$kind Stream>] {
61                fn poll_write(
62                    mut self: Pin<&mut Self>,
63                    cx: &mut Context<'_>,
64                    buf: &[u8],
65                ) -> Poll<IoResult<usize>> {
66                    Pin::new(&mut self.s).poll_write(cx, buf)
67                }
68                fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
69                    Pin::new(&mut self.s).poll_flush(cx)
70                }
71                fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
72                    Pin::new(&mut self.s).poll_close(cx)
73                }
74            }
75
76            /// Wrap a Tokio
77            #[doc = stringify!($kind)]
78            /// Listener to behave as a futures::io::TcpListener.
79            pub struct [<$kind Listener>] {
80                /// The underlying listener.
81                pub(super) lis: [<Tokio $kind Listener>],
82            }
83
84            /// Asynchronous stream that yields incoming connections from a
85            #[doc = stringify!($kind)]
86            /// Listener.
87            ///
88            /// This is analogous to async_std::net::Incoming.
89            pub struct [<Incoming $kind Streams>] {
90                /// Reference to the underlying listener.
91                pub(super) lis: [<Tokio $kind Listener>],
92            }
93
94            impl futures::stream::Stream for [<Incoming $kind Streams>] {
95                type Item = IoResult<([<$kind Stream>], $addr)>;
96
97                fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
98                    match self.lis.poll_accept(cx) {
99                        Poll::Ready(Ok((s, a))) => Poll::Ready(Some(Ok((s.into(), $cvt_addr(a)? )))),
100                        Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
101                        Poll::Pending => Poll::Pending,
102                    }
103                }
104            }
105            impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
106                type Stream = [<$kind Stream>];
107                type Incoming = [<Incoming $kind Streams>];
108                fn incoming(self) -> Self::Incoming {
109                    [<Incoming $kind Streams>] { lis: self.lis }
110                }
111                fn local_addr(&self) -> IoResult<$addr> {
112                    $cvt_addr(self.lis.local_addr()?)
113                }
114            }
115        }}
116    }
117
118    /// Try to convert a tokio `unix::SocketAddr` into a crate::SocketAddr.
119    ///
120    /// Frustratingly, this information is _right there_: Tokio's SocketAddr has a
121    /// std::unix::net::SocketAddr internally, but there appears to be no way to get it out.
122    #[cfg(unix)]
123    #[allow(clippy::needless_pass_by_value)]
124    fn try_cvt_tokio_unix_addr(
125        addr: tokio_crate::net::unix::SocketAddr,
126    ) -> IoResult<unix::SocketAddr> {
127        if addr.is_unnamed() {
128            crate::unix::new_unnamed_socketaddr()
129        } else if let Some(p) = addr.as_pathname() {
130            unix::SocketAddr::from_pathname(p)
131        } else {
132            Err(crate::unix::UnsupportedAfUnixAddressType.into())
133        }
134    }
135
136    /// Wrapper for (not) converting std::net::SocketAddr to itself.
137    #[allow(clippy::unnecessary_wraps)]
138    fn identity_fn_socketaddr(addr: std::net::SocketAddr) -> IoResult<std::net::SocketAddr> {
139        Ok(addr)
140    }
141
142    stream_impl! { Tcp, std::net::SocketAddr, identity_fn_socketaddr }
143    #[cfg(unix)]
144    stream_impl! { Unix, unix::SocketAddr, try_cvt_tokio_unix_addr }
145
146    /// Wrap a Tokio UdpSocket
147    pub struct UdpSocket {
148        /// The underelying UdpSocket
149        socket: TokioUdpSocket,
150    }
151
152    impl UdpSocket {
153        /// Bind a UdpSocket
154        pub async fn bind(addr: SocketAddr) -> IoResult<Self> {
155            TokioUdpSocket::bind(addr)
156                .await
157                .map(|socket| UdpSocket { socket })
158        }
159    }
160
161    #[async_trait]
162    impl traits::UdpSocket for UdpSocket {
163        async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
164            self.socket.recv_from(buf).await
165        }
166
167        async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
168            self.socket.send_to(buf, target).await
169        }
170
171        fn local_addr(&self) -> IoResult<SocketAddr> {
172            self.socket.local_addr()
173        }
174    }
175
176    impl traits::StreamOps for TcpStream {
177        fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
178            impls::streamops::set_tcp_notsent_lowat(&self.s, notsent_lowat)
179        }
180
181        #[cfg(target_os = "linux")]
182        fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
183            Box::new(impls::streamops::TcpSockFd::from_fd(&self.s))
184        }
185    }
186
187    #[cfg(unix)]
188    impl traits::StreamOps for UnixStream {
189        fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
190            Err(traits::UnsupportedStreamOp::new(
191                "set_tcp_notsent_lowat",
192                "unsupported on Unix streams",
193            )
194            .into())
195        }
196    }
197}
198
199// ==============================
200
201use crate::traits::*;
202use async_trait::async_trait;
203use futures::Future;
204use std::io::Result as IoResult;
205use std::time::Duration;
206use tor_general_addr::unix;
207
208impl SleepProvider for TokioRuntimeHandle {
209    type SleepFuture = tokio_crate::time::Sleep;
210    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
211        tokio_crate::time::sleep(duration)
212    }
213}
214
215#[async_trait]
216impl crate::traits::NetStreamProvider for TokioRuntimeHandle {
217    type Stream = net::TcpStream;
218    type Listener = net::TcpListener;
219
220    async fn connect(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Stream> {
221        let s = net::TokioTcpStream::connect(addr).await?;
222        Ok(s.into())
223    }
224    async fn listen(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Listener> {
225        let lis = net::TokioTcpListener::bind(*addr).await?;
226        Ok(net::TcpListener { lis })
227    }
228}
229
230#[cfg(unix)]
231#[async_trait]
232impl crate::traits::NetStreamProvider<unix::SocketAddr> for TokioRuntimeHandle {
233    type Stream = net::UnixStream;
234    type Listener = net::UnixListener;
235
236    async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
237        let path = addr
238            .as_pathname()
239            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
240        let s = net::TokioUnixStream::connect(path).await?;
241        Ok(s.into())
242    }
243    async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
244        let path = addr
245            .as_pathname()
246            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
247        let lis = net::TokioUnixListener::bind(path)?;
248        Ok(net::UnixListener { lis })
249    }
250}
251
252#[cfg(not(unix))]
253crate::impls::impl_unix_non_provider! { TokioRuntimeHandle }
254
255#[async_trait]
256impl crate::traits::UdpProvider for TokioRuntimeHandle {
257    type UdpSocket = net::UdpSocket;
258
259    async fn bind(&self, addr: &std::net::SocketAddr) -> IoResult<Self::UdpSocket> {
260        net::UdpSocket::bind(*addr).await
261    }
262}
263
264/// Create and return a new Tokio multithreaded runtime.
265pub(crate) fn create_runtime() -> IoResult<TokioRuntimeHandle> {
266    let runtime = async_executors::exec::TokioTp::new().map_err(std::io::Error::other)?;
267    Ok(runtime.into())
268}
269
270/// Wrapper around a Handle to a tokio runtime.
271///
272/// Ideally, this type would go away, and we would just use
273/// `tokio::runtime::Handle` directly.  Unfortunately, we can't implement
274/// `futures::Spawn` on it ourselves because of Rust's orphan rules, so we need
275/// to define a new type here.
276///
277/// # Limitations
278///
279/// Note that Arti requires that the runtime should have working implementations
280/// for Tokio's time, net, and io facilities, but we have no good way to check
281/// that when creating this object.
282#[derive(Clone, Debug)]
283pub struct TokioRuntimeHandle {
284    /// If present, the tokio executor that we've created (and which we own).
285    ///
286    /// We never access this directly; only through `handle`.  We keep it here
287    /// so that our Runtime types can be agnostic about whether they own the
288    /// executor.
289    owned: Option<async_executors::TokioTp>,
290    /// The underlying Handle.
291    handle: tokio_crate::runtime::Handle,
292}
293
294impl TokioRuntimeHandle {
295    /// Wrap a tokio runtime handle into a format that Arti can use.
296    ///
297    /// # Limitations
298    ///
299    /// Note that Arti requires that the runtime should have working
300    /// implementations for Tokio's time, net, and io facilities, but we have
301    /// no good way to check that when creating this object.
302    pub(crate) fn new(handle: tokio_crate::runtime::Handle) -> Self {
303        handle.into()
304    }
305
306    /// Return true if this handle owns the executor that it points to.
307    pub fn is_owned(&self) -> bool {
308        self.owned.is_some()
309    }
310}
311
312impl From<tokio_crate::runtime::Handle> for TokioRuntimeHandle {
313    fn from(handle: tokio_crate::runtime::Handle) -> Self {
314        Self {
315            owned: None,
316            handle,
317        }
318    }
319}
320
321impl From<async_executors::TokioTp> for TokioRuntimeHandle {
322    fn from(owner: async_executors::TokioTp) -> TokioRuntimeHandle {
323        let handle = owner.block_on(async { tokio_crate::runtime::Handle::current() });
324        Self {
325            owned: Some(owner),
326            handle,
327        }
328    }
329}
330
331impl ToplevelBlockOn for TokioRuntimeHandle {
332    #[track_caller]
333    fn block_on<F: Future>(&self, f: F) -> F::Output {
334        self.handle.block_on(f)
335    }
336}
337
338impl Blocking for TokioRuntimeHandle {
339    type ThreadHandle<T: Send + 'static> = async_executors::BlockingHandle<T>;
340
341    #[track_caller]
342    fn spawn_blocking<F, T>(&self, f: F) -> async_executors::BlockingHandle<T>
343    where
344        F: FnOnce() -> T + Send + 'static,
345        T: Send + 'static,
346    {
347        async_executors::BlockingHandle::tokio(self.handle.spawn_blocking(f))
348    }
349
350    #[track_caller]
351    fn reenter_block_on<F: Future>(&self, future: F) -> F::Output {
352        self.handle.block_on(future)
353    }
354
355    #[track_caller]
356    fn blocking_io<F, T>(&self, f: F) -> impl Future<Output = T>
357    where
358        F: FnOnce() -> T + Send + 'static,
359        T: Send + 'static,
360    {
361        let r = tokio_crate::task::block_in_place(f);
362        std::future::ready(r)
363    }
364}
365
366impl futures::task::Spawn for TokioRuntimeHandle {
367    #[track_caller]
368    fn spawn_obj(
369        &self,
370        future: futures::task::FutureObj<'static, ()>,
371    ) -> Result<(), futures::task::SpawnError> {
372        let join_handle = self.handle.spawn(future);
373        drop(join_handle); // this makes the task detached.
374        Ok(())
375    }
376}