tor_chanmgr/transport/
default.rs

1//! Implement the default transport, which opens TCP connections using a
2//! happy-eyeballs style parallel algorithm.
3
4use std::{net::SocketAddr, sync::Arc, time::Duration};
5
6use async_trait::async_trait;
7use futures::{stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt};
8use safelog::sensitive as sv;
9use tor_error::bad_api_usage;
10use tor_linkspec::{ChannelMethod, HasChanMethod, OwnedChanTarget};
11use tor_rtcompat::{NetStreamProvider, Runtime};
12use tracing::trace;
13
14use crate::Error;
15
16/// A default transport object that opens TCP connections for a
17/// `ChannelMethod::Direct`.
18///
19/// It opens almost-simultaneous parallel TCP connections to each address, and
20/// chooses the first one to succeed.
21#[derive(Clone, Debug)]
22pub(crate) struct DefaultTransport<R: Runtime> {
23    /// The runtime that we use for connecting.
24    runtime: R,
25}
26
27impl<R: Runtime> DefaultTransport<R> {
28    /// Construct a new DefaultTransport
29    pub(crate) fn new(runtime: R) -> Self {
30        Self { runtime }
31    }
32}
33
34#[async_trait]
35impl<R: Runtime> crate::transport::TransportImplHelper for DefaultTransport<R> {
36    type Stream = <R as NetStreamProvider>::Stream;
37
38    /// Implements the transport: makes a TCP connection (possibly
39    /// tunneled over whatever protocol) if possible.
40    async fn connect(
41        &self,
42        target: &OwnedChanTarget,
43    ) -> crate::Result<(OwnedChanTarget, Self::Stream)> {
44        let direct_addrs: Vec<_> = match target.chan_method() {
45            ChannelMethod::Direct(addrs) => addrs,
46            #[allow(unreachable_patterns)]
47            _ => {
48                return Err(Error::UnusableTarget(bad_api_usage!(
49                    "Used default transport implementation for an unsupported transport."
50                )))
51            }
52        };
53
54        trace!("Launching direct connection for {}", target);
55
56        let (stream, addr) = connect_to_one(&self.runtime, &direct_addrs).await?;
57        let mut using_target = target.clone();
58        let _ignore = using_target.chan_method_mut().retain_addrs(|a| a == &addr);
59
60        Ok((using_target, stream))
61    }
62}
63
64/// Time to wait between starting parallel connections to the same relay.
65static CONNECTION_DELAY: Duration = Duration::from_millis(150);
66
67/// Connect to one of the addresses in `addrs` by running connections in parallel until one works.
68///
69/// This implements a basic version of RFC 8305 "happy eyeballs".
70async fn connect_to_one<R: Runtime>(
71    rt: &R,
72    addrs: &[SocketAddr],
73) -> crate::Result<(<R as NetStreamProvider>::Stream, SocketAddr)> {
74    // We need *some* addresses to connect to.
75    if addrs.is_empty() {
76        return Err(Error::UnusableTarget(bad_api_usage!(
77            "No addresses for chosen relay"
78        )));
79    }
80
81    // Turn each address into a future that waits (i * CONNECTION_DELAY), then
82    // attempts to connect to the address using the runtime (where i is the
83    // array index). Shove all of these into a `FuturesUnordered`, polling them
84    // simultaneously and returning the results in completion order.
85    //
86    // This is basically the concurrent-connection stuff from RFC 8305, ish.
87    // TODO(eta): sort the addresses first?
88    let mut connections = addrs
89        .iter()
90        .enumerate()
91        .map(|(i, a)| {
92            let delay = rt.sleep(CONNECTION_DELAY * i as u32);
93            delay.then(move |_| {
94                tracing::debug!("Connecting to {}", a);
95                rt.connect(a)
96                    .map_ok(move |stream| (stream, *a))
97                    .map_err(move |e| (e, *a))
98            })
99        })
100        .collect::<FuturesUnordered<_>>();
101
102    let mut ret = None;
103    let mut errors = vec![];
104
105    while let Some(result) = connections.next().await {
106        match result {
107            Ok(s) => {
108                // We got a stream (and address).
109                ret = Some(s);
110                break;
111            }
112            Err((e, a)) => {
113                // We got a failure on one of the streams. Store the error.
114                // TODO(eta): ideally we'd start the next connection attempt immediately.
115                tor_error::warn_report!(e, "Connection to {} failed", sv(a));
116                errors.push((e, a));
117            }
118        }
119    }
120
121    // Ensure we don't continue trying to make connections.
122    drop(connections);
123
124    ret.ok_or_else(|| Error::ChannelBuild {
125        addresses: errors
126            .into_iter()
127            .map(|(e, a)| (sv(a), Arc::new(e)))
128            .collect(),
129    })
130}
131
132#[cfg(test)]
133mod test {
134    // @@ begin test lint list maintained by maint/add_warning @@
135    #![allow(clippy::bool_assert_comparison)]
136    #![allow(clippy::clone_on_copy)]
137    #![allow(clippy::dbg_macro)]
138    #![allow(clippy::mixed_attributes_style)]
139    #![allow(clippy::print_stderr)]
140    #![allow(clippy::print_stdout)]
141    #![allow(clippy::single_char_pattern)]
142    #![allow(clippy::unwrap_used)]
143    #![allow(clippy::unchecked_duration_subtraction)]
144    #![allow(clippy::useless_vec)]
145    #![allow(clippy::needless_pass_by_value)]
146    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
147
148    use std::str::FromStr;
149
150    use tor_rtcompat::{test_with_one_runtime, SleepProviderExt};
151    use tor_rtmock::net::MockNetwork;
152
153    use super::*;
154
155    #[test]
156    fn test_connect_one() {
157        let client_addr = "192.0.1.16".parse().unwrap();
158        // We'll put a "relay" at this address
159        let addr1 = SocketAddr::from_str("192.0.2.17:443").unwrap();
160        // We'll put nothing at this address, to generate errors.
161        let addr2 = SocketAddr::from_str("192.0.3.18:443").unwrap();
162        // Well put a black hole at this address, to generate timeouts.
163        let addr3 = SocketAddr::from_str("192.0.4.19:443").unwrap();
164        // We'll put a "relay" at this address too
165        let addr4 = SocketAddr::from_str("192.0.9.9:443").unwrap();
166
167        test_with_one_runtime!(|rt| async move {
168            // Stub out the internet so that this connection can work.
169            let network = MockNetwork::new();
170
171            // Set up a client and server runtime with a given IP
172            let client_rt = network
173                .builder()
174                .add_address(client_addr)
175                .runtime(rt.clone());
176            let server_rt = network
177                .builder()
178                .add_address(addr1.ip())
179                .add_address(addr4.ip())
180                .runtime(rt.clone());
181            let _listener = server_rt.mock_net().listen(&addr1).await.unwrap();
182            let _listener2 = server_rt.mock_net().listen(&addr4).await.unwrap();
183            // TODO: Because this test doesn't mock time, there will actually be
184            // delays as we wait for connections to this address to time out. It
185            // would be good to use MockSleepProvider instead, once we figure
186            // out how to make it both reliable and convenient.
187            network.add_blackhole(addr3).unwrap();
188
189            // No addresses? Can't succeed.
190            let failure = connect_to_one(&client_rt, &[]).await;
191            assert!(failure.is_err());
192
193            // Connect to a set of addresses including addr1? That's a success.
194            for addresses in [
195                &[addr1][..],
196                &[addr1, addr2][..],
197                &[addr2, addr1][..],
198                &[addr1, addr3][..],
199                &[addr3, addr1][..],
200                &[addr1, addr2, addr3][..],
201                &[addr3, addr2, addr1][..],
202            ] {
203                let (_conn, addr) = connect_to_one(&client_rt, addresses).await.unwrap();
204                assert_eq!(addr, addr1);
205            }
206
207            // Connect to a set of addresses including addr2 but not addr1?
208            // That's an error of one kind or another.
209            for addresses in [
210                &[addr2][..],
211                &[addr2, addr3][..],
212                &[addr3, addr2][..],
213                &[addr3][..],
214            ] {
215                let expect_timeout = addresses.contains(&addr3);
216                let failure = rt
217                    .timeout(
218                        Duration::from_millis(300),
219                        connect_to_one(&client_rt, addresses),
220                    )
221                    .await;
222                if expect_timeout {
223                    assert!(failure.is_err());
224                } else {
225                    assert!(failure.unwrap().is_err());
226                }
227            }
228
229            // Connect to addr1 and addr4?  The first one should win.
230            let (_conn, addr) = connect_to_one(&client_rt, &[addr1, addr4]).await.unwrap();
231            assert_eq!(addr, addr1);
232            let (_conn, addr) = connect_to_one(&client_rt, &[addr4, addr1]).await.unwrap();
233            assert_eq!(addr, addr4);
234        });
235    }
236}