arti_testing/rt/
badtcp.rs

1//! Implement a NetStreamProvider that can break things.
2#![allow(clippy::missing_docs_in_private_items)] // required for pin_project(enum)
3
4use futures::Stream;
5use tor_rtcompat::{
6    NetStreamListener, NetStreamProvider, NoOpStreamOpsHandle, SleepProvider, StreamOps,
7};
8
9use anyhow::anyhow;
10use async_trait::async_trait;
11use futures::io::{AsyncRead, AsyncWrite};
12use pin_project::pin_project;
13use std::io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult};
14use std::net::SocketAddr;
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::{Arc, Mutex};
18use std::task::{Context, Poll};
19use std::time::Duration;
20use tor_basic_utils::RngExt as _;
21
22/// An action that we can take upon trying to make a TCP connection.
23#[derive(Debug, Copy, Clone)]
24pub(crate) enum Action {
25    /// Let the connection work as intended.
26    Work,
27    /// Wait for a random interval up to the given duration, then return an error.
28    Fail(Duration, IoErrorKind),
29    /// Time out indefinitely.
30    Timeout,
31    /// Succeed, then drop all data.
32    Blackhole,
33}
34
35/// When should an Action apply?
36#[derive(Debug, Clone)]
37pub(crate) enum ActionPat {
38    /// always apply
39    Always,
40    /// Apply to all ipv4
41    V4,
42    /// apply to all ipv6
43    V6,
44    /// apply to all ports but 443
45    Non443,
46}
47
48/// An Action plus a set of conditions when it applies.
49///
50/// (When the action doesn't apply, connections will just `Action::Work`.
51#[derive(Debug, Clone)]
52pub(crate) struct ConditionalAction {
53    /// The underlying action
54    pub(crate) action: Action,
55
56    /// When should the action apply?
57    pub(crate) when: ActionPat,
58}
59
60impl FromStr for Action {
61    type Err = anyhow::Error;
62
63    fn from_str(s: &str) -> Result<Self, Self::Err> {
64        Ok(match s {
65            "none" | "work" => Action::Work,
66            "error" => Action::Fail(Duration::from_millis(10), IoErrorKind::Other),
67            "timeout" => Action::Timeout,
68            "blackhole" => Action::Blackhole,
69            _ => return Err(anyhow!("unrecognized tcp breakage action {:?}", s)),
70        })
71    }
72}
73
74impl FromStr for ActionPat {
75    type Err = anyhow::Error;
76
77    fn from_str(s: &str) -> Result<Self, Self::Err> {
78        Ok(match s {
79            "all" => ActionPat::Always,
80            "v4" => ActionPat::V4,
81            "v6" => ActionPat::V6,
82            "non443" => ActionPat::Non443,
83            _ => return Err(anyhow!("unrecognized tcp breakage condition {:?}", s)),
84        })
85    }
86}
87
88impl ConditionalAction {
89    fn applies_to(&self, addr: &SocketAddr) -> bool {
90        match (addr, &self.when) {
91            (_, ActionPat::Always) => true,
92            (SocketAddr::V4(_), ActionPat::V4) => true,
93            (SocketAddr::V6(_), ActionPat::V6) => true,
94            (sa, ActionPat::Non443) if sa.port() != 443 => true,
95            (_, _) => false,
96        }
97    }
98}
99
100impl Default for ConditionalAction {
101    fn default() -> Self {
102        Self {
103            action: Action::Work,
104            when: ActionPat::Always,
105        }
106    }
107}
108
109/// A NetStreamProvider that can make its connections fail.
110#[pin_project]
111#[derive(Debug, Clone)]
112pub(crate) struct BrokenTcpProvider<R> {
113    /// An underlying NetStreamProvider to use when we actually want our connections to succeed
114    #[pin]
115    inner: R,
116    /// The action to take when we try to make an outbound connection.
117    action: Arc<Mutex<ConditionalAction>>,
118}
119
120impl<R> BrokenTcpProvider<R> {
121    /// Construct a new BrokenTcpProvider which responds to all outbound
122    /// connections by taking the specified action.
123    pub(crate) fn new(inner: R, action: ConditionalAction) -> Self {
124        Self {
125            inner,
126            action: Arc::new(Mutex::new(action)),
127        }
128    }
129
130    /// Cause the provider to respond to all outbound connection attempts
131    /// with the specified action.
132    pub(crate) fn set_action(&self, action: ConditionalAction) {
133        *self.action.lock().expect("Lock poisoned") = action;
134    }
135
136    /// Return the action to take for a connection to `addr`.
137    fn get_action(&self, addr: &SocketAddr) -> Action {
138        let action = self.action.lock().expect("Lock poisoned");
139        if action.applies_to(addr) {
140            action.action
141        } else {
142            Action::Work
143        }
144    }
145}
146
147#[async_trait]
148impl<R: NetStreamProvider + SleepProvider> NetStreamProvider for BrokenTcpProvider<R> {
149    type Stream = BreakableTcpStream<R::Stream>;
150    type Listener = BrokenTcpProvider<R::Listener>;
151
152    async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::Stream> {
153        match self.get_action(addr) {
154            Action::Work => {
155                let conn = self.inner.connect(addr).await?;
156                Ok(BreakableTcpStream::Present(conn))
157            }
158            Action::Fail(dur, kind) => {
159                let d = rand::rng().gen_range_infallible(..=dur);
160                self.inner.sleep(d).await;
161                Err(IoError::new(kind, anyhow::anyhow!("intentional failure")))
162            }
163            Action::Timeout => futures::future::pending().await,
164            Action::Blackhole => Ok(BreakableTcpStream::Broken),
165        }
166    }
167
168    async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
169        let listener = self.inner.listen(addr).await?;
170        Ok(BrokenTcpProvider {
171            inner: listener,
172            action: self.action.clone(),
173        })
174    }
175}
176
177/// A TCP stream that is either present, or black-holed.
178#[pin_project(project = BreakableTcpStreamP)]
179#[derive(Debug, Clone)]
180pub(crate) enum BreakableTcpStream<S> {
181    /// The stream is black-holed: there is nothing to read, and all writes
182    /// succeed but are ignored.
183    Broken,
184
185    /// The stream is present and should work normally.
186    Present(#[pin] S),
187}
188
189impl<S: AsyncRead> AsyncRead for BreakableTcpStream<S> {
190    fn poll_read(
191        self: Pin<&mut Self>,
192        cx: &mut Context<'_>,
193        buf: &mut [u8],
194    ) -> Poll<IoResult<usize>> {
195        let this = self.project();
196        match this {
197            BreakableTcpStreamP::Present(s) => s.poll_read(cx, buf),
198            BreakableTcpStreamP::Broken => Poll::Pending,
199        }
200    }
201}
202
203impl<S: AsyncWrite> AsyncWrite for BreakableTcpStream<S> {
204    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
205        match self.project() {
206            BreakableTcpStreamP::Present(s) => s.poll_write(cx, buf),
207            BreakableTcpStreamP::Broken => Poll::Ready(Ok(buf.len())),
208        }
209    }
210    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
211        match self.project() {
212            BreakableTcpStreamP::Present(s) => s.poll_flush(cx),
213            BreakableTcpStreamP::Broken => Poll::Ready(Ok(())),
214        }
215    }
216    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
217        match self.project() {
218            BreakableTcpStreamP::Present(s) => s.poll_close(cx),
219            BreakableTcpStreamP::Broken => Poll::Ready(Ok(())),
220        }
221    }
222}
223
224impl<S: StreamOps> StreamOps for BreakableTcpStream<S> {
225    fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
226        match self {
227            BreakableTcpStream::Broken => Ok(()),
228            BreakableTcpStream::Present(s) => s.set_tcp_notsent_lowat(notsent_lowat),
229        }
230    }
231
232    fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
233        match self {
234            BreakableTcpStream::Broken => Box::new(NoOpStreamOpsHandle::default()),
235            BreakableTcpStream::Present(s) => s.new_handle(),
236        }
237    }
238}
239
240impl<S: NetStreamListener + Send + Sync> NetStreamListener for BrokenTcpProvider<S> {
241    type Stream = BreakableTcpStream<S::Stream>;
242    type Incoming = BrokenTcpProvider<S::Incoming>;
243
244    fn incoming(self) -> Self::Incoming {
245        BrokenTcpProvider {
246            inner: self.inner.incoming(),
247            action: self.action,
248        }
249    }
250
251    fn local_addr(&self) -> IoResult<SocketAddr> {
252        self.inner.local_addr()
253    }
254}
255impl<S, T> Stream for BrokenTcpProvider<S>
256where
257    S: Stream<Item = IoResult<(T, SocketAddr)>>,
258{
259    type Item = IoResult<(BreakableTcpStream<T>, SocketAddr)>;
260
261    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
262        match self.project().inner.poll_next(cx) {
263            Poll::Pending => Poll::Pending,
264            Poll::Ready(None) => Poll::Ready(None),
265            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
266            Poll::Ready(Some(Ok((s, a)))) => {
267                Poll::Ready(Some(Ok((BreakableTcpStream::Present(s), a))))
268            }
269        }
270    }
271}