1
//! Re-exports of the tokio runtime for use with arti.
2
//!
3
//! This crate helps define a slim API around our async runtime so that we
4
//! can easily swap it out.
5

            
6
/// Types used for networking (tokio implementation)
7
pub(crate) mod net {
8
    use crate::{impls, traits};
9
    use async_trait::async_trait;
10
    #[cfg(unix)]
11
    use tor_general_addr::unix;
12

            
13
    pub(crate) use tokio_crate::net::{
14
        TcpListener as TokioTcpListener, TcpStream as TokioTcpStream, UdpSocket as TokioUdpSocket,
15
    };
16
    #[cfg(unix)]
17
    pub(crate) use tokio_crate::net::{
18
        UnixListener as TokioUnixListener, UnixStream as TokioUnixStream,
19
    };
20

            
21
    use futures::io::{AsyncRead, AsyncWrite};
22
    use paste::paste;
23
    use tokio_util::compat::{Compat, TokioAsyncReadCompatExt as _};
24

            
25
    use std::io::Result as IoResult;
26
    use std::net::SocketAddr;
27
    use std::pin::Pin;
28
    use std::task::{Context, Poll};
29

            
30
    /// Provide a set of network stream wrappers for a single stream type.
31
    macro_rules! stream_impl {
32
        {
33
            $kind:ident,
34
            $addr:ty,
35
            $cvt_addr:ident
36
        } => {paste!{
37
            /// Wrapper for Tokio's
38
            #[doc = stringify!($kind)]
39
            /// streams,
40
            /// that implements the standard
41
            /// AsyncRead and AsyncWrite.
42
            pub struct [<$kind Stream>] {
43
                /// Underlying tokio_util::compat::Compat wrapper.
44
                s: Compat<[<Tokio $kind Stream>]>,
45
            }
46
            impl From<[<Tokio $kind Stream>]> for [<$kind Stream>] {
47
62
                fn from(s: [<Tokio $kind Stream>]) ->  [<$kind Stream>] {
48
62
                    let s = s.compat();
49
62
                    [<$kind Stream>] { s }
50
62
                }
51
            }
52
            impl AsyncRead for  [<$kind Stream>] {
53
178
                fn poll_read(
54
178
                    mut self: Pin<&mut Self>,
55
178
                    cx: &mut Context<'_>,
56
178
                    buf: &mut [u8],
57
178
                ) -> Poll<IoResult<usize>> {
58
178
                    Pin::new(&mut self.s).poll_read(cx, buf)
59
178
                }
60
            }
61
            impl AsyncWrite for  [<$kind Stream>] {
62
56
                fn poll_write(
63
56
                    mut self: Pin<&mut Self>,
64
56
                    cx: &mut Context<'_>,
65
56
                    buf: &[u8],
66
56
                ) -> Poll<IoResult<usize>> {
67
56
                    Pin::new(&mut self.s).poll_write(cx, buf)
68
56
                }
69
52
                fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
70
52
                    Pin::new(&mut self.s).poll_flush(cx)
71
52
                }
72
6
                fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
73
6
                    Pin::new(&mut self.s).poll_close(cx)
74
6
                }
75
            }
76

            
77
            /// Wrap a Tokio
78
            #[doc = stringify!($kind)]
79
            /// Listener to behave as a futures::io::TcpListener.
80
            pub struct [<$kind Listener>] {
81
                /// The underlying listener.
82
                pub(super) lis: [<Tokio $kind Listener>],
83
            }
84

            
85
            /// Asynchronous stream that yields incoming connections from a
86
            #[doc = stringify!($kind)]
87
            /// Listener.
88
            ///
89
            /// This is analogous to async_std::net::Incoming.
90
            pub struct [<Incoming $kind Streams>] {
91
                /// Reference to the underlying listener.
92
                pub(super) lis: [<Tokio $kind Listener>],
93
            }
94

            
95
            impl futures::stream::Stream for [<Incoming $kind Streams>] {
96
                type Item = IoResult<([<$kind Stream>], $addr)>;
97

            
98
36
                fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
99
36
                    match self.lis.poll_accept(cx) {
100
28
                        Poll::Ready(Ok((s, a))) => Poll::Ready(Some(Ok((s.into(), $cvt_addr(a)? )))),
101
                        Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
102
8
                        Poll::Pending => Poll::Pending,
103
                    }
104
36
                }
105
            }
106
            impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
107
                type Stream = [<$kind Stream>];
108
                type Incoming = [<Incoming $kind Streams>];
109
8
                fn incoming(self) -> Self::Incoming {
110
8
                    [<Incoming $kind Streams>] { lis: self.lis }
111
8
                }
112
8
                fn local_addr(&self) -> IoResult<$addr> {
113
8
                    $cvt_addr(self.lis.local_addr()?)
114
8
                }
115
            }
116
        }}
117
    }
118

            
119
    /// Try to convert a tokio `unix::SocketAddr` into a crate::SocketAddr.
120
    ///
121
    /// Frustratingly, this information is _right there_: Tokio's SocketAddr has a
122
    /// std::unix::net::SocketAddr internally, but there appears to be no way to get it out.
123
    #[cfg(unix)]
124
    #[allow(clippy::needless_pass_by_value)]
125
    fn try_cvt_tokio_unix_addr(
126
        addr: tokio_crate::net::unix::SocketAddr,
127
    ) -> IoResult<unix::SocketAddr> {
128
        if addr.is_unnamed() {
129
            crate::unix::new_unnamed_socketaddr()
130
        } else if let Some(p) = addr.as_pathname() {
131
            unix::SocketAddr::from_pathname(p)
132
        } else {
133
            Err(crate::unix::UnsupportedAfUnixAddressType.into())
134
        }
135
    }
136

            
137
    /// Wrapper for (not) converting std::net::SocketAddr to itself.
138
    #[allow(clippy::unnecessary_wraps)]
139
36
    fn identity_fn_socketaddr(addr: std::net::SocketAddr) -> IoResult<std::net::SocketAddr> {
140
36
        Ok(addr)
141
36
    }
142

            
143
    stream_impl! { Tcp, std::net::SocketAddr, identity_fn_socketaddr }
144
    #[cfg(unix)]
145
    stream_impl! { Unix, unix::SocketAddr, try_cvt_tokio_unix_addr }
146

            
147
    /// Wrap a Tokio UdpSocket
148
    pub struct UdpSocket {
149
        /// The underelying UdpSocket
150
        socket: TokioUdpSocket,
151
    }
152

            
153
    impl UdpSocket {
154
        /// Bind a UdpSocket
155
12
        pub async fn bind(addr: SocketAddr) -> IoResult<Self> {
156
8
            TokioUdpSocket::bind(addr)
157
8
                .await
158
8
                .map(|socket| UdpSocket { socket })
159
8
        }
160
    }
161

            
162
    #[async_trait]
163
    impl traits::UdpSocket for UdpSocket {
164
4
        async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
165
            self.socket.recv_from(buf).await
166
4
        }
167

            
168
4
        async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
169
            self.socket.send_to(buf, target).await
170
4
        }
171

            
172
8
        fn local_addr(&self) -> IoResult<SocketAddr> {
173
8
            self.socket.local_addr()
174
8
        }
175
    }
176

            
177
    impl traits::StreamOps for TcpStream {
178
        fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
179
            impls::streamops::set_tcp_notsent_lowat(&self.s, notsent_lowat)
180
        }
181

            
182
        #[cfg(target_os = "linux")]
183
        fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
184
            Box::new(impls::streamops::TcpSockFd::from_fd(&self.s))
185
        }
186
    }
187

            
188
    #[cfg(unix)]
189
    impl traits::StreamOps for UnixStream {
190
        fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
191
            Err(traits::UnsupportedStreamOp::new(
192
                "set_tcp_notsent_lowat",
193
                "unsupported on Unix streams",
194
            )
195
            .into())
196
        }
197
    }
198
}
199

            
200
// ==============================
201

            
202
use crate::traits::*;
203
use async_trait::async_trait;
204
use futures::Future;
205
use std::io::Result as IoResult;
206
use std::time::Duration;
207
#[cfg(unix)]
208
use tor_general_addr::unix;
209

            
210
impl SleepProvider for TokioRuntimeHandle {
211
    type SleepFuture = tokio_crate::time::Sleep;
212
13108
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
213
13108
        tokio_crate::time::sleep(duration)
214
13108
    }
215
}
216

            
217
#[async_trait]
218
impl crate::traits::NetStreamProvider for TokioRuntimeHandle {
219
    type Stream = net::TcpStream;
220
    type Listener = net::TcpListener;
221

            
222
34
    async fn connect(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Stream> {
223
        let s = net::TokioTcpStream::connect(addr).await?;
224
        Ok(s.into())
225
34
    }
226
8
    async fn listen(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Listener> {
227
        let lis = net::TokioTcpListener::bind(*addr).await?;
228
        Ok(net::TcpListener { lis })
229
8
    }
230
}
231

            
232
#[cfg(unix)]
233
#[async_trait]
234
impl crate::traits::NetStreamProvider<unix::SocketAddr> for TokioRuntimeHandle {
235
    type Stream = net::UnixStream;
236
    type Listener = net::UnixListener;
237

            
238
    async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
239
        let path = addr
240
            .as_pathname()
241
            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
242
        let s = net::TokioUnixStream::connect(path).await?;
243
        Ok(s.into())
244
    }
245
    async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
246
        let path = addr
247
            .as_pathname()
248
            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
249
        let lis = net::TokioUnixListener::bind(path)?;
250
        Ok(net::UnixListener { lis })
251
    }
252
}
253

            
254
#[cfg(not(unix))]
255
crate::impls::impl_unix_non_provider! { TokioRuntimeHandle }
256

            
257
#[async_trait]
258
impl crate::traits::UdpProvider for TokioRuntimeHandle {
259
    type UdpSocket = net::UdpSocket;
260

            
261
8
    async fn bind(&self, addr: &std::net::SocketAddr) -> IoResult<Self::UdpSocket> {
262
        net::UdpSocket::bind(*addr).await
263
8
    }
264
}
265

            
266
/// Create and return a new Tokio multithreaded runtime.
267
16504
pub(crate) fn create_runtime() -> IoResult<TokioRuntimeHandle> {
268
16504
    let runtime = async_executors::exec::TokioTp::new().map_err(std::io::Error::other)?;
269
16504
    Ok(runtime.into())
270
16504
}
271

            
272
/// Wrapper around a Handle to a tokio runtime.
273
///
274
/// Ideally, this type would go away, and we would just use
275
/// `tokio::runtime::Handle` directly.  Unfortunately, we can't implement
276
/// `futures::Spawn` on it ourselves because of Rust's orphan rules, so we need
277
/// to define a new type here.
278
///
279
/// # Limitations
280
///
281
/// Note that Arti requires that the runtime should have working implementations
282
/// for Tokio's time, net, and io facilities, but we have no good way to check
283
/// that when creating this object.
284
#[derive(Clone, Debug)]
285
pub struct TokioRuntimeHandle {
286
    /// If present, the tokio executor that we've created (and which we own).
287
    ///
288
    /// We never access this directly; only through `handle`.  We keep it here
289
    /// so that our Runtime types can be agnostic about whether they own the
290
    /// executor.
291
    owned: Option<async_executors::TokioTp>,
292
    /// The underlying Handle.
293
    handle: tokio_crate::runtime::Handle,
294
}
295

            
296
impl TokioRuntimeHandle {
297
    /// Wrap a tokio runtime handle into a format that Arti can use.
298
    ///
299
    /// # Limitations
300
    ///
301
    /// Note that Arti requires that the runtime should have working
302
    /// implementations for Tokio's time, net, and io facilities, but we have
303
    /// no good way to check that when creating this object.
304
80
    pub(crate) fn new(handle: tokio_crate::runtime::Handle) -> Self {
305
80
        handle.into()
306
80
    }
307

            
308
    /// Return true if this handle owns the executor that it points to.
309
    pub fn is_owned(&self) -> bool {
310
        self.owned.is_some()
311
    }
312
}
313

            
314
impl From<tokio_crate::runtime::Handle> for TokioRuntimeHandle {
315
80
    fn from(handle: tokio_crate::runtime::Handle) -> Self {
316
80
        Self {
317
80
            owned: None,
318
80
            handle,
319
80
        }
320
80
    }
321
}
322

            
323
impl From<async_executors::TokioTp> for TokioRuntimeHandle {
324
16504
    fn from(owner: async_executors::TokioTp) -> TokioRuntimeHandle {
325
16764
        let handle = owner.block_on(async { tokio_crate::runtime::Handle::current() });
326
16504
        Self {
327
16504
            owned: Some(owner),
328
16504
            handle,
329
16504
        }
330
16504
    }
331
}
332

            
333
impl ToplevelBlockOn for TokioRuntimeHandle {
334
    #[track_caller]
335
432
    fn block_on<F: Future>(&self, f: F) -> F::Output {
336
432
        self.handle.block_on(f)
337
432
    }
338
}
339

            
340
impl Blocking for TokioRuntimeHandle {
341
    type ThreadHandle<T: Send + 'static> = async_executors::BlockingHandle<T>;
342

            
343
    #[track_caller]
344
    fn spawn_blocking<F, T>(&self, f: F) -> async_executors::BlockingHandle<T>
345
    where
346
        F: FnOnce() -> T + Send + 'static,
347
        T: Send + 'static,
348
    {
349
        async_executors::BlockingHandle::tokio(self.handle.spawn_blocking(f))
350
    }
351

            
352
    #[track_caller]
353
    fn reenter_block_on<F: Future>(&self, future: F) -> F::Output {
354
        self.handle.block_on(future)
355
    }
356

            
357
    #[track_caller]
358
    fn blocking_io<F, T>(&self, f: F) -> impl Future<Output = T>
359
    where
360
        F: FnOnce() -> T + Send + 'static,
361
        T: Send + 'static,
362
    {
363
        let r = tokio_crate::task::block_in_place(f);
364
        std::future::ready(r)
365
    }
366
}
367

            
368
impl futures::task::Spawn for TokioRuntimeHandle {
369
    #[track_caller]
370
22648
    fn spawn_obj(
371
22648
        &self,
372
22648
        future: futures::task::FutureObj<'static, ()>,
373
22648
    ) -> Result<(), futures::task::SpawnError> {
374
22648
        let join_handle = self.handle.spawn(future);
375
22648
        drop(join_handle); // this makes the task detached.
376
22648
        Ok(())
377
22648
    }
378
}