1
//! Implement the default transport, which opens TCP connections using a
2
//! happy-eyeballs style parallel algorithm.
3

            
4
use std::{net::SocketAddr, sync::Arc, time::Duration};
5

            
6
use async_trait::async_trait;
7
use futures::{stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt};
8
use safelog::sensitive as sv;
9
use tor_error::bad_api_usage;
10
use tor_linkspec::{ChannelMethod, HasChanMethod, OwnedChanTarget};
11
use tor_rtcompat::{NetStreamProvider, Runtime};
12
use tracing::trace;
13

            
14
use crate::Error;
15

            
16
/// A default transport object that opens TCP connections for a
17
/// `ChannelMethod::Direct`.
18
///
19
/// It opens almost-simultaneous parallel TCP connections to each address, and
20
/// chooses the first one to succeed.
21
#[derive(Clone, Debug)]
22
pub(crate) struct DefaultTransport<R: Runtime> {
23
    /// The runtime that we use for connecting.
24
    runtime: R,
25
}
26

            
27
impl<R: Runtime> DefaultTransport<R> {
28
    /// Construct a new DefaultTransport
29
46
    pub(crate) fn new(runtime: R) -> Self {
30
46
        Self { runtime }
31
46
    }
32
}
33

            
34
#[async_trait]
35
impl<R: Runtime> crate::transport::TransportImplHelper for DefaultTransport<R> {
36
    type Stream = <R as NetStreamProvider>::Stream;
37

            
38
    /// Implements the transport: makes a TCP connection (possibly
39
    /// tunneled over whatever protocol) if possible.
40
    async fn connect(
41
        &self,
42
        target: &OwnedChanTarget,
43
2
    ) -> crate::Result<(OwnedChanTarget, Self::Stream)> {
44
2
        let direct_addrs: Vec<_> = match target.chan_method() {
45
2
            ChannelMethod::Direct(addrs) => addrs,
46
            #[allow(unreachable_patterns)]
47
            _ => {
48
                return Err(Error::UnusableTarget(bad_api_usage!(
49
                    "Used default transport implementation for an unsupported transport."
50
                )))
51
            }
52
        };
53

            
54
2
        trace!("Launching direct connection for {}", target);
55

            
56
2
        let (stream, addr) = connect_to_one(&self.runtime, &direct_addrs).await?;
57
2
        let mut using_target = target.clone();
58
2
        let _ignore = using_target.chan_method_mut().retain_addrs(|a| a == &addr);
59
2

            
60
2
        Ok((using_target, stream))
61
4
    }
62
}
63

            
64
/// Time to wait between starting parallel connections to the same relay.
65
static CONNECTION_DELAY: Duration = Duration::from_millis(150);
66

            
67
/// Connect to one of the addresses in `addrs` by running connections in parallel until one works.
68
///
69
/// This implements a basic version of RFC 8305 "happy eyeballs".
70
30
async fn connect_to_one<R: Runtime>(
71
30
    rt: &R,
72
30
    addrs: &[SocketAddr],
73
30
) -> crate::Result<(<R as NetStreamProvider>::Stream, SocketAddr)> {
74
30
    // We need *some* addresses to connect to.
75
30
    if addrs.is_empty() {
76
2
        return Err(Error::UnusableTarget(bad_api_usage!(
77
2
            "No addresses for chosen relay"
78
2
        )));
79
28
    }
80
28

            
81
28
    // Turn each address into a future that waits (i * CONNECTION_DELAY), then
82
28
    // attempts to connect to the address using the runtime (where i is the
83
28
    // array index). Shove all of these into a `FuturesUnordered`, polling them
84
28
    // simultaneously and returning the results in completion order.
85
28
    //
86
28
    // This is basically the concurrent-connection stuff from RFC 8305, ish.
87
28
    // TODO(eta): sort the addresses first?
88
28
    let mut connections = addrs
89
28
        .iter()
90
28
        .enumerate()
91
52
        .map(|(i, a)| {
92
52
            let delay = rt.sleep(CONNECTION_DELAY * i as u32);
93
52
            delay.then(move |_| {
94
40
                tracing::debug!("Connecting to {}", a);
95
40
                rt.connect(a)
96
40
                    .map_ok(move |stream| (stream, *a))
97
40
                    .map_err(move |e| (e, *a))
98
52
            })
99
52
        })
100
28
        .collect::<FuturesUnordered<_>>();
101
28

            
102
28
    let mut ret = None;
103
28
    let mut errors = vec![];
104

            
105
38
    while let Some(result) = connections.next().await {
106
30
        match result {
107
20
            Ok(s) => {
108
20
                // We got a stream (and address).
109
20
                ret = Some(s);
110
20
                break;
111
            }
112
10
            Err((e, a)) => {
113
10
                // We got a failure on one of the streams. Store the error.
114
10
                // TODO(eta): ideally we'd start the next connection attempt immediately.
115
10
                tor_error::warn_report!(e, "Connection to {} failed", sv(a));
116
10
                errors.push((e, a));
117
            }
118
        }
119
    }
120

            
121
    // Ensure we don't continue trying to make connections.
122
22
    drop(connections);
123
22

            
124
22
    ret.ok_or_else(|| Error::ChannelBuild {
125
2
        addresses: errors
126
2
            .into_iter()
127
2
            .map(|(e, a)| (sv(a), Arc::new(e)))
128
2
            .collect(),
129
22
    })
130
24
}
131

            
132
#[cfg(test)]
133
mod test {
134
    // @@ begin test lint list maintained by maint/add_warning @@
135
    #![allow(clippy::bool_assert_comparison)]
136
    #![allow(clippy::clone_on_copy)]
137
    #![allow(clippy::dbg_macro)]
138
    #![allow(clippy::mixed_attributes_style)]
139
    #![allow(clippy::print_stderr)]
140
    #![allow(clippy::print_stdout)]
141
    #![allow(clippy::single_char_pattern)]
142
    #![allow(clippy::unwrap_used)]
143
    #![allow(clippy::unchecked_duration_subtraction)]
144
    #![allow(clippy::useless_vec)]
145
    #![allow(clippy::needless_pass_by_value)]
146
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
147

            
148
    use std::str::FromStr;
149

            
150
    use tor_rtcompat::{test_with_one_runtime, SleepProviderExt};
151
    use tor_rtmock::net::MockNetwork;
152

            
153
    use super::*;
154

            
155
    #[test]
156
    fn test_connect_one() {
157
        let client_addr = "192.0.1.16".parse().unwrap();
158
        // We'll put a "relay" at this address
159
        let addr1 = SocketAddr::from_str("192.0.2.17:443").unwrap();
160
        // We'll put nothing at this address, to generate errors.
161
        let addr2 = SocketAddr::from_str("192.0.3.18:443").unwrap();
162
        // Well put a black hole at this address, to generate timeouts.
163
        let addr3 = SocketAddr::from_str("192.0.4.19:443").unwrap();
164
        // We'll put a "relay" at this address too
165
        let addr4 = SocketAddr::from_str("192.0.9.9:443").unwrap();
166

            
167
        test_with_one_runtime!(|rt| async move {
168
            // Stub out the internet so that this connection can work.
169
            let network = MockNetwork::new();
170

            
171
            // Set up a client and server runtime with a given IP
172
            let client_rt = network
173
                .builder()
174
                .add_address(client_addr)
175
                .runtime(rt.clone());
176
            let server_rt = network
177
                .builder()
178
                .add_address(addr1.ip())
179
                .add_address(addr4.ip())
180
                .runtime(rt.clone());
181
            let _listener = server_rt.mock_net().listen(&addr1).await.unwrap();
182
            let _listener2 = server_rt.mock_net().listen(&addr4).await.unwrap();
183
            // TODO: Because this test doesn't mock time, there will actually be
184
            // delays as we wait for connections to this address to time out. It
185
            // would be good to use MockSleepProvider instead, once we figure
186
            // out how to make it both reliable and convenient.
187
            network.add_blackhole(addr3).unwrap();
188

            
189
            // No addresses? Can't succeed.
190
            let failure = connect_to_one(&client_rt, &[]).await;
191
            assert!(failure.is_err());
192

            
193
            // Connect to a set of addresses including addr1? That's a success.
194
            for addresses in [
195
                &[addr1][..],
196
                &[addr1, addr2][..],
197
                &[addr2, addr1][..],
198
                &[addr1, addr3][..],
199
                &[addr3, addr1][..],
200
                &[addr1, addr2, addr3][..],
201
                &[addr3, addr2, addr1][..],
202
            ] {
203
                let (_conn, addr) = connect_to_one(&client_rt, addresses).await.unwrap();
204
                assert_eq!(addr, addr1);
205
            }
206

            
207
            // Connect to a set of addresses including addr2 but not addr1?
208
            // That's an error of one kind or another.
209
            for addresses in [
210
                &[addr2][..],
211
                &[addr2, addr3][..],
212
                &[addr3, addr2][..],
213
                &[addr3][..],
214
            ] {
215
                let expect_timeout = addresses.contains(&addr3);
216
                let failure = rt
217
                    .timeout(
218
                        Duration::from_millis(300),
219
                        connect_to_one(&client_rt, addresses),
220
                    )
221
                    .await;
222
                if expect_timeout {
223
                    assert!(failure.is_err());
224
                } else {
225
                    assert!(failure.unwrap().is_err());
226
                }
227
            }
228

            
229
            // Connect to addr1 and addr4?  The first one should win.
230
            let (_conn, addr) = connect_to_one(&client_rt, &[addr1, addr4]).await.unwrap();
231
            assert_eq!(addr, addr1);
232
            let (_conn, addr) = connect_to_one(&client_rt, &[addr4, addr1]).await.unwrap();
233
            assert_eq!(addr, addr4);
234
        });
235
    }
236
}