tor_rtcompat/
general.rs

1//! Support for streams and listeners on `general::SocketAddr`.
2
3use async_trait::async_trait;
4use futures::{stream, AsyncRead, AsyncWrite, StreamExt as _};
5use std::io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult};
6use std::net;
7use std::task::Poll;
8use std::{pin::Pin, task::Context};
9use tor_general_addr::unix;
10
11use crate::{NetStreamListener, NetStreamProvider, StreamOps};
12use tor_general_addr::general;
13
14pub use general::{AddrParseError, SocketAddr};
15
16/// Helper trait to allow us to create a type-erased stream.
17///
18/// (Rust doesn't allow "dyn AsyncRead + AsyncWrite")
19trait ReadAndWrite: AsyncRead + AsyncWrite + StreamOps + Send + Sync {}
20impl<T> ReadAndWrite for T where T: AsyncRead + AsyncWrite + StreamOps + Send + Sync {}
21
22/// A stream returned by a `NetStreamProvider<GeneralizedAddr>`
23pub struct Stream(Pin<Box<dyn ReadAndWrite>>);
24impl AsyncRead for Stream {
25    fn poll_read(
26        mut self: Pin<&mut Self>,
27        cx: &mut Context<'_>,
28        buf: &mut [u8],
29    ) -> Poll<IoResult<usize>> {
30        self.0.as_mut().poll_read(cx, buf)
31    }
32}
33impl AsyncWrite for Stream {
34    fn poll_write(
35        mut self: Pin<&mut Self>,
36        cx: &mut Context<'_>,
37        buf: &[u8],
38    ) -> Poll<IoResult<usize>> {
39        self.0.as_mut().poll_write(cx, buf)
40    }
41
42    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
43        self.0.as_mut().poll_flush(cx)
44    }
45
46    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
47        self.0.as_mut().poll_close(cx)
48    }
49}
50
51impl StreamOps for Stream {
52    fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
53        self.0.set_tcp_notsent_lowat(notsent_lowat)
54    }
55
56    fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
57        self.0.new_handle()
58    }
59}
60
61/// The type of the result from an [`IncomingStreams`].
62type StreamItem = IoResult<(Stream, general::SocketAddr)>;
63
64/// A stream of incoming connections on a [`general::Listener`](Listener).
65pub struct IncomingStreams(Pin<Box<dyn stream::Stream<Item = StreamItem> + Send + Sync>>);
66
67impl stream::Stream for IncomingStreams {
68    type Item = IoResult<(Stream, general::SocketAddr)>;
69
70    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71        self.0.as_mut().poll_next(cx)
72    }
73}
74
75/// A listener returned by a `NetStreamProvider<general::SocketAddr>`.
76pub struct Listener {
77    /// The `futures::Stream` of incoming network streams.
78    streams: IncomingStreams,
79    /// The local address on which we're listening.
80    local_addr: general::SocketAddr,
81}
82
83impl NetStreamListener<general::SocketAddr> for Listener {
84    type Stream = Stream;
85    type Incoming = IncomingStreams;
86
87    fn incoming(self) -> IncomingStreams {
88        self.streams
89    }
90
91    fn local_addr(&self) -> IoResult<general::SocketAddr> {
92        Ok(self.local_addr.clone())
93    }
94}
95
96/// Use `provider` to launch a `NetStreamListener` at `address`, and wrap that listener
97/// as a `Listener`.
98async fn abstract_listener_on<ADDR, P>(provider: &P, address: &ADDR) -> IoResult<Listener>
99where
100    P: NetStreamProvider<ADDR>,
101    general::SocketAddr: From<ADDR>,
102{
103    let lis = provider.listen(address).await?;
104    let local_addr = general::SocketAddr::from(lis.local_addr()?);
105    let streams = lis.incoming().map(|result| {
106        result.map(|(socket, addr)| (Stream(Box::pin(socket)), general::SocketAddr::from(addr)))
107    });
108    let streams = IncomingStreams(Box::pin(streams));
109    Ok(Listener {
110        streams,
111        local_addr,
112    })
113}
114
115#[async_trait]
116impl<T> NetStreamProvider<general::SocketAddr> for T
117where
118    T: NetStreamProvider<net::SocketAddr> + NetStreamProvider<unix::SocketAddr>,
119{
120    type Stream = Stream;
121    type Listener = Listener;
122
123    async fn connect(&self, addr: &general::SocketAddr) -> IoResult<Stream> {
124        use general::SocketAddr as G;
125        match addr {
126            G::Inet(a) => Ok(Stream(Box::pin(self.connect(a).await?))),
127            G::Unix(a) => Ok(Stream(Box::pin(self.connect(a).await?))),
128            other => Err(IoError::new(
129                IoErrorKind::InvalidInput,
130                UnsupportedAddress(other.clone()),
131            )),
132        }
133    }
134    async fn listen(&self, addr: &general::SocketAddr) -> IoResult<Listener> {
135        use general::SocketAddr as G;
136        match addr {
137            G::Inet(a) => abstract_listener_on(self, a).await,
138            G::Unix(a) => abstract_listener_on(self, a).await,
139            other => Err(IoError::new(
140                IoErrorKind::InvalidInput,
141                UnsupportedAddress(other.clone()),
142            )),
143        }
144    }
145}
146
147/// Tried to use a [`general::SocketAddr`] that `tor-rtcompat` didn't understand.
148#[derive(Clone, Debug, thiserror::Error)]
149#[error("Socket address {0:?} is not supported by tor-rtcompat")]
150pub struct UnsupportedAddress(general::SocketAddr);