1use crate::congestion::sendme;
4use crate::stream::queue::StreamQueueSender;
5use crate::stream::{AnyCmdChecker, StreamFlowControl};
6use crate::tunnel::circuit::StreamMpscReceiver;
7use crate::tunnel::halfstream::HalfStream;
8use crate::tunnel::reactor::circuit::RECV_WINDOW_INIT;
9use crate::util::stream_poll_set::{KeyAlreadyInsertedError, StreamPollSet};
10use crate::{Error, Result};
11use pin_project::pin_project;
12use tor_async_utils::peekable_stream::{PeekableStream, UnobtrusivePeekableStream};
13use tor_async_utils::stream_peek::StreamUnobtrusivePeeker;
14use tor_cell::relaycell::flow_ctrl::{Xoff, Xon, XonKbpsEwma};
15use tor_cell::relaycell::{msg::AnyRelayMsg, StreamId};
16use tor_cell::relaycell::{RelayMsg, UnparsedRelayMsg};
17
18use std::collections::hash_map;
19use std::collections::HashMap;
20use std::num::NonZeroU16;
21use std::pin::Pin;
22use std::task::{Poll, Waker};
23use tor_error::{bad_api_usage, internal};
24
25use rand::Rng;
26
27use tracing::debug;
28
29#[derive(Debug)]
34#[pin_project]
35pub(super) struct OpenStreamEnt {
36 pub(super) sink: StreamQueueSender,
38 pub(super) dropped: u16,
41 pub(super) cmd_checker: AnyCmdChecker,
43 flow_ctrl: StreamFlowControl,
47 #[pin]
52 rx: StreamUnobtrusivePeeker<StreamMpscReceiver<AnyRelayMsg>>,
53 flow_ctrl_waker: Option<Waker>,
56}
57
58impl OpenStreamEnt {
59 pub(crate) fn can_send<M: RelayMsg>(&self, msg: &M) -> bool {
61 self.flow_ctrl.can_send(msg)
62 }
63
64 pub(crate) fn put_for_incoming_sendme(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
69 self.flow_ctrl.put_for_incoming_sendme(msg)?;
70 if let Some(waker) = self.flow_ctrl_waker.take() {
72 waker.wake();
73 }
74 Ok(())
75 }
76
77 fn approx_stream_bytes_buffered(&self) -> usize {
79 self.sink.approx_stream_bytes()
91 }
92
93 pub(crate) fn maybe_send_xon(&mut self, rate: XonKbpsEwma) -> Result<Option<Xon>> {
98 self.flow_ctrl
99 .maybe_send_xon(rate, self.approx_stream_bytes_buffered())
100 }
101
102 pub(super) fn maybe_send_xoff(&mut self) -> Result<Option<Xoff>> {
107 self.flow_ctrl
108 .maybe_send_xoff(self.approx_stream_bytes_buffered())
109 }
110
111 pub(crate) fn handle_incoming_xon(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
116 self.flow_ctrl.handle_incoming_xon(msg)
117 }
118
119 pub(crate) fn handle_incoming_xoff(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
124 self.flow_ctrl.handle_incoming_xoff(msg)
125 }
126
127 pub(crate) fn take_capacity_to_send<M: RelayMsg>(&mut self, msg: &M) -> Result<()> {
134 self.flow_ctrl.take_capacity_to_send(msg)
135 }
136}
137
138#[derive(Debug)]
142#[pin_project]
143struct OpenStreamEntStream {
144 #[pin]
146 inner: OpenStreamEnt,
147}
148
149impl futures::Stream for OpenStreamEntStream {
150 type Item = AnyRelayMsg;
151
152 fn poll_next(
153 mut self: std::pin::Pin<&mut Self>,
154 cx: &mut std::task::Context<'_>,
155 ) -> Poll<Option<Self::Item>> {
156 if !self.as_mut().poll_peek_mut(cx).is_ready() {
157 return Poll::Pending;
158 };
159 let res = self.project().inner.project().rx.poll_next(cx);
160 debug_assert!(res.is_ready());
161 res
168 }
169}
170
171impl PeekableStream for OpenStreamEntStream {
172 fn poll_peek_mut(
173 self: Pin<&mut Self>,
174 cx: &mut std::task::Context<'_>,
175 ) -> Poll<Option<&mut <Self as futures::Stream>::Item>> {
176 let s = self.project();
177 let inner = s.inner.project();
178 let m = match inner.rx.poll_peek_mut(cx) {
179 Poll::Ready(Some(m)) => m,
180 Poll::Ready(None) => return Poll::Ready(None),
181 Poll::Pending => return Poll::Pending,
182 };
183 if !inner.flow_ctrl.can_send(m) {
184 inner.flow_ctrl_waker.replace(cx.waker().clone());
185 return Poll::Pending;
186 }
187 Poll::Ready(Some(m))
188 }
189}
190
191impl UnobtrusivePeekableStream for OpenStreamEntStream {
192 fn unobtrusive_peek_mut(
193 self: std::pin::Pin<&mut Self>,
194 ) -> Option<&mut <Self as futures::Stream>::Item> {
195 let s = self.project();
196 let inner = s.inner.project();
197 let m = inner.rx.unobtrusive_peek_mut()?;
198 if inner.flow_ctrl.can_send(m) {
199 Some(m)
200 } else {
201 None
202 }
203 }
204}
205
206#[derive(Debug)]
209pub(super) struct EndSentStreamEnt {
210 pub(super) half_stream: HalfStream,
213 explicitly_dropped: bool,
216}
217
218#[derive(Debug)]
220enum ClosedStreamEnt {
221 EndReceived,
224 EndSent(EndSentStreamEnt),
230}
231
232pub(super) enum StreamEntMut<'a> {
234 Open(&'a mut OpenStreamEnt),
236 EndReceived,
239 EndSent(&'a mut EndSentStreamEnt),
242}
243
244impl<'a> From<&'a mut ClosedStreamEnt> for StreamEntMut<'a> {
245 fn from(value: &'a mut ClosedStreamEnt) -> Self {
246 match value {
247 ClosedStreamEnt::EndReceived => Self::EndReceived,
248 ClosedStreamEnt::EndSent(e) => Self::EndSent(e),
249 }
250 }
251}
252
253impl<'a> From<&'a mut OpenStreamEntStream> for StreamEntMut<'a> {
254 fn from(value: &'a mut OpenStreamEntStream) -> Self {
255 Self::Open(&mut value.inner)
256 }
257}
258
259#[derive(Debug, Copy, Clone, Eq, PartialEq)]
262pub(super) enum ShouldSendEnd {
263 Send,
265 DontSend,
267}
268
269#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
271struct Priority(u64);
272
273pub(super) struct StreamMap {
276 open_streams: StreamPollSet<StreamId, Priority, OpenStreamEntStream>,
280 closed_streams: HashMap<StreamId, ClosedStreamEnt>,
284 next_stream_id: StreamId,
287 next_priority: Priority,
292}
293
294impl StreamMap {
295 pub(super) fn new() -> Self {
297 let mut rng = rand::rng();
298 let next_stream_id: NonZeroU16 = rng.random();
299 StreamMap {
300 open_streams: StreamPollSet::new(),
301 closed_streams: HashMap::new(),
302 next_stream_id: next_stream_id.into(),
303 next_priority: Priority(0),
304 }
305 }
306
307 pub(super) fn n_open_streams(&self) -> usize {
309 self.open_streams.len()
310 }
311
312 fn take_next_priority(&mut self) -> Priority {
314 let rv = self.next_priority;
315 self.next_priority = Priority(rv.0 + 1);
316 rv
317 }
318
319 pub(super) fn add_ent(
321 &mut self,
322 sink: StreamQueueSender,
323 rx: StreamMpscReceiver<AnyRelayMsg>,
324 flow_ctrl: StreamFlowControl,
325 cmd_checker: AnyCmdChecker,
326 ) -> Result<StreamId> {
327 let mut stream_ent = OpenStreamEntStream {
328 inner: OpenStreamEnt {
329 sink,
330 flow_ctrl,
331 dropped: 0,
332 cmd_checker,
333 rx: StreamUnobtrusivePeeker::new(rx),
334 flow_ctrl_waker: None,
335 },
336 };
337 let priority = self.take_next_priority();
338 for _ in 1..=65536 {
343 let id: StreamId = self.next_stream_id;
344 self.next_stream_id = wrapping_next_stream_id(self.next_stream_id);
345 stream_ent = match self.open_streams.try_insert(id, priority, stream_ent) {
346 Ok(_) => return Ok(id),
347 Err(KeyAlreadyInsertedError {
348 key: _,
349 priority: _,
350 stream,
351 }) => stream,
352 };
353 }
354
355 Err(Error::IdRangeFull)
356 }
357
358 #[cfg(feature = "hs-service")]
360 pub(super) fn add_ent_with_id(
361 &mut self,
362 sink: StreamQueueSender,
363 rx: StreamMpscReceiver<AnyRelayMsg>,
364 flow_ctrl: StreamFlowControl,
365 id: StreamId,
366 cmd_checker: AnyCmdChecker,
367 ) -> Result<()> {
368 let stream_ent = OpenStreamEntStream {
369 inner: OpenStreamEnt {
370 sink,
371 flow_ctrl,
372 dropped: 0,
373 cmd_checker,
374 rx: StreamUnobtrusivePeeker::new(rx),
375 flow_ctrl_waker: None,
376 },
377 };
378 let priority = self.take_next_priority();
379 self.open_streams
380 .try_insert(id, priority, stream_ent)
381 .map_err(|_| Error::IdUnavailable(id))
382 }
383
384 pub(super) fn get_mut(&mut self, id: StreamId) -> Option<StreamEntMut<'_>> {
386 if let Some(e) = self.open_streams.stream_mut(&id) {
387 return Some(e.into());
388 }
389 if let Some(e) = self.closed_streams.get_mut(&id) {
390 return Some(e.into());
391 }
392 None
393 }
394
395 pub(super) fn ending_msg_received(&mut self, id: StreamId) -> Result<()> {
400 if self.open_streams.remove(&id).is_some() {
401 let prev = self.closed_streams.insert(id, ClosedStreamEnt::EndReceived);
402 debug_assert!(prev.is_none(), "Unexpected duplicate entry for {id}");
403 return Ok(());
404 }
405 let hash_map::Entry::Occupied(closed_entry) = self.closed_streams.entry(id) else {
406 return Err(Error::CircProto(
407 "Received END cell on nonexistent stream".into(),
408 ));
409 };
410 match closed_entry.get() {
412 ClosedStreamEnt::EndReceived => Err(Error::CircProto(
413 "Received two END cells on same stream".into(),
414 )),
415 ClosedStreamEnt::EndSent { .. } => {
416 debug!("Actually got an end cell on a half-closed stream!");
417 closed_entry.remove_entry();
420 Ok(())
421 }
422 }
423 }
424
425 pub(super) fn terminate(
429 &mut self,
430 id: StreamId,
431 why: TerminateReason,
432 ) -> Result<ShouldSendEnd> {
433 use TerminateReason as TR;
434
435 if let Some((_id, _priority, ent)) = self.open_streams.remove(&id) {
436 let OpenStreamEntStream {
437 inner:
438 OpenStreamEnt {
439 flow_ctrl,
440 dropped,
441 cmd_checker,
442 ..
445 },
446 } = ent;
447 let mut recv_window = sendme::StreamRecvWindow::new(RECV_WINDOW_INIT);
451 recv_window.decrement_n(dropped)?;
452 let half_stream = HalfStream::new(flow_ctrl, recv_window, cmd_checker);
454 let explicitly_dropped = why == TR::StreamTargetClosed;
455 let prev = self.closed_streams.insert(
456 id,
457 ClosedStreamEnt::EndSent(EndSentStreamEnt {
458 half_stream,
459 explicitly_dropped,
460 }),
461 );
462 debug_assert!(prev.is_none(), "Unexpected duplicate entry for {id}");
463 return Ok(ShouldSendEnd::Send);
464 }
465
466 match self
468 .closed_streams
469 .remove(&id)
470 .ok_or_else(|| Error::from(internal!("Somehow we terminated a nonexistent stream?")))?
471 {
472 ClosedStreamEnt::EndReceived => Ok(ShouldSendEnd::DontSend),
473 ClosedStreamEnt::EndSent(EndSentStreamEnt {
474 ref mut explicitly_dropped,
475 ..
476 }) => match (*explicitly_dropped, why) {
477 (false, TR::StreamTargetClosed) => {
478 *explicitly_dropped = true;
479 Ok(ShouldSendEnd::DontSend)
480 }
481 (true, TR::StreamTargetClosed) => {
482 Err(bad_api_usage!("Tried to close an already closed stream.").into())
483 }
484 (_, TR::ExplicitEnd) => Err(bad_api_usage!(
485 "Tried to end an already closed stream. (explicitly_dropped={:?})",
486 *explicitly_dropped
487 )
488 .into()),
489 },
490 }
491 }
492
493 pub(super) fn poll_ready_streams_iter<'a>(
503 &'a mut self,
504 cx: &mut std::task::Context,
505 ) -> impl Iterator<Item = (StreamId, Option<&'a AnyRelayMsg>)> + 'a {
506 self.open_streams
507 .poll_ready_iter_mut(cx)
508 .map(|(sid, _priority, ent)| {
509 let ent = Pin::new(ent);
510 let msg = ent.unobtrusive_peek();
511 (*sid, msg)
512 })
513 }
514
515 pub(super) fn take_ready_msg(&mut self, sid: StreamId) -> Option<AnyRelayMsg> {
519 let new_priority = self.take_next_priority();
520 let (_prev_priority, val) = self
521 .open_streams
522 .take_ready_value_and_reprioritize(&sid, new_priority)?;
523 Some(val)
524 }
525
526 }
529
530#[derive(Copy, Clone, Debug, PartialEq, Eq)]
534pub(super) enum TerminateReason {
535 StreamTargetClosed,
538 ExplicitEnd,
541}
542
543fn wrapping_next_stream_id(id: StreamId) -> StreamId {
545 let next_val = NonZeroU16::from(id)
546 .checked_add(1)
547 .unwrap_or_else(|| NonZeroU16::new(1).expect("Impossibly got 0 value"));
548 next_val.into()
549}
550
551#[cfg(test)]
552mod test {
553 #![allow(clippy::bool_assert_comparison)]
555 #![allow(clippy::clone_on_copy)]
556 #![allow(clippy::dbg_macro)]
557 #![allow(clippy::mixed_attributes_style)]
558 #![allow(clippy::print_stderr)]
559 #![allow(clippy::print_stdout)]
560 #![allow(clippy::single_char_pattern)]
561 #![allow(clippy::unwrap_used)]
562 #![allow(clippy::unchecked_duration_subtraction)]
563 #![allow(clippy::useless_vec)]
564 #![allow(clippy::needless_pass_by_value)]
565 use super::*;
567 use crate::stream::queue::fake_stream_queue;
568 use crate::tunnel::circuit::test::fake_mpsc;
569 use crate::{congestion::sendme::StreamSendWindow, stream::DataCmdChecker};
570
571 #[test]
572 fn test_wrapping_next_stream_id() {
573 let one = StreamId::new(1).unwrap();
574 let two = StreamId::new(2).unwrap();
575 let max = StreamId::new(0xffff).unwrap();
576 assert_eq!(wrapping_next_stream_id(one), two);
577 assert_eq!(wrapping_next_stream_id(max), one);
578 }
579
580 #[test]
581 #[allow(clippy::cognitive_complexity)]
582 fn streammap_basics() -> Result<()> {
583 let mut map = StreamMap::new();
584 let mut next_id = map.next_stream_id;
585 let mut ids = Vec::new();
586
587 assert_eq!(map.n_open_streams(), 0);
588
589 for n in 1..=128 {
591 let (sink, _) = fake_stream_queue(128);
592 let (_, rx) = fake_mpsc(2);
593 let id = map.add_ent(
594 sink,
595 rx,
596 StreamFlowControl::new_window_based(StreamSendWindow::new(500)),
597 DataCmdChecker::new_any(),
598 )?;
599 let expect_id: StreamId = next_id;
600 assert_eq!(expect_id, id);
601 next_id = wrapping_next_stream_id(next_id);
602 ids.push(id);
603 assert_eq!(map.n_open_streams(), n);
604 }
605
606 let nonesuch_id = next_id;
608 assert!(matches!(
609 map.get_mut(ids[0]),
610 Some(StreamEntMut::Open { .. })
611 ));
612 assert!(map.get_mut(nonesuch_id).is_none());
613
614 assert!(map.ending_msg_received(nonesuch_id).is_err());
616 assert_eq!(map.n_open_streams(), 128);
617 assert!(map.ending_msg_received(ids[1]).is_ok());
618 assert_eq!(map.n_open_streams(), 127);
619 assert!(matches!(
620 map.get_mut(ids[1]),
621 Some(StreamEntMut::EndReceived)
622 ));
623 assert!(map.ending_msg_received(ids[1]).is_err());
624
625 use TerminateReason as TR;
627 assert!(map.terminate(nonesuch_id, TR::ExplicitEnd).is_err());
628 assert_eq!(map.n_open_streams(), 127);
629 assert_eq!(
630 map.terminate(ids[2], TR::ExplicitEnd).unwrap(),
631 ShouldSendEnd::Send
632 );
633 assert_eq!(map.n_open_streams(), 126);
634 assert!(matches!(
635 map.get_mut(ids[2]),
636 Some(StreamEntMut::EndSent { .. })
637 ));
638 assert_eq!(
639 map.terminate(ids[1], TR::ExplicitEnd).unwrap(),
640 ShouldSendEnd::DontSend
641 );
642 assert_eq!(map.n_open_streams(), 126);
645 assert!(map.get_mut(ids[1]).is_none());
646
647 assert!(map.ending_msg_received(ids[2]).is_ok());
649 assert!(map.get_mut(ids[2]).is_none());
650 assert_eq!(map.n_open_streams(), 126);
651
652 Ok(())
653 }
654}