1
//! Support for streams and listeners on `general::SocketAddr`.
2

            
3
use async_trait::async_trait;
4
use futures::{stream, AsyncRead, AsyncWrite, StreamExt as _};
5
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult};
6
use std::net;
7
use std::task::Poll;
8
use std::{pin::Pin, task::Context};
9
use tor_general_addr::unix;
10

            
11
use crate::{NetStreamListener, NetStreamProvider, StreamOps};
12
use tor_general_addr::general;
13

            
14
pub use general::{AddrParseError, SocketAddr};
15

            
16
/// Helper trait to allow us to create a type-erased stream.
17
///
18
/// (Rust doesn't allow "dyn AsyncRead + AsyncWrite")
19
trait ReadAndWrite: AsyncRead + AsyncWrite + StreamOps + Send + Sync {}
20
impl<T> ReadAndWrite for T where T: AsyncRead + AsyncWrite + StreamOps + Send + Sync {}
21

            
22
/// A stream returned by a `NetStreamProvider<GeneralizedAddr>`
23
pub struct Stream(Pin<Box<dyn ReadAndWrite>>);
24
impl AsyncRead for Stream {
25
    fn poll_read(
26
        mut self: Pin<&mut Self>,
27
        cx: &mut Context<'_>,
28
        buf: &mut [u8],
29
    ) -> Poll<IoResult<usize>> {
30
        self.0.as_mut().poll_read(cx, buf)
31
    }
32
}
33
impl AsyncWrite for Stream {
34
    fn poll_write(
35
        mut self: Pin<&mut Self>,
36
        cx: &mut Context<'_>,
37
        buf: &[u8],
38
    ) -> Poll<IoResult<usize>> {
39
        self.0.as_mut().poll_write(cx, buf)
40
    }
41

            
42
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
43
        self.0.as_mut().poll_flush(cx)
44
    }
45

            
46
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
47
        self.0.as_mut().poll_close(cx)
48
    }
49
}
50

            
51
impl StreamOps for Stream {
52
    fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
53
        self.0.set_tcp_notsent_lowat(notsent_lowat)
54
    }
55

            
56
    fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
57
        self.0.new_handle()
58
    }
59
}
60

            
61
/// The type of the result from an [`IncomingStreams`].
62
type StreamItem = IoResult<(Stream, general::SocketAddr)>;
63

            
64
/// A stream of incoming connections on a [`general::Listener`](Listener).
65
pub struct IncomingStreams(Pin<Box<dyn stream::Stream<Item = StreamItem> + Send + Sync>>);
66

            
67
impl stream::Stream for IncomingStreams {
68
    type Item = IoResult<(Stream, general::SocketAddr)>;
69

            
70
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71
        self.0.as_mut().poll_next(cx)
72
    }
73
}
74

            
75
/// A listener returned by a `NetStreamProvider<general::SocketAddr>`.
76
pub struct Listener {
77
    /// The `futures::Stream` of incoming network streams.
78
    streams: IncomingStreams,
79
    /// The local address on which we're listening.
80
    local_addr: general::SocketAddr,
81
}
82

            
83
impl NetStreamListener<general::SocketAddr> for Listener {
84
    type Stream = Stream;
85
    type Incoming = IncomingStreams;
86

            
87
    fn incoming(self) -> IncomingStreams {
88
        self.streams
89
    }
90

            
91
    fn local_addr(&self) -> IoResult<general::SocketAddr> {
92
        Ok(self.local_addr.clone())
93
    }
94
}
95

            
96
/// Use `provider` to launch a `NetStreamListener` at `address`, and wrap that listener
97
/// as a `Listener`.
98
async fn abstract_listener_on<ADDR, P>(provider: &P, address: &ADDR) -> IoResult<Listener>
99
where
100
    P: NetStreamProvider<ADDR>,
101
    general::SocketAddr: From<ADDR>,
102
{
103
    let lis = provider.listen(address).await?;
104
    let local_addr = general::SocketAddr::from(lis.local_addr()?);
105
    let streams = lis.incoming().map(|result| {
106
        result.map(|(socket, addr)| (Stream(Box::pin(socket)), general::SocketAddr::from(addr)))
107
    });
108
    let streams = IncomingStreams(Box::pin(streams));
109
    Ok(Listener {
110
        streams,
111
        local_addr,
112
    })
113
}
114

            
115
#[async_trait]
116
impl<T> NetStreamProvider<general::SocketAddr> for T
117
where
118
    T: NetStreamProvider<net::SocketAddr> + NetStreamProvider<unix::SocketAddr>,
119
{
120
    type Stream = Stream;
121
    type Listener = Listener;
122

            
123
    async fn connect(&self, addr: &general::SocketAddr) -> IoResult<Stream> {
124
        use general::SocketAddr as G;
125
        match addr {
126
            G::Inet(a) => Ok(Stream(Box::pin(self.connect(a).await?))),
127
            G::Unix(a) => Ok(Stream(Box::pin(self.connect(a).await?))),
128
            other => Err(IoError::new(
129
                IoErrorKind::InvalidInput,
130
                UnsupportedAddress(other.clone()),
131
            )),
132
        }
133
    }
134
    async fn listen(&self, addr: &general::SocketAddr) -> IoResult<Listener> {
135
        use general::SocketAddr as G;
136
        match addr {
137
            G::Inet(a) => abstract_listener_on(self, a).await,
138
            G::Unix(a) => abstract_listener_on(self, a).await,
139
            other => Err(IoError::new(
140
                IoErrorKind::InvalidInput,
141
                UnsupportedAddress(other.clone()),
142
            )),
143
        }
144
    }
145
}
146

            
147
/// Tried to use a [`general::SocketAddr`] that `tor-rtcompat` didn't understand.
148
#[derive(Clone, Debug, thiserror::Error)]
149
#[error("Socket address {0:?} is not supported by tor-rtcompat")]
150
pub struct UnsupportedAddress(general::SocketAddr);