1use crate::congestion::sendme;
4use crate::stream::{AnyCmdChecker, StreamSendFlowControl};
5use crate::tunnel::circuit::{StreamMpscReceiver, StreamMpscSender};
6use crate::tunnel::halfstream::HalfStream;
7use crate::tunnel::reactor::circuit::RECV_WINDOW_INIT;
8use crate::util::stream_poll_set::{KeyAlreadyInsertedError, StreamPollSet};
9use crate::{Error, Result};
10use pin_project::pin_project;
11use tor_async_utils::peekable_stream::{PeekableStream, UnobtrusivePeekableStream};
12use tor_async_utils::stream_peek::StreamUnobtrusivePeeker;
13use tor_cell::relaycell::{msg::AnyRelayMsg, StreamId};
14use tor_cell::relaycell::{RelayMsg, UnparsedRelayMsg};
15
16use std::collections::hash_map;
17use std::collections::HashMap;
18use std::num::NonZeroU16;
19use std::pin::Pin;
20use std::task::{Poll, Waker};
21use tor_error::{bad_api_usage, internal};
22
23use rand::Rng;
24
25use tracing::debug;
26
27#[derive(Debug)]
32#[pin_project]
33pub(super) struct OpenStreamEnt {
34 pub(super) sink: StreamMpscSender<UnparsedRelayMsg>,
36 pub(super) dropped: u16,
39 pub(super) cmd_checker: AnyCmdChecker,
41 flow_ctrl: StreamSendFlowControl,
45 #[pin]
50 rx: StreamUnobtrusivePeeker<StreamMpscReceiver<AnyRelayMsg>>,
51 flow_ctrl_waker: Option<Waker>,
54}
55
56impl OpenStreamEnt {
57 pub(crate) fn can_send<M: RelayMsg>(&self, msg: &M) -> bool {
59 self.flow_ctrl.can_send(msg)
60 }
61
62 pub(crate) fn put_for_incoming_sendme(&mut self) -> Result<()> {
69 self.flow_ctrl.put_for_incoming_sendme()?;
70 if let Some(waker) = self.flow_ctrl_waker.take() {
72 waker.wake();
73 }
74 Ok(())
75 }
76
77 pub(crate) fn take_capacity_to_send<M: RelayMsg>(&mut self, msg: &M) -> Result<()> {
84 self.flow_ctrl.take_capacity_to_send(msg)
85 }
86}
87
88#[derive(Debug)]
92#[pin_project]
93struct OpenStreamEntStream {
94 #[pin]
96 inner: OpenStreamEnt,
97}
98
99impl futures::Stream for OpenStreamEntStream {
100 type Item = AnyRelayMsg;
101
102 fn poll_next(
103 mut self: std::pin::Pin<&mut Self>,
104 cx: &mut std::task::Context<'_>,
105 ) -> Poll<Option<Self::Item>> {
106 if !self.as_mut().poll_peek_mut(cx).is_ready() {
107 return Poll::Pending;
108 };
109 let res = self.project().inner.project().rx.poll_next(cx);
110 debug_assert!(res.is_ready());
111 res
118 }
119}
120
121impl PeekableStream for OpenStreamEntStream {
122 fn poll_peek_mut(
123 self: Pin<&mut Self>,
124 cx: &mut std::task::Context<'_>,
125 ) -> Poll<Option<&mut <Self as futures::Stream>::Item>> {
126 let s = self.project();
127 let inner = s.inner.project();
128 let m = match inner.rx.poll_peek_mut(cx) {
129 Poll::Ready(Some(m)) => m,
130 Poll::Ready(None) => return Poll::Ready(None),
131 Poll::Pending => return Poll::Pending,
132 };
133 if !inner.flow_ctrl.can_send(m) {
134 inner.flow_ctrl_waker.replace(cx.waker().clone());
135 return Poll::Pending;
136 }
137 Poll::Ready(Some(m))
138 }
139}
140
141impl UnobtrusivePeekableStream for OpenStreamEntStream {
142 fn unobtrusive_peek_mut(
143 self: std::pin::Pin<&mut Self>,
144 ) -> Option<&mut <Self as futures::Stream>::Item> {
145 let s = self.project();
146 let inner = s.inner.project();
147 let m = inner.rx.unobtrusive_peek_mut()?;
148 if inner.flow_ctrl.can_send(m) {
149 Some(m)
150 } else {
151 None
152 }
153 }
154}
155
156#[derive(Debug)]
159pub(super) struct EndSentStreamEnt {
160 pub(super) half_stream: HalfStream,
163 explicitly_dropped: bool,
166}
167
168#[derive(Debug)]
170enum ClosedStreamEnt {
171 EndReceived,
174 EndSent(EndSentStreamEnt),
180}
181
182pub(super) enum StreamEntMut<'a> {
184 Open(&'a mut OpenStreamEnt),
186 EndReceived,
189 EndSent(&'a mut EndSentStreamEnt),
192}
193
194impl<'a> From<&'a mut ClosedStreamEnt> for StreamEntMut<'a> {
195 fn from(value: &'a mut ClosedStreamEnt) -> Self {
196 match value {
197 ClosedStreamEnt::EndReceived => Self::EndReceived,
198 ClosedStreamEnt::EndSent(e) => Self::EndSent(e),
199 }
200 }
201}
202
203impl<'a> From<&'a mut OpenStreamEntStream> for StreamEntMut<'a> {
204 fn from(value: &'a mut OpenStreamEntStream) -> Self {
205 Self::Open(&mut value.inner)
206 }
207}
208
209#[derive(Debug, Copy, Clone, Eq, PartialEq)]
212pub(super) enum ShouldSendEnd {
213 Send,
215 DontSend,
217}
218
219#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
221struct Priority(u64);
222
223pub(super) struct StreamMap {
226 open_streams: StreamPollSet<StreamId, Priority, OpenStreamEntStream>,
230 closed_streams: HashMap<StreamId, ClosedStreamEnt>,
234 next_stream_id: StreamId,
237 next_priority: Priority,
242}
243
244impl StreamMap {
245 pub(super) fn new() -> Self {
247 let mut rng = rand::rng();
248 let next_stream_id: NonZeroU16 = rng.random();
249 StreamMap {
250 open_streams: StreamPollSet::new(),
251 closed_streams: HashMap::new(),
252 next_stream_id: next_stream_id.into(),
253 next_priority: Priority(0),
254 }
255 }
256
257 pub(super) fn n_open_streams(&self) -> usize {
259 self.open_streams.len()
260 }
261
262 fn take_next_priority(&mut self) -> Priority {
264 let rv = self.next_priority;
265 self.next_priority = Priority(rv.0 + 1);
266 rv
267 }
268
269 pub(super) fn add_ent(
271 &mut self,
272 sink: StreamMpscSender<UnparsedRelayMsg>,
273 rx: StreamMpscReceiver<AnyRelayMsg>,
274 flow_ctrl: StreamSendFlowControl,
275 cmd_checker: AnyCmdChecker,
276 ) -> Result<StreamId> {
277 let mut stream_ent = OpenStreamEntStream {
278 inner: OpenStreamEnt {
279 sink,
280 flow_ctrl,
281 dropped: 0,
282 cmd_checker,
283 rx: StreamUnobtrusivePeeker::new(rx),
284 flow_ctrl_waker: None,
285 },
286 };
287 let priority = self.take_next_priority();
288 for _ in 1..=65536 {
293 let id: StreamId = self.next_stream_id;
294 self.next_stream_id = wrapping_next_stream_id(self.next_stream_id);
295 stream_ent = match self.open_streams.try_insert(id, priority, stream_ent) {
296 Ok(_) => return Ok(id),
297 Err(KeyAlreadyInsertedError {
298 key: _,
299 priority: _,
300 stream,
301 }) => stream,
302 };
303 }
304
305 Err(Error::IdRangeFull)
306 }
307
308 #[cfg(feature = "hs-service")]
310 pub(super) fn add_ent_with_id(
311 &mut self,
312 sink: StreamMpscSender<UnparsedRelayMsg>,
313 rx: StreamMpscReceiver<AnyRelayMsg>,
314 flow_ctrl: StreamSendFlowControl,
315 id: StreamId,
316 cmd_checker: AnyCmdChecker,
317 ) -> Result<()> {
318 let stream_ent = OpenStreamEntStream {
319 inner: OpenStreamEnt {
320 sink,
321 flow_ctrl,
322 dropped: 0,
323 cmd_checker,
324 rx: StreamUnobtrusivePeeker::new(rx),
325 flow_ctrl_waker: None,
326 },
327 };
328 let priority = self.take_next_priority();
329 self.open_streams
330 .try_insert(id, priority, stream_ent)
331 .map_err(|_| Error::IdUnavailable(id))
332 }
333
334 pub(super) fn get_mut(&mut self, id: StreamId) -> Option<StreamEntMut<'_>> {
336 if let Some(e) = self.open_streams.stream_mut(&id) {
337 return Some(e.into());
338 }
339 if let Some(e) = self.closed_streams.get_mut(&id) {
340 return Some(e.into());
341 }
342 None
343 }
344
345 pub(super) fn ending_msg_received(&mut self, id: StreamId) -> Result<()> {
350 if self.open_streams.remove(&id).is_some() {
351 let prev = self.closed_streams.insert(id, ClosedStreamEnt::EndReceived);
352 debug_assert!(prev.is_none(), "Unexpected duplicate entry for {id}");
353 return Ok(());
354 }
355 let hash_map::Entry::Occupied(closed_entry) = self.closed_streams.entry(id) else {
356 return Err(Error::CircProto(
357 "Received END cell on nonexistent stream".into(),
358 ));
359 };
360 match closed_entry.get() {
362 ClosedStreamEnt::EndReceived => Err(Error::CircProto(
363 "Received two END cells on same stream".into(),
364 )),
365 ClosedStreamEnt::EndSent { .. } => {
366 debug!("Actually got an end cell on a half-closed stream!");
367 closed_entry.remove_entry();
370 Ok(())
371 }
372 }
373 }
374
375 pub(super) fn terminate(
379 &mut self,
380 id: StreamId,
381 why: TerminateReason,
382 ) -> Result<ShouldSendEnd> {
383 use TerminateReason as TR;
384
385 if let Some((_id, _priority, ent)) = self.open_streams.remove(&id) {
386 let OpenStreamEntStream {
387 inner:
388 OpenStreamEnt {
389 flow_ctrl,
390 dropped,
391 cmd_checker,
392 ..
395 },
396 } = ent;
397 let mut recv_window = sendme::StreamRecvWindow::new(RECV_WINDOW_INIT);
401 recv_window.decrement_n(dropped)?;
402 let half_stream = HalfStream::new(flow_ctrl, recv_window, cmd_checker);
404 let explicitly_dropped = why == TR::StreamTargetClosed;
405 let prev = self.closed_streams.insert(
406 id,
407 ClosedStreamEnt::EndSent(EndSentStreamEnt {
408 half_stream,
409 explicitly_dropped,
410 }),
411 );
412 debug_assert!(prev.is_none(), "Unexpected duplicate entry for {id}");
413 return Ok(ShouldSendEnd::Send);
414 }
415
416 match self
418 .closed_streams
419 .remove(&id)
420 .ok_or_else(|| Error::from(internal!("Somehow we terminated a nonexistent stream?")))?
421 {
422 ClosedStreamEnt::EndReceived => Ok(ShouldSendEnd::DontSend),
423 ClosedStreamEnt::EndSent(EndSentStreamEnt {
424 ref mut explicitly_dropped,
425 ..
426 }) => match (*explicitly_dropped, why) {
427 (false, TR::StreamTargetClosed) => {
428 *explicitly_dropped = true;
429 Ok(ShouldSendEnd::DontSend)
430 }
431 (true, TR::StreamTargetClosed) => {
432 Err(bad_api_usage!("Tried to close an already closed stream.").into())
433 }
434 (_, TR::ExplicitEnd) => Err(bad_api_usage!(
435 "Tried to end an already closed stream. (explicitly_dropped={:?})",
436 *explicitly_dropped
437 )
438 .into()),
439 },
440 }
441 }
442
443 pub(super) fn poll_ready_streams_iter<'a>(
453 &'a mut self,
454 cx: &mut std::task::Context,
455 ) -> impl Iterator<Item = (StreamId, Option<&'a AnyRelayMsg>)> + 'a {
456 self.open_streams
457 .poll_ready_iter_mut(cx)
458 .map(|(sid, _priority, ent)| {
459 let ent = Pin::new(ent);
460 let msg = ent.unobtrusive_peek();
461 (*sid, msg)
462 })
463 }
464
465 pub(super) fn take_ready_msg(&mut self, sid: StreamId) -> Option<AnyRelayMsg> {
469 let new_priority = self.take_next_priority();
470 let (_prev_priority, val) = self
471 .open_streams
472 .take_ready_value_and_reprioritize(&sid, new_priority)?;
473 Some(val)
474 }
475
476 }
479
480#[derive(Copy, Clone, Debug, PartialEq, Eq)]
484pub(super) enum TerminateReason {
485 StreamTargetClosed,
488 ExplicitEnd,
491}
492
493fn wrapping_next_stream_id(id: StreamId) -> StreamId {
495 let next_val = NonZeroU16::from(id)
496 .checked_add(1)
497 .unwrap_or_else(|| NonZeroU16::new(1).expect("Impossibly got 0 value"));
498 next_val.into()
499}
500
501#[cfg(test)]
502mod test {
503 #![allow(clippy::bool_assert_comparison)]
505 #![allow(clippy::clone_on_copy)]
506 #![allow(clippy::dbg_macro)]
507 #![allow(clippy::mixed_attributes_style)]
508 #![allow(clippy::print_stderr)]
509 #![allow(clippy::print_stdout)]
510 #![allow(clippy::single_char_pattern)]
511 #![allow(clippy::unwrap_used)]
512 #![allow(clippy::unchecked_duration_subtraction)]
513 #![allow(clippy::useless_vec)]
514 #![allow(clippy::needless_pass_by_value)]
515 use super::*;
517 use crate::tunnel::circuit::test::fake_mpsc;
518 use crate::{congestion::sendme::StreamSendWindow, stream::DataCmdChecker};
519
520 #[test]
521 fn test_wrapping_next_stream_id() {
522 let one = StreamId::new(1).unwrap();
523 let two = StreamId::new(2).unwrap();
524 let max = StreamId::new(0xffff).unwrap();
525 assert_eq!(wrapping_next_stream_id(one), two);
526 assert_eq!(wrapping_next_stream_id(max), one);
527 }
528
529 #[test]
530 #[allow(clippy::cognitive_complexity)]
531 fn streammap_basics() -> Result<()> {
532 let mut map = StreamMap::new();
533 let mut next_id = map.next_stream_id;
534 let mut ids = Vec::new();
535
536 assert_eq!(map.n_open_streams(), 0);
537
538 for n in 1..=128 {
540 let (sink, _) = fake_mpsc(128);
541 let (_, rx) = fake_mpsc(2);
542 let id = map.add_ent(
543 sink,
544 rx,
545 StreamSendFlowControl::new_window_based(StreamSendWindow::new(500)),
546 DataCmdChecker::new_any(),
547 )?;
548 let expect_id: StreamId = next_id;
549 assert_eq!(expect_id, id);
550 next_id = wrapping_next_stream_id(next_id);
551 ids.push(id);
552 assert_eq!(map.n_open_streams(), n);
553 }
554
555 let nonesuch_id = next_id;
557 assert!(matches!(
558 map.get_mut(ids[0]),
559 Some(StreamEntMut::Open { .. })
560 ));
561 assert!(map.get_mut(nonesuch_id).is_none());
562
563 assert!(map.ending_msg_received(nonesuch_id).is_err());
565 assert_eq!(map.n_open_streams(), 128);
566 assert!(map.ending_msg_received(ids[1]).is_ok());
567 assert_eq!(map.n_open_streams(), 127);
568 assert!(matches!(
569 map.get_mut(ids[1]),
570 Some(StreamEntMut::EndReceived)
571 ));
572 assert!(map.ending_msg_received(ids[1]).is_err());
573
574 use TerminateReason as TR;
576 assert!(map.terminate(nonesuch_id, TR::ExplicitEnd).is_err());
577 assert_eq!(map.n_open_streams(), 127);
578 assert_eq!(
579 map.terminate(ids[2], TR::ExplicitEnd).unwrap(),
580 ShouldSendEnd::Send
581 );
582 assert_eq!(map.n_open_streams(), 126);
583 assert!(matches!(
584 map.get_mut(ids[2]),
585 Some(StreamEntMut::EndSent { .. })
586 ));
587 assert_eq!(
588 map.terminate(ids[1], TR::ExplicitEnd).unwrap(),
589 ShouldSendEnd::DontSend
590 );
591 assert_eq!(map.n_open_streams(), 126);
594 assert!(map.get_mut(ids[1]).is_none());
595
596 assert!(map.ending_msg_received(ids[2]).is_ok());
598 assert!(map.get_mut(ids[2]).is_none());
599 assert_eq!(map.n_open_streams(), 126);
600
601 Ok(())
602 }
603}