tor_rtcompat/impls/
tokio.rs
1pub(crate) mod net {
8 use crate::{impls, traits};
9 use async_trait::async_trait;
10 use tor_general_addr::unix;
11
12 pub(crate) use tokio_crate::net::{
13 TcpListener as TokioTcpListener, TcpStream as TokioTcpStream, UdpSocket as TokioUdpSocket,
14 };
15 #[cfg(unix)]
16 pub(crate) use tokio_crate::net::{
17 UnixListener as TokioUnixListener, UnixStream as TokioUnixStream,
18 };
19
20 use futures::io::{AsyncRead, AsyncWrite};
21 use paste::paste;
22 use tokio_util::compat::{Compat, TokioAsyncReadCompatExt as _};
23
24 use std::io::Result as IoResult;
25 use std::net::SocketAddr;
26 use std::pin::Pin;
27 use std::task::{Context, Poll};
28
29 macro_rules! stream_impl {
31 {
32 $kind:ident,
33 $addr:ty,
34 $cvt_addr:ident
35 } => {paste!{
36 #[doc = stringify!($kind)]
38 pub struct [<$kind Stream>] {
42 s: Compat<[<Tokio $kind Stream>]>,
44 }
45 impl From<[<Tokio $kind Stream>]> for [<$kind Stream>] {
46 fn from(s: [<Tokio $kind Stream>]) -> [<$kind Stream>] {
47 let s = s.compat();
48 [<$kind Stream>] { s }
49 }
50 }
51 impl AsyncRead for [<$kind Stream>] {
52 fn poll_read(
53 mut self: Pin<&mut Self>,
54 cx: &mut Context<'_>,
55 buf: &mut [u8],
56 ) -> Poll<IoResult<usize>> {
57 Pin::new(&mut self.s).poll_read(cx, buf)
58 }
59 }
60 impl AsyncWrite for [<$kind Stream>] {
61 fn poll_write(
62 mut self: Pin<&mut Self>,
63 cx: &mut Context<'_>,
64 buf: &[u8],
65 ) -> Poll<IoResult<usize>> {
66 Pin::new(&mut self.s).poll_write(cx, buf)
67 }
68 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
69 Pin::new(&mut self.s).poll_flush(cx)
70 }
71 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
72 Pin::new(&mut self.s).poll_close(cx)
73 }
74 }
75
76 #[doc = stringify!($kind)]
78 pub struct [<$kind Listener>] {
80 pub(super) lis: [<Tokio $kind Listener>],
82 }
83
84 #[doc = stringify!($kind)]
86 pub struct [<Incoming $kind Streams>] {
90 pub(super) lis: [<Tokio $kind Listener>],
92 }
93
94 impl futures::stream::Stream for [<Incoming $kind Streams>] {
95 type Item = IoResult<([<$kind Stream>], $addr)>;
96
97 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
98 match self.lis.poll_accept(cx) {
99 Poll::Ready(Ok((s, a))) => Poll::Ready(Some(Ok((s.into(), $cvt_addr(a)? )))),
100 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
101 Poll::Pending => Poll::Pending,
102 }
103 }
104 }
105 impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
106 type Stream = [<$kind Stream>];
107 type Incoming = [<Incoming $kind Streams>];
108 fn incoming(self) -> Self::Incoming {
109 [<Incoming $kind Streams>] { lis: self.lis }
110 }
111 fn local_addr(&self) -> IoResult<$addr> {
112 $cvt_addr(self.lis.local_addr()?)
113 }
114 }
115 }}
116 }
117
118 #[cfg(unix)]
123 #[allow(clippy::needless_pass_by_value)]
124 fn try_cvt_tokio_unix_addr(
125 addr: tokio_crate::net::unix::SocketAddr,
126 ) -> IoResult<unix::SocketAddr> {
127 if addr.is_unnamed() {
128 crate::unix::new_unnamed_socketaddr()
129 } else if let Some(p) = addr.as_pathname() {
130 unix::SocketAddr::from_pathname(p)
131 } else {
132 Err(crate::unix::UnsupportedAfUnixAddressType.into())
133 }
134 }
135
136 #[allow(clippy::unnecessary_wraps)]
138 fn identity_fn_socketaddr(addr: std::net::SocketAddr) -> IoResult<std::net::SocketAddr> {
139 Ok(addr)
140 }
141
142 stream_impl! { Tcp, std::net::SocketAddr, identity_fn_socketaddr }
143 #[cfg(unix)]
144 stream_impl! { Unix, unix::SocketAddr, try_cvt_tokio_unix_addr }
145
146 pub struct UdpSocket {
148 socket: TokioUdpSocket,
150 }
151
152 impl UdpSocket {
153 pub async fn bind(addr: SocketAddr) -> IoResult<Self> {
155 TokioUdpSocket::bind(addr)
156 .await
157 .map(|socket| UdpSocket { socket })
158 }
159 }
160
161 #[async_trait]
162 impl traits::UdpSocket for UdpSocket {
163 async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
164 self.socket.recv_from(buf).await
165 }
166
167 async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
168 self.socket.send_to(buf, target).await
169 }
170
171 fn local_addr(&self) -> IoResult<SocketAddr> {
172 self.socket.local_addr()
173 }
174 }
175
176 impl traits::StreamOps for TcpStream {
177 fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
178 impls::streamops::set_tcp_notsent_lowat(&self.s, notsent_lowat)
179 }
180
181 #[cfg(target_os = "linux")]
182 fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
183 Box::new(impls::streamops::TcpSockFd::from_fd(&self.s))
184 }
185 }
186
187 #[cfg(unix)]
188 impl traits::StreamOps for UnixStream {
189 fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
190 Err(traits::UnsupportedStreamOp::new(
191 "set_tcp_notsent_lowat",
192 "unsupported on Unix streams",
193 )
194 .into())
195 }
196 }
197}
198
199use crate::traits::*;
202use async_trait::async_trait;
203use futures::Future;
204use std::io::Result as IoResult;
205use std::time::Duration;
206use tor_general_addr::unix;
207
208impl SleepProvider for TokioRuntimeHandle {
209 type SleepFuture = tokio_crate::time::Sleep;
210 fn sleep(&self, duration: Duration) -> Self::SleepFuture {
211 tokio_crate::time::sleep(duration)
212 }
213}
214
215#[async_trait]
216impl crate::traits::NetStreamProvider for TokioRuntimeHandle {
217 type Stream = net::TcpStream;
218 type Listener = net::TcpListener;
219
220 async fn connect(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Stream> {
221 let s = net::TokioTcpStream::connect(addr).await?;
222 Ok(s.into())
223 }
224 async fn listen(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Listener> {
225 let lis = net::TokioTcpListener::bind(*addr).await?;
226 Ok(net::TcpListener { lis })
227 }
228}
229
230#[cfg(unix)]
231#[async_trait]
232impl crate::traits::NetStreamProvider<unix::SocketAddr> for TokioRuntimeHandle {
233 type Stream = net::UnixStream;
234 type Listener = net::UnixListener;
235
236 async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
237 let path = addr
238 .as_pathname()
239 .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
240 let s = net::TokioUnixStream::connect(path).await?;
241 Ok(s.into())
242 }
243 async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
244 let path = addr
245 .as_pathname()
246 .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
247 let lis = net::TokioUnixListener::bind(path)?;
248 Ok(net::UnixListener { lis })
249 }
250}
251
252#[cfg(not(unix))]
253crate::impls::impl_unix_non_provider! { TokioRuntimeHandle }
254
255#[async_trait]
256impl crate::traits::UdpProvider for TokioRuntimeHandle {
257 type UdpSocket = net::UdpSocket;
258
259 async fn bind(&self, addr: &std::net::SocketAddr) -> IoResult<Self::UdpSocket> {
260 net::UdpSocket::bind(*addr).await
261 }
262}
263
264pub(crate) fn create_runtime() -> IoResult<TokioRuntimeHandle> {
266 let runtime = async_executors::exec::TokioTp::new()
267 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
268 Ok(runtime.into())
269}
270
271#[derive(Clone, Debug)]
284pub struct TokioRuntimeHandle {
285 owned: Option<async_executors::TokioTp>,
291 handle: tokio_crate::runtime::Handle,
293}
294
295impl TokioRuntimeHandle {
296 pub(crate) fn new(handle: tokio_crate::runtime::Handle) -> Self {
304 handle.into()
305 }
306
307 pub fn is_owned(&self) -> bool {
309 self.owned.is_some()
310 }
311}
312
313impl From<tokio_crate::runtime::Handle> for TokioRuntimeHandle {
314 fn from(handle: tokio_crate::runtime::Handle) -> Self {
315 Self {
316 owned: None,
317 handle,
318 }
319 }
320}
321
322impl From<async_executors::TokioTp> for TokioRuntimeHandle {
323 fn from(owner: async_executors::TokioTp) -> TokioRuntimeHandle {
324 let handle = owner.block_on(async { tokio_crate::runtime::Handle::current() });
325 Self {
326 owned: Some(owner),
327 handle,
328 }
329 }
330}
331
332impl ToplevelBlockOn for TokioRuntimeHandle {
333 #[track_caller]
334 fn block_on<F: Future>(&self, f: F) -> F::Output {
335 self.handle.block_on(f)
336 }
337}
338
339impl Blocking for TokioRuntimeHandle {
340 type ThreadHandle<T: Send + 'static> = async_executors::BlockingHandle<T>;
341
342 #[track_caller]
343 fn spawn_blocking<F, T>(&self, f: F) -> async_executors::BlockingHandle<T>
344 where
345 F: FnOnce() -> T + Send + 'static,
346 T: Send + 'static,
347 {
348 async_executors::BlockingHandle::tokio(self.handle.spawn_blocking(f))
349 }
350
351 #[track_caller]
352 fn reenter_block_on<F: Future>(&self, future: F) -> F::Output {
353 self.handle.block_on(future)
354 }
355
356 #[track_caller]
357 fn blocking_io<F, T>(&self, f: F) -> impl Future<Output = T>
358 where
359 F: FnOnce() -> T + Send + 'static,
360 T: Send + 'static,
361 {
362 let r = tokio_crate::task::block_in_place(f);
363 std::future::ready(r)
364 }
365}
366
367impl futures::task::Spawn for TokioRuntimeHandle {
368 #[track_caller]
369 fn spawn_obj(
370 &self,
371 future: futures::task::FutureObj<'static, ()>,
372 ) -> Result<(), futures::task::SpawnError> {
373 let join_handle = self.handle.spawn(future);
374 drop(join_handle); Ok(())
376 }
377}