1use std::sync::{Arc, Mutex};
4
5use futures::{
6 select_biased, task::SpawnExt as _, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Future,
7 FutureExt as _, Stream, StreamExt as _,
8};
9use itertools::iproduct;
10use oneshot_fused_workaround as oneshot;
11use safelog::sensitive as sv;
12use std::collections::HashMap;
13use std::io::{Error as IoError, Result as IoResult};
14use strum::IntoEnumIterator;
15use tor_cell::relaycell::msg as relaymsg;
16use tor_error::{debug_report, ErrorKind, HasKind};
17use tor_hsservice::{HsNickname, RendRequest, StreamRequest};
18use tor_log_ratelim::log_ratelim;
19use tor_proto::stream::{DataStream, IncomingStreamRequest};
20use tor_rtcompat::Runtime;
21
22use crate::config::{
23 Encapsulation, ProxyAction, ProxyActionDiscriminants, ProxyConfig, TargetAddr,
24};
25
26#[derive(Debug)]
29pub struct OnionServiceReverseProxy {
30 state: Mutex<State>,
32}
33
34#[derive(Debug)]
36struct State {
37 config: ProxyConfig,
39 shutdown_tx: Option<oneshot::Sender<void::Void>>,
41 shutdown_rx: futures::future::Shared<oneshot::Receiver<void::Void>>,
43}
44
45#[derive(Clone, Debug, thiserror::Error)]
47#[non_exhaustive]
48pub enum HandleRequestsError {
49 #[error("Unable to spawn a task")]
51 Spawn(#[source] Arc<futures::task::SpawnError>),
52}
53
54impl HasKind for HandleRequestsError {
55 fn kind(&self) -> ErrorKind {
56 match self {
57 HandleRequestsError::Spawn(e) => e.kind(),
58 }
59 }
60}
61
62impl OnionServiceReverseProxy {
63 pub fn new(config: ProxyConfig) -> Arc<Self> {
65 let (shutdown_tx, shutdown_rx) = oneshot::channel();
66 Arc::new(Self {
67 state: Mutex::new(State {
68 config,
69 shutdown_tx: Some(shutdown_tx),
70 shutdown_rx: shutdown_rx.shared(),
71 }),
72 })
73 }
74
75 pub fn reconfigure(
80 &self,
81 config: ProxyConfig,
82 how: tor_config::Reconfigure,
83 ) -> Result<(), tor_config::ReconfigureError> {
84 if how == tor_config::Reconfigure::CheckAllOrNothing {
85 return Ok(());
87 }
88 let mut state = self.state.lock().expect("poisoned lock");
89 state.config = config;
90 Ok(())
95 }
96
97 pub fn shutdown(&self) {
99 let mut state = self.state.lock().expect("poisoned lock");
100 let _ = state.shutdown_tx.take();
101 }
102
103 pub async fn handle_requests<R, S>(
110 &self,
111 runtime: R,
112 nickname: HsNickname,
113 requests: S,
114 ) -> Result<(), HandleRequestsError>
115 where
116 R: Runtime,
117 S: Stream<Item = RendRequest> + Unpin,
118 {
119 let mut stream_requests = tor_hsservice::handle_rend_requests(requests).fuse();
120 let mut shutdown_rx = self
121 .state
122 .lock()
123 .expect("poisoned lock")
124 .shutdown_rx
125 .clone()
126 .fuse();
127 let nickname = Arc::new(nickname);
128
129 #[cfg(feature = "metrics")]
131 #[derive(Clone, Copy, Eq, PartialEq, Hash)]
132 enum CounterSelector {
133 Ret(Result<(), ()>),
135 Total,
137 }
138
139 #[cfg(feature = "metrics")]
140 let metrics_counters = {
141 use CounterSelector as CS;
142
143 let counters = iproduct!(
144 ProxyActionDiscriminants::iter(),
145 [
146 (CS::Total, "arti_hss_proxy_connections_total"),
147 (CS::Ret(Ok(())), "arti_hss_proxy_connections_ok_total"),
148 (CS::Ret(Err(())), "arti_hss_proxy_connections_failed_total"),
149 ],
150 )
151 .map(|(action, (outcome, name))| {
152 let k = (action, outcome);
153 let nickname = nickname.to_string();
154 let action: &str = action.into();
155 let v = metrics::counter!(name, "nickname" => nickname, "action" => action);
156 (k, v)
157 })
158 .collect::<HashMap<(ProxyActionDiscriminants, CounterSelector), _>>();
159
160 Arc::new(counters)
161 };
162
163 loop {
164 let stream_request = select_biased! {
165 _ = shutdown_rx => return Ok(()),
166 stream_request = stream_requests.next() => match stream_request {
167 None => return Ok(()),
168 Some(s) => s,
169 }
170 };
171
172 runtime.spawn({
173 let action = self.choose_action(stream_request.request());
174 let runtime = runtime.clone();
175 let nickname = nickname.clone();
176 let req = stream_request.request().clone();
177
178 #[cfg(feature = "metrics")]
179 let metrics_counters = metrics_counters.clone();
180
181 async move {
182 let outcome =
183 run_action(runtime, nickname.as_ref(), action.clone(), stream_request).await;
184
185 #[cfg(feature = "metrics")]
186 {
187 use CounterSelector as CS;
188
189 let action = ProxyActionDiscriminants::from(&action);
190 let outcome = outcome.as_ref().map(|_|()).map_err(|_|());
191 for outcome in [CS::Total, CS::Ret(outcome)] {
192 if let Some(counter) = metrics_counters.get(&(action, outcome)) {
193 counter.increment(1);
194 } else {
195 }
197 }
198 }
199
200 log_ratelim!(
201 "Performing action on {}", nickname;
202 outcome;
203 Err(_) => WARN, "Unable to take action {:?} for request {:?}", sv(action), sv(req)
204 );
205 }
206 })
207 .map_err(|e| HandleRequestsError::Spawn(Arc::new(e)))?;
208 }
209 }
210
211 fn choose_action(&self, stream_request: &IncomingStreamRequest) -> ProxyAction {
214 let port: u16 = match stream_request {
215 IncomingStreamRequest::Begin(begin) => {
216 begin.port()
219 }
220 other => {
221 tracing::warn!(
222 "Rejecting onion service request for invalid command {:?}. Internal error.",
223 other
224 );
225 return ProxyAction::DestroyCircuit;
226 }
227 };
228
229 self.state
230 .lock()
231 .expect("poisoned lock")
232 .config
233 .resolve_port_for_begin(port)
234 .cloned()
235 .unwrap_or(ProxyAction::DestroyCircuit)
237 }
238}
239
240async fn run_action<R: Runtime>(
242 runtime: R,
243 nickname: &HsNickname,
244 action: ProxyAction,
245 request: StreamRequest,
246) -> Result<(), RequestFailed> {
247 match action {
248 ProxyAction::DestroyCircuit => {
249 request
250 .shutdown_circuit()
251 .map_err(RequestFailed::CantDestroy)?;
252 }
253 ProxyAction::Forward(encap, target) => match (encap, target) {
254 (Encapsulation::Simple, ref addr @ TargetAddr::Inet(a)) => {
255 let rt_clone = runtime.clone();
256 forward_connection(rt_clone, request, runtime.connect(&a), nickname, addr).await?;
257 } },
263 ProxyAction::RejectStream => {
264 let end = relaymsg::End::new_with_reason(relaymsg::EndReason::DONE);
266
267 request
268 .reject(end)
269 .await
270 .map_err(RequestFailed::CantReject)?;
271 }
272 ProxyAction::IgnoreStream => drop(request),
273 };
274 Ok(())
275}
276
277#[derive(thiserror::Error, Debug, Clone)]
279enum RequestFailed {
280 #[error("Unable to destroy onion service circuit")]
282 CantDestroy(#[source] tor_error::Bug),
283
284 #[error("Unable to reject onion service request")]
286 CantReject(#[source] tor_hsservice::ClientError),
287
288 #[error("Unable to accept onion service connection")]
291 AcceptRemote(#[source] tor_hsservice::ClientError),
292
293 #[error("Unable to spawn task")]
295 Spawn(#[source] Arc<futures::task::SpawnError>),
296}
297
298impl HasKind for RequestFailed {
299 fn kind(&self) -> ErrorKind {
300 match self {
301 RequestFailed::CantDestroy(e) => e.kind(),
302 RequestFailed::CantReject(e) => e.kind(),
303 RequestFailed::AcceptRemote(e) => e.kind(),
304 RequestFailed::Spawn(e) => e.kind(),
305 }
306 }
307}
308
309async fn forward_connection<R, FUT, TS>(
317 runtime: R,
318 request: StreamRequest,
319 target_stream_future: FUT,
320 nickname: &HsNickname,
321 addr: &TargetAddr,
322) -> Result<(), RequestFailed>
323where
324 R: Runtime,
325 FUT: Future<Output = Result<TS, IoError>>,
326 TS: AsyncRead + AsyncWrite + Send + 'static,
327{
328 let local_stream = target_stream_future.await.map_err(Arc::new);
329
330 log_ratelim!(
333 "Connecting to {} for onion service {}", sv(addr), nickname;
334 local_stream
335 );
336
337 let local_stream = match local_stream {
338 Ok(s) => s,
339 Err(_) => {
340 let end = relaymsg::End::new_with_reason(relaymsg::EndReason::DONE);
341 if let Err(e_rejecting) = request.reject(end).await {
342 debug_report!(
343 &e_rejecting,
344 "Unable to reject onion service request from client"
345 );
346 return Err(RequestFailed::CantReject(e_rejecting));
347 }
348 return Ok(());
351 }
352 };
353
354 let onion_service_stream: DataStream = {
355 let connected = relaymsg::Connected::new_empty();
356 request
357 .accept(connected)
358 .await
359 .map_err(RequestFailed::AcceptRemote)?
360 };
361
362 let (svc_r, svc_w) = onion_service_stream.split();
363 let (local_r, local_w) = local_stream.split();
364
365 runtime
366 .spawn(copy_interactive(local_r, svc_w).map(|_| ()))
367 .map_err(|e| RequestFailed::Spawn(Arc::new(e)))?;
368 runtime
369 .spawn(copy_interactive(svc_r, local_w).map(|_| ()))
370 .map_err(|e| RequestFailed::Spawn(Arc::new(e)))?;
371
372 Ok(())
373}
374
375async fn copy_interactive<R, W>(mut reader: R, mut writer: W) -> IoResult<()>
391where
392 R: AsyncRead + Unpin,
393 W: AsyncWrite + Unpin,
394{
395 use futures::{poll, task::Poll};
396
397 let mut buf = [0_u8; 1024];
398
399 let loop_result: IoResult<()> = loop {
406 let mut read_future = reader.read(&mut buf[..]);
407 match poll!(&mut read_future) {
408 Poll::Ready(Err(e)) => break Err(e),
409 Poll::Ready(Ok(0)) => break Ok(()), Poll::Ready(Ok(n)) => {
411 writer.write_all(&buf[..n]).await?;
412 continue;
413 }
414 Poll::Pending => writer.flush().await?,
415 }
416
417 match read_future.await {
419 Err(e) => break Err(e),
420 Ok(0) => break Ok(()),
421 Ok(n) => writer.write_all(&buf[..n]).await?,
422 }
423 };
424
425 let flush_result = if loop_result.is_ok() {
430 writer.close().await
431 } else {
432 writer.flush().await
433 };
434
435 loop_result.or(flush_result)
436}