1//! Implement the default transport, which opens TCP connections using a
2//! happy-eyeballs style parallel algorithm.
34use std::{net::SocketAddr, sync::Arc, time::Duration};
56use async_trait::async_trait;
7use futures::{stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt};
8use safelog::sensitive as sv;
9use tor_error::bad_api_usage;
10use tor_linkspec::{ChannelMethod, HasChanMethod, OwnedChanTarget};
11use tor_rtcompat::{NetStreamProvider, Runtime};
12use tracing::trace;
1314use crate::Error;
1516/// A default transport object that opens TCP connections for a
17/// `ChannelMethod::Direct`.
18///
19/// It opens almost-simultaneous parallel TCP connections to each address, and
20/// chooses the first one to succeed.
21#[derive(Clone, Debug)]
22pub(crate) struct DefaultTransport<R: Runtime> {
23/// The runtime that we use for connecting.
24runtime: R,
25}
2627impl<R: Runtime> DefaultTransport<R> {
28/// Construct a new DefaultTransport
29pub(crate) fn new(runtime: R) -> Self {
30Self { runtime }
31 }
32}
3334#[async_trait]
35impl<R: Runtime> crate::transport::TransportImplHelper for DefaultTransport<R> {
36type Stream = <R as NetStreamProvider>::Stream;
3738/// Implements the transport: makes a TCP connection (possibly
39 /// tunneled over whatever protocol) if possible.
40async fn connect(
41&self,
42 target: &OwnedChanTarget,
43 ) -> crate::Result<(OwnedChanTarget, Self::Stream)> {
44let direct_addrs: Vec<_> = match target.chan_method() {
45 ChannelMethod::Direct(addrs) => addrs,
46#[allow(unreachable_patterns)]
47_ => {
48return Err(Error::UnusableTarget(bad_api_usage!(
49"Used default transport implementation for an unsupported transport."
50)))
51 }
52 };
5354trace!("Launching direct connection for {}", target);
5556let (stream, addr) = connect_to_one(&self.runtime, &direct_addrs).await?;
57let mut using_target = target.clone();
58let _ignore = using_target.chan_method_mut().retain_addrs(|a| a == &addr);
5960Ok((using_target, stream))
61 }
62}
6364/// Time to wait between starting parallel connections to the same relay.
65static CONNECTION_DELAY: Duration = Duration::from_millis(150);
6667/// Connect to one of the addresses in `addrs` by running connections in parallel until one works.
68///
69/// This implements a basic version of RFC 8305 "happy eyeballs".
70async fn connect_to_one<R: Runtime>(
71 rt: &R,
72 addrs: &[SocketAddr],
73) -> crate::Result<(<R as NetStreamProvider>::Stream, SocketAddr)> {
74// We need *some* addresses to connect to.
75if addrs.is_empty() {
76return Err(Error::UnusableTarget(bad_api_usage!(
77"No addresses for chosen relay"
78)));
79 }
8081// Turn each address into a future that waits (i * CONNECTION_DELAY), then
82 // attempts to connect to the address using the runtime (where i is the
83 // array index). Shove all of these into a `FuturesUnordered`, polling them
84 // simultaneously and returning the results in completion order.
85 //
86 // This is basically the concurrent-connection stuff from RFC 8305, ish.
87 // TODO(eta): sort the addresses first?
88let mut connections = addrs
89 .iter()
90 .enumerate()
91 .map(|(i, a)| {
92let delay = rt.sleep(CONNECTION_DELAY * i as u32);
93 delay.then(move |_| {
94tracing::debug!("Connecting to {}", a);
95 rt.connect(a)
96 .map_ok(move |stream| (stream, *a))
97 .map_err(move |e| (e, *a))
98 })
99 })
100 .collect::<FuturesUnordered<_>>();
101102let mut ret = None;
103let mut errors = vec![];
104105while let Some(result) = connections.next().await {
106match result {
107Ok(s) => {
108// We got a stream (and address).
109ret = Some(s);
110break;
111 }
112Err((e, a)) => {
113// We got a failure on one of the streams. Store the error.
114 // TODO(eta): ideally we'd start the next connection attempt immediately.
115tor_error::warn_report!(e, "Connection to {} failed", sv(a));
116 errors.push((e, a));
117 }
118 }
119 }
120121// Ensure we don't continue trying to make connections.
122drop(connections);
123124 ret.ok_or_else(|| Error::ChannelBuild {
125 addresses: errors
126 .into_iter()
127 .map(|(e, a)| (sv(a), Arc::new(e)))
128 .collect(),
129 })
130}
131132#[cfg(test)]
133mod test {
134// @@ begin test lint list maintained by maint/add_warning @@
135#![allow(clippy::bool_assert_comparison)]
136 #![allow(clippy::clone_on_copy)]
137 #![allow(clippy::dbg_macro)]
138 #![allow(clippy::mixed_attributes_style)]
139 #![allow(clippy::print_stderr)]
140 #![allow(clippy::print_stdout)]
141 #![allow(clippy::single_char_pattern)]
142 #![allow(clippy::unwrap_used)]
143 #![allow(clippy::unchecked_duration_subtraction)]
144 #![allow(clippy::useless_vec)]
145 #![allow(clippy::needless_pass_by_value)]
146//! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
147148use std::str::FromStr;
149150use tor_rtcompat::{test_with_one_runtime, SleepProviderExt};
151use tor_rtmock::net::MockNetwork;
152153use super::*;
154155#[test]
156fn test_connect_one() {
157let client_addr = "192.0.1.16".parse().unwrap();
158// We'll put a "relay" at this address
159let addr1 = SocketAddr::from_str("192.0.2.17:443").unwrap();
160// We'll put nothing at this address, to generate errors.
161let addr2 = SocketAddr::from_str("192.0.3.18:443").unwrap();
162// Well put a black hole at this address, to generate timeouts.
163let addr3 = SocketAddr::from_str("192.0.4.19:443").unwrap();
164// We'll put a "relay" at this address too
165let addr4 = SocketAddr::from_str("192.0.9.9:443").unwrap();
166167test_with_one_runtime!(|rt| async move {
168// Stub out the internet so that this connection can work.
169let network = MockNetwork::new();
170171// Set up a client and server runtime with a given IP
172let client_rt = network
173 .builder()
174 .add_address(client_addr)
175 .runtime(rt.clone());
176let server_rt = network
177 .builder()
178 .add_address(addr1.ip())
179 .add_address(addr4.ip())
180 .runtime(rt.clone());
181let _listener = server_rt.mock_net().listen(&addr1).await.unwrap();
182let _listener2 = server_rt.mock_net().listen(&addr4).await.unwrap();
183// TODO: Because this test doesn't mock time, there will actually be
184 // delays as we wait for connections to this address to time out. It
185 // would be good to use MockSleepProvider instead, once we figure
186 // out how to make it both reliable and convenient.
187network.add_blackhole(addr3).unwrap();
188189// No addresses? Can't succeed.
190let failure = connect_to_one(&client_rt, &[]).await;
191assert!(failure.is_err());
192193// Connect to a set of addresses including addr1? That's a success.
194for addresses in [
195&[addr1][..],
196&[addr1, addr2][..],
197&[addr2, addr1][..],
198&[addr1, addr3][..],
199&[addr3, addr1][..],
200&[addr1, addr2, addr3][..],
201&[addr3, addr2, addr1][..],
202 ] {
203let (_conn, addr) = connect_to_one(&client_rt, addresses).await.unwrap();
204assert_eq!(addr, addr1);
205 }
206207// Connect to a set of addresses including addr2 but not addr1?
208 // That's an error of one kind or another.
209for addresses in [
210&[addr2][..],
211&[addr2, addr3][..],
212&[addr3, addr2][..],
213&[addr3][..],
214 ] {
215let expect_timeout = addresses.contains(&addr3);
216let failure = rt
217 .timeout(
218 Duration::from_millis(300),
219 connect_to_one(&client_rt, addresses),
220 )
221 .await;
222if expect_timeout {
223assert!(failure.is_err());
224 } else {
225assert!(failure.unwrap().is_err());
226 }
227 }
228229// Connect to addr1 and addr4? The first one should win.
230let (_conn, addr) = connect_to_one(&client_rt, &[addr1, addr4]).await.unwrap();
231assert_eq!(addr, addr1);
232let (_conn, addr) = connect_to_one(&client_rt, &[addr4, addr1]).await.unwrap();
233assert_eq!(addr, addr4);
234 });
235 }
236}