1
//! Re-exports of the tokio runtime for use with arti.
2
//!
3
//! This crate helps define a slim API around our async runtime so that we
4
//! can easily swap it out.
5

            
6
/// Types used for networking (tokio implementation)
7
pub(crate) mod net {
8
    use crate::{impls, traits};
9
    use async_trait::async_trait;
10
    use tor_general_addr::unix;
11

            
12
    pub(crate) use tokio_crate::net::{
13
        TcpListener as TokioTcpListener, TcpStream as TokioTcpStream, UdpSocket as TokioUdpSocket,
14
    };
15
    #[cfg(unix)]
16
    pub(crate) use tokio_crate::net::{
17
        UnixListener as TokioUnixListener, UnixStream as TokioUnixStream,
18
    };
19

            
20
    use futures::io::{AsyncRead, AsyncWrite};
21
    use paste::paste;
22
    use tokio_util::compat::{Compat, TokioAsyncReadCompatExt as _};
23

            
24
    use std::io::Result as IoResult;
25
    use std::net::SocketAddr;
26
    use std::pin::Pin;
27
    use std::task::{Context, Poll};
28

            
29
    /// Provide a set of network stream wrappers for a single stream type.
30
    macro_rules! stream_impl {
31
        {
32
            $kind:ident,
33
            $addr:ty,
34
            $cvt_addr:ident
35
        } => {paste!{
36
            /// Wrapper for Tokio's
37
            #[doc = stringify!($kind)]
38
            /// streams,
39
            /// that implements the standard
40
            /// AsyncRead and AsyncWrite.
41
            pub struct [<$kind Stream>] {
42
                /// Underlying tokio_util::compat::Compat wrapper.
43
                s: Compat<[<Tokio $kind Stream>]>,
44
            }
45
            impl From<[<Tokio $kind Stream>]> for [<$kind Stream>] {
46
62
                fn from(s: [<Tokio $kind Stream>]) ->  [<$kind Stream>] {
47
62
                    let s = s.compat();
48
62
                    [<$kind Stream>] { s }
49
62
                }
50
            }
51
            impl AsyncRead for  [<$kind Stream>] {
52
180
                fn poll_read(
53
180
                    mut self: Pin<&mut Self>,
54
180
                    cx: &mut Context<'_>,
55
180
                    buf: &mut [u8],
56
180
                ) -> Poll<IoResult<usize>> {
57
180
                    Pin::new(&mut self.s).poll_read(cx, buf)
58
180
                }
59
            }
60
            impl AsyncWrite for  [<$kind Stream>] {
61
56
                fn poll_write(
62
56
                    mut self: Pin<&mut Self>,
63
56
                    cx: &mut Context<'_>,
64
56
                    buf: &[u8],
65
56
                ) -> Poll<IoResult<usize>> {
66
56
                    Pin::new(&mut self.s).poll_write(cx, buf)
67
56
                }
68
52
                fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
69
52
                    Pin::new(&mut self.s).poll_flush(cx)
70
52
                }
71
6
                fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
72
6
                    Pin::new(&mut self.s).poll_close(cx)
73
6
                }
74
            }
75

            
76
            /// Wrap a Tokio
77
            #[doc = stringify!($kind)]
78
            /// Listener to behave as a futures::io::TcpListener.
79
            pub struct [<$kind Listener>] {
80
                /// The underlying listener.
81
                pub(super) lis: [<Tokio $kind Listener>],
82
            }
83

            
84
            /// Asynchronous stream that yields incoming connections from a
85
            #[doc = stringify!($kind)]
86
            /// Listener.
87
            ///
88
            /// This is analogous to async_std::net::Incoming.
89
            pub struct [<Incoming $kind Streams>] {
90
                /// Reference to the underlying listener.
91
                pub(super) lis: [<Tokio $kind Listener>],
92
            }
93

            
94
            impl futures::stream::Stream for [<Incoming $kind Streams>] {
95
                type Item = IoResult<([<$kind Stream>], $addr)>;
96

            
97
36
                fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
98
36
                    match self.lis.poll_accept(cx) {
99
28
                        Poll::Ready(Ok((s, a))) => Poll::Ready(Some(Ok((s.into(), $cvt_addr(a)? )))),
100
                        Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
101
8
                        Poll::Pending => Poll::Pending,
102
                    }
103
36
                }
104
            }
105
            impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
106
                type Stream = [<$kind Stream>];
107
                type Incoming = [<Incoming $kind Streams>];
108
8
                fn incoming(self) -> Self::Incoming {
109
8
                    [<Incoming $kind Streams>] { lis: self.lis }
110
8
                }
111
8
                fn local_addr(&self) -> IoResult<$addr> {
112
8
                    $cvt_addr(self.lis.local_addr()?)
113
8
                }
114
            }
115
        }}
116
    }
117

            
118
    /// Try to convert a tokio `unix::SocketAddr` into a crate::SocketAddr.
119
    ///
120
    /// Frustratingly, this information is _right there_: Tokio's SocketAddr has a
121
    /// std::unix::net::SocketAddr internally, but there appears to be no way to get it out.
122
    #[cfg(unix)]
123
    #[allow(clippy::needless_pass_by_value)]
124
    fn try_cvt_tokio_unix_addr(
125
        addr: tokio_crate::net::unix::SocketAddr,
126
    ) -> IoResult<unix::SocketAddr> {
127
        if addr.is_unnamed() {
128
            crate::unix::new_unnamed_socketaddr()
129
        } else if let Some(p) = addr.as_pathname() {
130
            unix::SocketAddr::from_pathname(p)
131
        } else {
132
            Err(crate::unix::UnsupportedAfUnixAddressType.into())
133
        }
134
    }
135

            
136
    /// Wrapper for (not) converting std::net::SocketAddr to itself.
137
    #[allow(clippy::unnecessary_wraps)]
138
36
    fn identity_fn_socketaddr(addr: std::net::SocketAddr) -> IoResult<std::net::SocketAddr> {
139
36
        Ok(addr)
140
36
    }
141

            
142
    stream_impl! { Tcp, std::net::SocketAddr, identity_fn_socketaddr }
143
    #[cfg(unix)]
144
    stream_impl! { Unix, unix::SocketAddr, try_cvt_tokio_unix_addr }
145

            
146
    /// Wrap a Tokio UdpSocket
147
    pub struct UdpSocket {
148
        /// The underelying UdpSocket
149
        socket: TokioUdpSocket,
150
    }
151

            
152
    impl UdpSocket {
153
        /// Bind a UdpSocket
154
12
        pub async fn bind(addr: SocketAddr) -> IoResult<Self> {
155
8
            TokioUdpSocket::bind(addr)
156
8
                .await
157
8
                .map(|socket| UdpSocket { socket })
158
8
        }
159
    }
160

            
161
    #[async_trait]
162
    impl traits::UdpSocket for UdpSocket {
163
4
        async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
164
4
            self.socket.recv_from(buf).await
165
8
        }
166

            
167
4
        async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
168
4
            self.socket.send_to(buf, target).await
169
8
        }
170

            
171
8
        fn local_addr(&self) -> IoResult<SocketAddr> {
172
8
            self.socket.local_addr()
173
8
        }
174
    }
175

            
176
    impl traits::StreamOps for TcpStream {
177
        fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
178
            impls::streamops::set_tcp_notsent_lowat(&self.s, notsent_lowat)
179
        }
180

            
181
        #[cfg(target_os = "linux")]
182
        fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
183
            Box::new(impls::streamops::TcpSockFd::from_fd(&self.s))
184
        }
185
    }
186

            
187
    #[cfg(unix)]
188
    impl traits::StreamOps for UnixStream {
189
        fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
190
            Err(traits::UnsupportedStreamOp::new(
191
                "set_tcp_notsent_lowat",
192
                "unsupported on Unix streams",
193
            )
194
            .into())
195
        }
196
    }
197
}
198

            
199
// ==============================
200

            
201
use crate::traits::*;
202
use async_trait::async_trait;
203
use futures::Future;
204
use std::io::Result as IoResult;
205
use std::time::Duration;
206
use tor_general_addr::unix;
207

            
208
impl SleepProvider for TokioRuntimeHandle {
209
    type SleepFuture = tokio_crate::time::Sleep;
210
7234
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
211
7234
        tokio_crate::time::sleep(duration)
212
7234
    }
213
}
214

            
215
#[async_trait]
216
impl crate::traits::NetStreamProvider for TokioRuntimeHandle {
217
    type Stream = net::TcpStream;
218
    type Listener = net::TcpListener;
219

            
220
34
    async fn connect(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Stream> {
221
34
        let s = net::TokioTcpStream::connect(addr).await?;
222
34
        Ok(s.into())
223
68
    }
224
8
    async fn listen(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Listener> {
225
8
        let lis = net::TokioTcpListener::bind(*addr).await?;
226
8
        Ok(net::TcpListener { lis })
227
16
    }
228
}
229

            
230
#[cfg(unix)]
231
#[async_trait]
232
impl crate::traits::NetStreamProvider<unix::SocketAddr> for TokioRuntimeHandle {
233
    type Stream = net::UnixStream;
234
    type Listener = net::UnixListener;
235

            
236
    async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
237
        let path = addr
238
            .as_pathname()
239
            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
240
        let s = net::TokioUnixStream::connect(path).await?;
241
        Ok(s.into())
242
    }
243
    async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
244
        let path = addr
245
            .as_pathname()
246
            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
247
        let lis = net::TokioUnixListener::bind(path)?;
248
        Ok(net::UnixListener { lis })
249
    }
250
}
251

            
252
#[cfg(not(unix))]
253
crate::impls::impl_unix_non_provider! { TokioRuntimeHandle }
254

            
255
#[async_trait]
256
impl crate::traits::UdpProvider for TokioRuntimeHandle {
257
    type UdpSocket = net::UdpSocket;
258

            
259
8
    async fn bind(&self, addr: &std::net::SocketAddr) -> IoResult<Self::UdpSocket> {
260
8
        net::UdpSocket::bind(*addr).await
261
16
    }
262
}
263

            
264
/// Create and return a new Tokio multithreaded runtime.
265
10883
pub(crate) fn create_runtime() -> IoResult<TokioRuntimeHandle> {
266
10883
    let runtime = async_executors::exec::TokioTp::new()
267
10883
        .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
268
10883
    Ok(runtime.into())
269
10883
}
270

            
271
/// Wrapper around a Handle to a tokio runtime.
272
///
273
/// Ideally, this type would go away, and we would just use
274
/// `tokio::runtime::Handle` directly.  Unfortunately, we can't implement
275
/// `futures::Spawn` on it ourselves because of Rust's orphan rules, so we need
276
/// to define a new type here.
277
///
278
/// # Limitations
279
///
280
/// Note that Arti requires that the runtime should have working implementations
281
/// for Tokio's time, net, and io facilities, but we have no good way to check
282
/// that when creating this object.
283
#[derive(Clone, Debug)]
284
pub struct TokioRuntimeHandle {
285
    /// If present, the tokio executor that we've created (and which we own).
286
    ///
287
    /// We never access this directly; only through `handle`.  We keep it here
288
    /// so that our Runtime types can be agnostic about whether they own the
289
    /// executor.
290
    owned: Option<async_executors::TokioTp>,
291
    /// The underlying Handle.
292
    handle: tokio_crate::runtime::Handle,
293
}
294

            
295
impl TokioRuntimeHandle {
296
    /// Wrap a tokio runtime handle into a format that Arti can use.
297
    ///
298
    /// # Limitations
299
    ///
300
    /// Note that Arti requires that the runtime should have working
301
    /// implementations for Tokio's time, net, and io facilities, but we have
302
    /// no good way to check that when creating this object.
303
65
    pub(crate) fn new(handle: tokio_crate::runtime::Handle) -> Self {
304
65
        handle.into()
305
65
    }
306

            
307
    /// Return true if this handle owns the executor that it points to.
308
    pub fn is_owned(&self) -> bool {
309
        self.owned.is_some()
310
    }
311
}
312

            
313
impl From<tokio_crate::runtime::Handle> for TokioRuntimeHandle {
314
65
    fn from(handle: tokio_crate::runtime::Handle) -> Self {
315
65
        Self {
316
65
            owned: None,
317
65
            handle,
318
65
        }
319
65
    }
320
}
321

            
322
impl From<async_executors::TokioTp> for TokioRuntimeHandle {
323
10883
    fn from(owner: async_executors::TokioTp) -> TokioRuntimeHandle {
324
11103
        let handle = owner.block_on(async { tokio_crate::runtime::Handle::current() });
325
10883
        Self {
326
10883
            owned: Some(owner),
327
10883
            handle,
328
10883
        }
329
10883
    }
330
}
331

            
332
impl ToplevelBlockOn for TokioRuntimeHandle {
333
    #[track_caller]
334
416
    fn block_on<F: Future>(&self, f: F) -> F::Output {
335
416
        self.handle.block_on(f)
336
416
    }
337
}
338

            
339
impl Blocking for TokioRuntimeHandle {
340
    type ThreadHandle<T: Send + 'static> = async_executors::BlockingHandle<T>;
341

            
342
    #[track_caller]
343
    fn spawn_blocking<F, T>(&self, f: F) -> async_executors::BlockingHandle<T>
344
    where
345
        F: FnOnce() -> T + Send + 'static,
346
        T: Send + 'static,
347
    {
348
        async_executors::BlockingHandle::tokio(self.handle.spawn_blocking(f))
349
    }
350

            
351
    #[track_caller]
352
    fn reenter_block_on<F: Future>(&self, future: F) -> F::Output {
353
        self.handle.block_on(future)
354
    }
355

            
356
    #[track_caller]
357
    fn blocking_io<F, T>(&self, f: F) -> impl Future<Output = T>
358
    where
359
        F: FnOnce() -> T + Send + 'static,
360
        T: Send + 'static,
361
    {
362
        let r = tokio_crate::task::block_in_place(f);
363
        std::future::ready(r)
364
    }
365
}
366

            
367
impl futures::task::Spawn for TokioRuntimeHandle {
368
    #[track_caller]
369
14640
    fn spawn_obj(
370
14640
        &self,
371
14640
        future: futures::task::FutureObj<'static, ()>,
372
14640
    ) -> Result<(), futures::task::SpawnError> {
373
14640
        let join_handle = self.handle.spawn(future);
374
14640
        drop(join_handle); // this makes the task detached.
375
14640
        Ok(())
376
14640
    }
377
}