1use crate::congestion::sendme;
4use crate::stream::{AnyCmdChecker, StreamSendFlowControl};
5use crate::tunnel::circuit::{StreamMpscReceiver, StreamMpscSender};
6use crate::tunnel::halfstream::HalfStream;
7use crate::tunnel::reactor::circuit::RECV_WINDOW_INIT;
8use crate::util::stream_poll_set::{KeyAlreadyInsertedError, StreamPollSet};
9use crate::{Error, Result};
10use pin_project::pin_project;
11use tor_async_utils::peekable_stream::{PeekableStream, UnobtrusivePeekableStream};
12use tor_async_utils::stream_peek::StreamUnobtrusivePeeker;
13use tor_cell::relaycell::{msg::AnyRelayMsg, StreamId};
14use tor_cell::relaycell::{RelayMsg, UnparsedRelayMsg};
15
16use std::collections::hash_map;
17use std::collections::HashMap;
18use std::num::NonZeroU16;
19use std::pin::Pin;
20use std::task::{Poll, Waker};
21use tor_error::{bad_api_usage, internal};
22
23use rand::Rng;
24
25use tracing::debug;
26
27#[derive(Debug)]
32#[pin_project]
33pub(super) struct OpenStreamEnt {
34 pub(super) sink: StreamMpscSender<UnparsedRelayMsg>,
36 pub(super) dropped: u16,
39 pub(super) cmd_checker: AnyCmdChecker,
41 flow_ctrl: StreamSendFlowControl,
45 #[pin]
50 rx: StreamUnobtrusivePeeker<StreamMpscReceiver<AnyRelayMsg>>,
51 flow_ctrl_waker: Option<Waker>,
54}
55
56impl OpenStreamEnt {
57 pub(crate) fn can_send<M: RelayMsg>(&self, msg: &M) -> bool {
59 self.flow_ctrl.can_send(msg)
60 }
61
62 pub(crate) fn put_for_incoming_sendme(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
67 self.flow_ctrl.put_for_incoming_sendme(msg)?;
68 if let Some(waker) = self.flow_ctrl_waker.take() {
70 waker.wake();
71 }
72 Ok(())
73 }
74
75 pub(crate) fn handle_incoming_xon(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
80 self.flow_ctrl.handle_incoming_xon(msg)
81 }
82
83 pub(crate) fn handle_incoming_xoff(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
88 self.flow_ctrl.handle_incoming_xoff(msg)
89 }
90
91 pub(crate) fn take_capacity_to_send<M: RelayMsg>(&mut self, msg: &M) -> Result<()> {
98 self.flow_ctrl.take_capacity_to_send(msg)
99 }
100}
101
102#[derive(Debug)]
106#[pin_project]
107struct OpenStreamEntStream {
108 #[pin]
110 inner: OpenStreamEnt,
111}
112
113impl futures::Stream for OpenStreamEntStream {
114 type Item = AnyRelayMsg;
115
116 fn poll_next(
117 mut self: std::pin::Pin<&mut Self>,
118 cx: &mut std::task::Context<'_>,
119 ) -> Poll<Option<Self::Item>> {
120 if !self.as_mut().poll_peek_mut(cx).is_ready() {
121 return Poll::Pending;
122 };
123 let res = self.project().inner.project().rx.poll_next(cx);
124 debug_assert!(res.is_ready());
125 res
132 }
133}
134
135impl PeekableStream for OpenStreamEntStream {
136 fn poll_peek_mut(
137 self: Pin<&mut Self>,
138 cx: &mut std::task::Context<'_>,
139 ) -> Poll<Option<&mut <Self as futures::Stream>::Item>> {
140 let s = self.project();
141 let inner = s.inner.project();
142 let m = match inner.rx.poll_peek_mut(cx) {
143 Poll::Ready(Some(m)) => m,
144 Poll::Ready(None) => return Poll::Ready(None),
145 Poll::Pending => return Poll::Pending,
146 };
147 if !inner.flow_ctrl.can_send(m) {
148 inner.flow_ctrl_waker.replace(cx.waker().clone());
149 return Poll::Pending;
150 }
151 Poll::Ready(Some(m))
152 }
153}
154
155impl UnobtrusivePeekableStream for OpenStreamEntStream {
156 fn unobtrusive_peek_mut(
157 self: std::pin::Pin<&mut Self>,
158 ) -> Option<&mut <Self as futures::Stream>::Item> {
159 let s = self.project();
160 let inner = s.inner.project();
161 let m = inner.rx.unobtrusive_peek_mut()?;
162 if inner.flow_ctrl.can_send(m) {
163 Some(m)
164 } else {
165 None
166 }
167 }
168}
169
170#[derive(Debug)]
173pub(super) struct EndSentStreamEnt {
174 pub(super) half_stream: HalfStream,
177 explicitly_dropped: bool,
180}
181
182#[derive(Debug)]
184enum ClosedStreamEnt {
185 EndReceived,
188 EndSent(EndSentStreamEnt),
194}
195
196pub(super) enum StreamEntMut<'a> {
198 Open(&'a mut OpenStreamEnt),
200 EndReceived,
203 EndSent(&'a mut EndSentStreamEnt),
206}
207
208impl<'a> From<&'a mut ClosedStreamEnt> for StreamEntMut<'a> {
209 fn from(value: &'a mut ClosedStreamEnt) -> Self {
210 match value {
211 ClosedStreamEnt::EndReceived => Self::EndReceived,
212 ClosedStreamEnt::EndSent(e) => Self::EndSent(e),
213 }
214 }
215}
216
217impl<'a> From<&'a mut OpenStreamEntStream> for StreamEntMut<'a> {
218 fn from(value: &'a mut OpenStreamEntStream) -> Self {
219 Self::Open(&mut value.inner)
220 }
221}
222
223#[derive(Debug, Copy, Clone, Eq, PartialEq)]
226pub(super) enum ShouldSendEnd {
227 Send,
229 DontSend,
231}
232
233#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
235struct Priority(u64);
236
237pub(super) struct StreamMap {
240 open_streams: StreamPollSet<StreamId, Priority, OpenStreamEntStream>,
244 closed_streams: HashMap<StreamId, ClosedStreamEnt>,
248 next_stream_id: StreamId,
251 next_priority: Priority,
256}
257
258impl StreamMap {
259 pub(super) fn new() -> Self {
261 let mut rng = rand::rng();
262 let next_stream_id: NonZeroU16 = rng.random();
263 StreamMap {
264 open_streams: StreamPollSet::new(),
265 closed_streams: HashMap::new(),
266 next_stream_id: next_stream_id.into(),
267 next_priority: Priority(0),
268 }
269 }
270
271 pub(super) fn n_open_streams(&self) -> usize {
273 self.open_streams.len()
274 }
275
276 fn take_next_priority(&mut self) -> Priority {
278 let rv = self.next_priority;
279 self.next_priority = Priority(rv.0 + 1);
280 rv
281 }
282
283 pub(super) fn add_ent(
285 &mut self,
286 sink: StreamMpscSender<UnparsedRelayMsg>,
287 rx: StreamMpscReceiver<AnyRelayMsg>,
288 flow_ctrl: StreamSendFlowControl,
289 cmd_checker: AnyCmdChecker,
290 ) -> Result<StreamId> {
291 let mut stream_ent = OpenStreamEntStream {
292 inner: OpenStreamEnt {
293 sink,
294 flow_ctrl,
295 dropped: 0,
296 cmd_checker,
297 rx: StreamUnobtrusivePeeker::new(rx),
298 flow_ctrl_waker: None,
299 },
300 };
301 let priority = self.take_next_priority();
302 for _ in 1..=65536 {
307 let id: StreamId = self.next_stream_id;
308 self.next_stream_id = wrapping_next_stream_id(self.next_stream_id);
309 stream_ent = match self.open_streams.try_insert(id, priority, stream_ent) {
310 Ok(_) => return Ok(id),
311 Err(KeyAlreadyInsertedError {
312 key: _,
313 priority: _,
314 stream,
315 }) => stream,
316 };
317 }
318
319 Err(Error::IdRangeFull)
320 }
321
322 #[cfg(feature = "hs-service")]
324 pub(super) fn add_ent_with_id(
325 &mut self,
326 sink: StreamMpscSender<UnparsedRelayMsg>,
327 rx: StreamMpscReceiver<AnyRelayMsg>,
328 flow_ctrl: StreamSendFlowControl,
329 id: StreamId,
330 cmd_checker: AnyCmdChecker,
331 ) -> Result<()> {
332 let stream_ent = OpenStreamEntStream {
333 inner: OpenStreamEnt {
334 sink,
335 flow_ctrl,
336 dropped: 0,
337 cmd_checker,
338 rx: StreamUnobtrusivePeeker::new(rx),
339 flow_ctrl_waker: None,
340 },
341 };
342 let priority = self.take_next_priority();
343 self.open_streams
344 .try_insert(id, priority, stream_ent)
345 .map_err(|_| Error::IdUnavailable(id))
346 }
347
348 pub(super) fn get_mut(&mut self, id: StreamId) -> Option<StreamEntMut<'_>> {
350 if let Some(e) = self.open_streams.stream_mut(&id) {
351 return Some(e.into());
352 }
353 if let Some(e) = self.closed_streams.get_mut(&id) {
354 return Some(e.into());
355 }
356 None
357 }
358
359 pub(super) fn ending_msg_received(&mut self, id: StreamId) -> Result<()> {
364 if self.open_streams.remove(&id).is_some() {
365 let prev = self.closed_streams.insert(id, ClosedStreamEnt::EndReceived);
366 debug_assert!(prev.is_none(), "Unexpected duplicate entry for {id}");
367 return Ok(());
368 }
369 let hash_map::Entry::Occupied(closed_entry) = self.closed_streams.entry(id) else {
370 return Err(Error::CircProto(
371 "Received END cell on nonexistent stream".into(),
372 ));
373 };
374 match closed_entry.get() {
376 ClosedStreamEnt::EndReceived => Err(Error::CircProto(
377 "Received two END cells on same stream".into(),
378 )),
379 ClosedStreamEnt::EndSent { .. } => {
380 debug!("Actually got an end cell on a half-closed stream!");
381 closed_entry.remove_entry();
384 Ok(())
385 }
386 }
387 }
388
389 pub(super) fn terminate(
393 &mut self,
394 id: StreamId,
395 why: TerminateReason,
396 ) -> Result<ShouldSendEnd> {
397 use TerminateReason as TR;
398
399 if let Some((_id, _priority, ent)) = self.open_streams.remove(&id) {
400 let OpenStreamEntStream {
401 inner:
402 OpenStreamEnt {
403 flow_ctrl,
404 dropped,
405 cmd_checker,
406 ..
409 },
410 } = ent;
411 let mut recv_window = sendme::StreamRecvWindow::new(RECV_WINDOW_INIT);
415 recv_window.decrement_n(dropped)?;
416 let half_stream = HalfStream::new(flow_ctrl, recv_window, cmd_checker);
418 let explicitly_dropped = why == TR::StreamTargetClosed;
419 let prev = self.closed_streams.insert(
420 id,
421 ClosedStreamEnt::EndSent(EndSentStreamEnt {
422 half_stream,
423 explicitly_dropped,
424 }),
425 );
426 debug_assert!(prev.is_none(), "Unexpected duplicate entry for {id}");
427 return Ok(ShouldSendEnd::Send);
428 }
429
430 match self
432 .closed_streams
433 .remove(&id)
434 .ok_or_else(|| Error::from(internal!("Somehow we terminated a nonexistent stream?")))?
435 {
436 ClosedStreamEnt::EndReceived => Ok(ShouldSendEnd::DontSend),
437 ClosedStreamEnt::EndSent(EndSentStreamEnt {
438 ref mut explicitly_dropped,
439 ..
440 }) => match (*explicitly_dropped, why) {
441 (false, TR::StreamTargetClosed) => {
442 *explicitly_dropped = true;
443 Ok(ShouldSendEnd::DontSend)
444 }
445 (true, TR::StreamTargetClosed) => {
446 Err(bad_api_usage!("Tried to close an already closed stream.").into())
447 }
448 (_, TR::ExplicitEnd) => Err(bad_api_usage!(
449 "Tried to end an already closed stream. (explicitly_dropped={:?})",
450 *explicitly_dropped
451 )
452 .into()),
453 },
454 }
455 }
456
457 pub(super) fn poll_ready_streams_iter<'a>(
467 &'a mut self,
468 cx: &mut std::task::Context,
469 ) -> impl Iterator<Item = (StreamId, Option<&'a AnyRelayMsg>)> + 'a {
470 self.open_streams
471 .poll_ready_iter_mut(cx)
472 .map(|(sid, _priority, ent)| {
473 let ent = Pin::new(ent);
474 let msg = ent.unobtrusive_peek();
475 (*sid, msg)
476 })
477 }
478
479 pub(super) fn take_ready_msg(&mut self, sid: StreamId) -> Option<AnyRelayMsg> {
483 let new_priority = self.take_next_priority();
484 let (_prev_priority, val) = self
485 .open_streams
486 .take_ready_value_and_reprioritize(&sid, new_priority)?;
487 Some(val)
488 }
489
490 }
493
494#[derive(Copy, Clone, Debug, PartialEq, Eq)]
498pub(super) enum TerminateReason {
499 StreamTargetClosed,
502 ExplicitEnd,
505}
506
507fn wrapping_next_stream_id(id: StreamId) -> StreamId {
509 let next_val = NonZeroU16::from(id)
510 .checked_add(1)
511 .unwrap_or_else(|| NonZeroU16::new(1).expect("Impossibly got 0 value"));
512 next_val.into()
513}
514
515#[cfg(test)]
516mod test {
517 #![allow(clippy::bool_assert_comparison)]
519 #![allow(clippy::clone_on_copy)]
520 #![allow(clippy::dbg_macro)]
521 #![allow(clippy::mixed_attributes_style)]
522 #![allow(clippy::print_stderr)]
523 #![allow(clippy::print_stdout)]
524 #![allow(clippy::single_char_pattern)]
525 #![allow(clippy::unwrap_used)]
526 #![allow(clippy::unchecked_duration_subtraction)]
527 #![allow(clippy::useless_vec)]
528 #![allow(clippy::needless_pass_by_value)]
529 use super::*;
531 use crate::tunnel::circuit::test::fake_mpsc;
532 use crate::{congestion::sendme::StreamSendWindow, stream::DataCmdChecker};
533
534 #[test]
535 fn test_wrapping_next_stream_id() {
536 let one = StreamId::new(1).unwrap();
537 let two = StreamId::new(2).unwrap();
538 let max = StreamId::new(0xffff).unwrap();
539 assert_eq!(wrapping_next_stream_id(one), two);
540 assert_eq!(wrapping_next_stream_id(max), one);
541 }
542
543 #[test]
544 #[allow(clippy::cognitive_complexity)]
545 fn streammap_basics() -> Result<()> {
546 let mut map = StreamMap::new();
547 let mut next_id = map.next_stream_id;
548 let mut ids = Vec::new();
549
550 assert_eq!(map.n_open_streams(), 0);
551
552 for n in 1..=128 {
554 let (sink, _) = fake_mpsc(128);
555 let (_, rx) = fake_mpsc(2);
556 let id = map.add_ent(
557 sink,
558 rx,
559 StreamSendFlowControl::new_window_based(StreamSendWindow::new(500)),
560 DataCmdChecker::new_any(),
561 )?;
562 let expect_id: StreamId = next_id;
563 assert_eq!(expect_id, id);
564 next_id = wrapping_next_stream_id(next_id);
565 ids.push(id);
566 assert_eq!(map.n_open_streams(), n);
567 }
568
569 let nonesuch_id = next_id;
571 assert!(matches!(
572 map.get_mut(ids[0]),
573 Some(StreamEntMut::Open { .. })
574 ));
575 assert!(map.get_mut(nonesuch_id).is_none());
576
577 assert!(map.ending_msg_received(nonesuch_id).is_err());
579 assert_eq!(map.n_open_streams(), 128);
580 assert!(map.ending_msg_received(ids[1]).is_ok());
581 assert_eq!(map.n_open_streams(), 127);
582 assert!(matches!(
583 map.get_mut(ids[1]),
584 Some(StreamEntMut::EndReceived)
585 ));
586 assert!(map.ending_msg_received(ids[1]).is_err());
587
588 use TerminateReason as TR;
590 assert!(map.terminate(nonesuch_id, TR::ExplicitEnd).is_err());
591 assert_eq!(map.n_open_streams(), 127);
592 assert_eq!(
593 map.terminate(ids[2], TR::ExplicitEnd).unwrap(),
594 ShouldSendEnd::Send
595 );
596 assert_eq!(map.n_open_streams(), 126);
597 assert!(matches!(
598 map.get_mut(ids[2]),
599 Some(StreamEntMut::EndSent { .. })
600 ));
601 assert_eq!(
602 map.terminate(ids[1], TR::ExplicitEnd).unwrap(),
603 ShouldSendEnd::DontSend
604 );
605 assert_eq!(map.n_open_streams(), 126);
608 assert!(map.get_mut(ids[1]).is_none());
609
610 assert!(map.ending_msg_received(ids[2]).is_ok());
612 assert!(map.get_mut(ids[2]).is_none());
613 assert_eq!(map.n_open_streams(), 126);
614
615 Ok(())
616 }
617}