1#![allow(clippy::missing_docs_in_private_items)] use futures::Stream;
5use tor_rtcompat::{
6 NetStreamListener, NetStreamProvider, NoOpStreamOpsHandle, SleepProvider, StreamOps,
7};
8
9use anyhow::anyhow;
10use async_trait::async_trait;
11use futures::io::{AsyncRead, AsyncWrite};
12use pin_project::pin_project;
13use std::io::{Error as IoError, ErrorKind as IoErrorKind, Result as IoResult};
14use std::net::SocketAddr;
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::{Arc, Mutex};
18use std::task::{Context, Poll};
19use std::time::Duration;
20use tor_basic_utils::RngExt as _;
21
22#[derive(Debug, Copy, Clone)]
24pub(crate) enum Action {
25 Work,
27 Fail(Duration, IoErrorKind),
29 Timeout,
31 Blackhole,
33}
34
35#[derive(Debug, Clone)]
37pub(crate) enum ActionPat {
38 Always,
40 V4,
42 V6,
44 Non443,
46}
47
48#[derive(Debug, Clone)]
52pub(crate) struct ConditionalAction {
53 pub(crate) action: Action,
55
56 pub(crate) when: ActionPat,
58}
59
60impl FromStr for Action {
61 type Err = anyhow::Error;
62
63 fn from_str(s: &str) -> Result<Self, Self::Err> {
64 Ok(match s {
65 "none" | "work" => Action::Work,
66 "error" => Action::Fail(Duration::from_millis(10), IoErrorKind::Other),
67 "timeout" => Action::Timeout,
68 "blackhole" => Action::Blackhole,
69 _ => return Err(anyhow!("unrecognized tcp breakage action {:?}", s)),
70 })
71 }
72}
73
74impl FromStr for ActionPat {
75 type Err = anyhow::Error;
76
77 fn from_str(s: &str) -> Result<Self, Self::Err> {
78 Ok(match s {
79 "all" => ActionPat::Always,
80 "v4" => ActionPat::V4,
81 "v6" => ActionPat::V6,
82 "non443" => ActionPat::Non443,
83 _ => return Err(anyhow!("unrecognized tcp breakage condition {:?}", s)),
84 })
85 }
86}
87
88impl ConditionalAction {
89 fn applies_to(&self, addr: &SocketAddr) -> bool {
90 match (addr, &self.when) {
91 (_, ActionPat::Always) => true,
92 (SocketAddr::V4(_), ActionPat::V4) => true,
93 (SocketAddr::V6(_), ActionPat::V6) => true,
94 (sa, ActionPat::Non443) if sa.port() != 443 => true,
95 (_, _) => false,
96 }
97 }
98}
99
100impl Default for ConditionalAction {
101 fn default() -> Self {
102 Self {
103 action: Action::Work,
104 when: ActionPat::Always,
105 }
106 }
107}
108
109#[pin_project]
111#[derive(Debug, Clone)]
112pub(crate) struct BrokenTcpProvider<R> {
113 #[pin]
115 inner: R,
116 action: Arc<Mutex<ConditionalAction>>,
118}
119
120impl<R> BrokenTcpProvider<R> {
121 pub(crate) fn new(inner: R, action: ConditionalAction) -> Self {
124 Self {
125 inner,
126 action: Arc::new(Mutex::new(action)),
127 }
128 }
129
130 pub(crate) fn set_action(&self, action: ConditionalAction) {
133 *self.action.lock().expect("Lock poisoned") = action;
134 }
135
136 fn get_action(&self, addr: &SocketAddr) -> Action {
138 let action = self.action.lock().expect("Lock poisoned");
139 if action.applies_to(addr) {
140 action.action
141 } else {
142 Action::Work
143 }
144 }
145}
146
147#[async_trait]
148impl<R: NetStreamProvider + SleepProvider> NetStreamProvider for BrokenTcpProvider<R> {
149 type Stream = BreakableTcpStream<R::Stream>;
150 type Listener = BrokenTcpProvider<R::Listener>;
151
152 async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::Stream> {
153 match self.get_action(addr) {
154 Action::Work => {
155 let conn = self.inner.connect(addr).await?;
156 Ok(BreakableTcpStream::Present(conn))
157 }
158 Action::Fail(dur, kind) => {
159 let d = rand::rng().gen_range_infallible(..=dur);
160 self.inner.sleep(d).await;
161 Err(IoError::new(kind, anyhow::anyhow!("intentional failure")))
162 }
163 Action::Timeout => futures::future::pending().await,
164 Action::Blackhole => Ok(BreakableTcpStream::Broken),
165 }
166 }
167
168 async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
169 let listener = self.inner.listen(addr).await?;
170 Ok(BrokenTcpProvider {
171 inner: listener,
172 action: self.action.clone(),
173 })
174 }
175}
176
177#[pin_project(project = BreakableTcpStreamP)]
179#[derive(Debug, Clone)]
180pub(crate) enum BreakableTcpStream<S> {
181 Broken,
184
185 Present(#[pin] S),
187}
188
189impl<S: AsyncRead> AsyncRead for BreakableTcpStream<S> {
190 fn poll_read(
191 self: Pin<&mut Self>,
192 cx: &mut Context<'_>,
193 buf: &mut [u8],
194 ) -> Poll<IoResult<usize>> {
195 let this = self.project();
196 match this {
197 BreakableTcpStreamP::Present(s) => s.poll_read(cx, buf),
198 BreakableTcpStreamP::Broken => Poll::Pending,
199 }
200 }
201}
202
203impl<S: AsyncWrite> AsyncWrite for BreakableTcpStream<S> {
204 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
205 match self.project() {
206 BreakableTcpStreamP::Present(s) => s.poll_write(cx, buf),
207 BreakableTcpStreamP::Broken => Poll::Ready(Ok(buf.len())),
208 }
209 }
210 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
211 match self.project() {
212 BreakableTcpStreamP::Present(s) => s.poll_flush(cx),
213 BreakableTcpStreamP::Broken => Poll::Ready(Ok(())),
214 }
215 }
216 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
217 match self.project() {
218 BreakableTcpStreamP::Present(s) => s.poll_close(cx),
219 BreakableTcpStreamP::Broken => Poll::Ready(Ok(())),
220 }
221 }
222}
223
224impl<S: StreamOps> StreamOps for BreakableTcpStream<S> {
225 fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
226 match self {
227 BreakableTcpStream::Broken => Ok(()),
228 BreakableTcpStream::Present(s) => s.set_tcp_notsent_lowat(notsent_lowat),
229 }
230 }
231
232 fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
233 match self {
234 BreakableTcpStream::Broken => Box::new(NoOpStreamOpsHandle::default()),
235 BreakableTcpStream::Present(s) => s.new_handle(),
236 }
237 }
238}
239
240impl<S: NetStreamListener + Send + Sync> NetStreamListener for BrokenTcpProvider<S> {
241 type Stream = BreakableTcpStream<S::Stream>;
242 type Incoming = BrokenTcpProvider<S::Incoming>;
243
244 fn incoming(self) -> Self::Incoming {
245 BrokenTcpProvider {
246 inner: self.inner.incoming(),
247 action: self.action,
248 }
249 }
250
251 fn local_addr(&self) -> IoResult<SocketAddr> {
252 self.inner.local_addr()
253 }
254}
255impl<S, T> Stream for BrokenTcpProvider<S>
256where
257 S: Stream<Item = IoResult<(T, SocketAddr)>>,
258{
259 type Item = IoResult<(BreakableTcpStream<T>, SocketAddr)>;
260
261 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
262 match self.project().inner.poll_next(cx) {
263 Poll::Pending => Poll::Pending,
264 Poll::Ready(None) => Poll::Ready(None),
265 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
266 Poll::Ready(Some(Ok((s, a)))) => {
267 Poll::Ready(Some(Ok((BreakableTcpStream::Present(s), a))))
268 }
269 }
270 }
271}