tor_proto/tunnel/
streammap.rs

1//! Types and code for mapping StreamIDs to streams on a circuit.
2
3use crate::congestion::sendme;
4use crate::stream::{AnyCmdChecker, StreamSendFlowControl};
5use crate::tunnel::circuit::{StreamMpscReceiver, StreamMpscSender};
6use crate::tunnel::halfstream::HalfStream;
7use crate::tunnel::reactor::circuit::RECV_WINDOW_INIT;
8use crate::util::stream_poll_set::{KeyAlreadyInsertedError, StreamPollSet};
9use crate::{Error, Result};
10use pin_project::pin_project;
11use tor_async_utils::peekable_stream::{PeekableStream, UnobtrusivePeekableStream};
12use tor_async_utils::stream_peek::StreamUnobtrusivePeeker;
13use tor_cell::relaycell::{msg::AnyRelayMsg, StreamId};
14use tor_cell::relaycell::{RelayMsg, UnparsedRelayMsg};
15
16use std::collections::hash_map;
17use std::collections::HashMap;
18use std::num::NonZeroU16;
19use std::pin::Pin;
20use std::task::{Poll, Waker};
21use tor_error::{bad_api_usage, internal};
22
23use rand::Rng;
24
25use tracing::debug;
26
27/// Entry for an open stream
28///
29/// (For the purposes of this module, an open stream is one where we have not
30/// sent or received any message indicating that the stream is ended.)
31#[derive(Debug)]
32#[pin_project]
33pub(super) struct OpenStreamEnt {
34    /// Sink to send relay cells tagged for this stream into.
35    pub(super) sink: StreamMpscSender<UnparsedRelayMsg>,
36    /// Number of cells dropped due to the stream disappearing before we can
37    /// transform this into an `EndSent`.
38    pub(super) dropped: u16,
39    /// A `CmdChecker` used to tell whether cells on this stream are valid.
40    pub(super) cmd_checker: AnyCmdChecker,
41    /// Flow control for this stream.
42    // Non-pub because we need to proxy `put_for_incoming_sendme` to ensure
43    // `flow_ctrl_waker` is woken.
44    flow_ctrl: StreamSendFlowControl,
45    /// Stream for cells that should be sent down this stream.
46    // Not directly exposed. This should only be polled via
47    // `OpenStreamEntStream`s implementation of `Stream`, which in turn should
48    // only be used through `StreamPollSet`.
49    #[pin]
50    rx: StreamUnobtrusivePeeker<StreamMpscReceiver<AnyRelayMsg>>,
51    /// Waker to be woken when more sending capacity becomes available (e.g.
52    /// receiving a SENDME).
53    flow_ctrl_waker: Option<Waker>,
54}
55
56impl OpenStreamEnt {
57    /// Whether this stream is ready to send `msg`.
58    pub(crate) fn can_send<M: RelayMsg>(&self, msg: &M) -> bool {
59        self.flow_ctrl.can_send(msg)
60    }
61
62    /// Handle an incoming sendme.
63    ///
64    /// On failure, return an error: the caller should close the stream or
65    /// circuit with a protocol error.
66    pub(crate) fn put_for_incoming_sendme(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
67        self.flow_ctrl.put_for_incoming_sendme(msg)?;
68        // Wake the stream if it was blocked on flow control.
69        if let Some(waker) = self.flow_ctrl_waker.take() {
70            waker.wake();
71        }
72        Ok(())
73    }
74
75    /// Handle an incoming XON message.
76    ///
77    /// On failure, return an error: the caller should close the stream or
78    /// circuit with a protocol error.
79    pub(crate) fn handle_incoming_xon(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
80        self.flow_ctrl.handle_incoming_xon(msg)
81    }
82
83    /// Handle an incoming XOFF message.
84    ///
85    /// On failure, return an error: the caller should close the stream or
86    /// circuit with a protocol error.
87    pub(crate) fn handle_incoming_xoff(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
88        self.flow_ctrl.handle_incoming_xoff(msg)
89    }
90
91    /// Take capacity to send `msg`. If there's insufficient capacity, returns
92    /// an error. Should be called at the point we've fully committed to
93    /// sending the message.
94    //
95    // TODO: Consider not exposing this, and instead taking the capacity in
96    // `StreamMap::take_ready_msg`.
97    pub(crate) fn take_capacity_to_send<M: RelayMsg>(&mut self, msg: &M) -> Result<()> {
98        self.flow_ctrl.take_capacity_to_send(msg)
99    }
100}
101
102/// Private wrapper over `OpenStreamEnt`. We implement `futures::Stream` for
103/// this wrapper, and not directly for `OpenStreamEnt`, so that client code
104/// can't directly access the stream.
105#[derive(Debug)]
106#[pin_project]
107struct OpenStreamEntStream {
108    /// Inner value.
109    #[pin]
110    inner: OpenStreamEnt,
111}
112
113impl futures::Stream for OpenStreamEntStream {
114    type Item = AnyRelayMsg;
115
116    fn poll_next(
117        mut self: std::pin::Pin<&mut Self>,
118        cx: &mut std::task::Context<'_>,
119    ) -> Poll<Option<Self::Item>> {
120        if !self.as_mut().poll_peek_mut(cx).is_ready() {
121            return Poll::Pending;
122        };
123        let res = self.project().inner.project().rx.poll_next(cx);
124        debug_assert!(res.is_ready());
125        // TODO: consider calling `inner.flow_ctrl.take_capacity_to_send` here;
126        // particularly if we change it to return a wrapper type that proves
127        // we've taken the capacity. Otherwise it'd make it tricky in the reactor
128        // to be sure we've correctly taken the capacity, since messages can originate
129        // in other parts of the code (currently none of those should be of types that
130        // count towards flow control, but that may change).
131        res
132    }
133}
134
135impl PeekableStream for OpenStreamEntStream {
136    fn poll_peek_mut(
137        self: Pin<&mut Self>,
138        cx: &mut std::task::Context<'_>,
139    ) -> Poll<Option<&mut <Self as futures::Stream>::Item>> {
140        let s = self.project();
141        let inner = s.inner.project();
142        let m = match inner.rx.poll_peek_mut(cx) {
143            Poll::Ready(Some(m)) => m,
144            Poll::Ready(None) => return Poll::Ready(None),
145            Poll::Pending => return Poll::Pending,
146        };
147        if !inner.flow_ctrl.can_send(m) {
148            inner.flow_ctrl_waker.replace(cx.waker().clone());
149            return Poll::Pending;
150        }
151        Poll::Ready(Some(m))
152    }
153}
154
155impl UnobtrusivePeekableStream for OpenStreamEntStream {
156    fn unobtrusive_peek_mut(
157        self: std::pin::Pin<&mut Self>,
158    ) -> Option<&mut <Self as futures::Stream>::Item> {
159        let s = self.project();
160        let inner = s.inner.project();
161        let m = inner.rx.unobtrusive_peek_mut()?;
162        if inner.flow_ctrl.can_send(m) {
163            Some(m)
164        } else {
165            None
166        }
167    }
168}
169
170/// Entry for a stream where we have sent an END, or other message
171/// indicating that the stream is terminated.
172#[derive(Debug)]
173pub(super) struct EndSentStreamEnt {
174    /// A "half-stream" that we use to check the validity of incoming
175    /// messages on this stream.
176    pub(super) half_stream: HalfStream,
177    /// True if the sender on this stream has been explicitly dropped;
178    /// false if we got an explicit close from `close_pending`
179    explicitly_dropped: bool,
180}
181
182/// The entry for a stream.
183#[derive(Debug)]
184enum ClosedStreamEnt {
185    /// A stream for which we have received an END cell, but not yet
186    /// had the stream object get dropped.
187    EndReceived,
188    /// A stream for which we have sent an END cell but not yet received an END
189    /// cell.
190    ///
191    /// TODO(arti#264) Can we ever throw this out? Do we really get END cells for
192    /// these?
193    EndSent(EndSentStreamEnt),
194}
195
196/// Mutable reference to a stream entry.
197pub(super) enum StreamEntMut<'a> {
198    /// An open stream.
199    Open(&'a mut OpenStreamEnt),
200    /// A stream for which we have received an END cell, but not yet
201    /// had the stream object get dropped.
202    EndReceived,
203    /// A stream for which we have sent an END cell but not yet received an END
204    /// cell.
205    EndSent(&'a mut EndSentStreamEnt),
206}
207
208impl<'a> From<&'a mut ClosedStreamEnt> for StreamEntMut<'a> {
209    fn from(value: &'a mut ClosedStreamEnt) -> Self {
210        match value {
211            ClosedStreamEnt::EndReceived => Self::EndReceived,
212            ClosedStreamEnt::EndSent(e) => Self::EndSent(e),
213        }
214    }
215}
216
217impl<'a> From<&'a mut OpenStreamEntStream> for StreamEntMut<'a> {
218    fn from(value: &'a mut OpenStreamEntStream) -> Self {
219        Self::Open(&mut value.inner)
220    }
221}
222
223/// Return value to indicate whether or not we send an END cell upon
224/// terminating a given stream.
225#[derive(Debug, Copy, Clone, Eq, PartialEq)]
226pub(super) enum ShouldSendEnd {
227    /// An END cell should be sent.
228    Send,
229    /// An END cell should not be sent.
230    DontSend,
231}
232
233/// A priority for use with [`StreamPollSet`].
234#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
235struct Priority(u64);
236
237/// A map from stream IDs to stream entries. Each circuit has one for each
238/// hop.
239pub(super) struct StreamMap {
240    /// Open streams.
241    // Invariants:
242    // * Keys are disjoint with `closed_streams`.
243    open_streams: StreamPollSet<StreamId, Priority, OpenStreamEntStream>,
244    /// Closed streams.
245    // Invariants:
246    // * Keys are disjoint with `open_streams`.
247    closed_streams: HashMap<StreamId, ClosedStreamEnt>,
248    /// The next StreamId that we should use for a newly allocated
249    /// circuit.
250    next_stream_id: StreamId,
251    /// Next priority to use in `rxs`. We implement round-robin scheduling of
252    /// handling outgoing messages from streams by assigning a stream the next
253    /// priority whenever an outgoing message is processed from that stream,
254    /// putting it last in line.
255    next_priority: Priority,
256}
257
258impl StreamMap {
259    /// Make a new empty StreamMap.
260    pub(super) fn new() -> Self {
261        let mut rng = rand::rng();
262        let next_stream_id: NonZeroU16 = rng.random();
263        StreamMap {
264            open_streams: StreamPollSet::new(),
265            closed_streams: HashMap::new(),
266            next_stream_id: next_stream_id.into(),
267            next_priority: Priority(0),
268        }
269    }
270
271    /// Return the number of open streams in this map.
272    pub(super) fn n_open_streams(&self) -> usize {
273        self.open_streams.len()
274    }
275
276    /// Return the next available priority.
277    fn take_next_priority(&mut self) -> Priority {
278        let rv = self.next_priority;
279        self.next_priority = Priority(rv.0 + 1);
280        rv
281    }
282
283    /// Add an entry to this map; return the newly allocated StreamId.
284    pub(super) fn add_ent(
285        &mut self,
286        sink: StreamMpscSender<UnparsedRelayMsg>,
287        rx: StreamMpscReceiver<AnyRelayMsg>,
288        flow_ctrl: StreamSendFlowControl,
289        cmd_checker: AnyCmdChecker,
290    ) -> Result<StreamId> {
291        let mut stream_ent = OpenStreamEntStream {
292            inner: OpenStreamEnt {
293                sink,
294                flow_ctrl,
295                dropped: 0,
296                cmd_checker,
297                rx: StreamUnobtrusivePeeker::new(rx),
298                flow_ctrl_waker: None,
299            },
300        };
301        let priority = self.take_next_priority();
302        // This "65536" seems too aggressive, but it's what tor does.
303        //
304        // Also, going around in a loop here is (sadly) needed in order
305        // to look like Tor clients.
306        for _ in 1..=65536 {
307            let id: StreamId = self.next_stream_id;
308            self.next_stream_id = wrapping_next_stream_id(self.next_stream_id);
309            stream_ent = match self.open_streams.try_insert(id, priority, stream_ent) {
310                Ok(_) => return Ok(id),
311                Err(KeyAlreadyInsertedError {
312                    key: _,
313                    priority: _,
314                    stream,
315                }) => stream,
316            };
317        }
318
319        Err(Error::IdRangeFull)
320    }
321
322    /// Add an entry to this map using the specified StreamId.
323    #[cfg(feature = "hs-service")]
324    pub(super) fn add_ent_with_id(
325        &mut self,
326        sink: StreamMpscSender<UnparsedRelayMsg>,
327        rx: StreamMpscReceiver<AnyRelayMsg>,
328        flow_ctrl: StreamSendFlowControl,
329        id: StreamId,
330        cmd_checker: AnyCmdChecker,
331    ) -> Result<()> {
332        let stream_ent = OpenStreamEntStream {
333            inner: OpenStreamEnt {
334                sink,
335                flow_ctrl,
336                dropped: 0,
337                cmd_checker,
338                rx: StreamUnobtrusivePeeker::new(rx),
339                flow_ctrl_waker: None,
340            },
341        };
342        let priority = self.take_next_priority();
343        self.open_streams
344            .try_insert(id, priority, stream_ent)
345            .map_err(|_| Error::IdUnavailable(id))
346    }
347
348    /// Return the entry for `id` in this map, if any.
349    pub(super) fn get_mut(&mut self, id: StreamId) -> Option<StreamEntMut<'_>> {
350        if let Some(e) = self.open_streams.stream_mut(&id) {
351            return Some(e.into());
352        }
353        if let Some(e) = self.closed_streams.get_mut(&id) {
354            return Some(e.into());
355        }
356        None
357    }
358
359    /// Note that we received an END message (or other message indicating the end of
360    /// the stream) on the stream with `id`.
361    ///
362    /// Returns true if there was really a stream there.
363    pub(super) fn ending_msg_received(&mut self, id: StreamId) -> Result<()> {
364        if self.open_streams.remove(&id).is_some() {
365            let prev = self.closed_streams.insert(id, ClosedStreamEnt::EndReceived);
366            debug_assert!(prev.is_none(), "Unexpected duplicate entry for {id}");
367            return Ok(());
368        }
369        let hash_map::Entry::Occupied(closed_entry) = self.closed_streams.entry(id) else {
370            return Err(Error::CircProto(
371                "Received END cell on nonexistent stream".into(),
372            ));
373        };
374        // Progress the stream's state machine accordingly
375        match closed_entry.get() {
376            ClosedStreamEnt::EndReceived => Err(Error::CircProto(
377                "Received two END cells on same stream".into(),
378            )),
379            ClosedStreamEnt::EndSent { .. } => {
380                debug!("Actually got an end cell on a half-closed stream!");
381                // We got an END, and we already sent an END. Great!
382                // we can forget about this stream.
383                closed_entry.remove_entry();
384                Ok(())
385            }
386        }
387    }
388
389    /// Handle a termination of the stream with `id` from this side of
390    /// the circuit. Return true if the stream was open and an END
391    /// ought to be sent.
392    pub(super) fn terminate(
393        &mut self,
394        id: StreamId,
395        why: TerminateReason,
396    ) -> Result<ShouldSendEnd> {
397        use TerminateReason as TR;
398
399        if let Some((_id, _priority, ent)) = self.open_streams.remove(&id) {
400            let OpenStreamEntStream {
401                inner:
402                    OpenStreamEnt {
403                        flow_ctrl,
404                        dropped,
405                        cmd_checker,
406                        // notably absent: the channels for sink and stream, which will get dropped and
407                        // closed (meaning reads/writes from/to this stream will now fail)
408                        ..
409                    },
410            } = ent;
411            // FIXME(eta): we don't copy the receive window, instead just creating a new one,
412            //             so a malicious peer can send us slightly more data than they should
413            //             be able to; see arti#230.
414            let mut recv_window = sendme::StreamRecvWindow::new(RECV_WINDOW_INIT);
415            recv_window.decrement_n(dropped)?;
416            // TODO: would be nice to avoid new_ref.
417            let half_stream = HalfStream::new(flow_ctrl, recv_window, cmd_checker);
418            let explicitly_dropped = why == TR::StreamTargetClosed;
419            let prev = self.closed_streams.insert(
420                id,
421                ClosedStreamEnt::EndSent(EndSentStreamEnt {
422                    half_stream,
423                    explicitly_dropped,
424                }),
425            );
426            debug_assert!(prev.is_none(), "Unexpected duplicate entry for {id}");
427            return Ok(ShouldSendEnd::Send);
428        }
429
430        // Progress the stream's state machine accordingly
431        match self
432            .closed_streams
433            .remove(&id)
434            .ok_or_else(|| Error::from(internal!("Somehow we terminated a nonexistent stream?")))?
435        {
436            ClosedStreamEnt::EndReceived => Ok(ShouldSendEnd::DontSend),
437            ClosedStreamEnt::EndSent(EndSentStreamEnt {
438                ref mut explicitly_dropped,
439                ..
440            }) => match (*explicitly_dropped, why) {
441                (false, TR::StreamTargetClosed) => {
442                    *explicitly_dropped = true;
443                    Ok(ShouldSendEnd::DontSend)
444                }
445                (true, TR::StreamTargetClosed) => {
446                    Err(bad_api_usage!("Tried to close an already closed stream.").into())
447                }
448                (_, TR::ExplicitEnd) => Err(bad_api_usage!(
449                    "Tried to end an already closed stream. (explicitly_dropped={:?})",
450                    *explicitly_dropped
451                )
452                .into()),
453            },
454        }
455    }
456
457    /// Get an up-to-date iterator of streams with ready items. `Option<AnyRelayMsg>::None`
458    /// indicates that the local sender has been dropped.
459    ///
460    /// Conceptually all streams are in a queue; new streams are added to the
461    /// back of the queue, and a stream is sent to the back of the queue
462    /// whenever a ready message is taken from it (via
463    /// [`Self::take_ready_msg`]). The returned iterator is an ordered view of
464    /// this queue, showing the subset of streams that have a message ready to
465    /// send, or whose sender has been dropped.
466    pub(super) fn poll_ready_streams_iter<'a>(
467        &'a mut self,
468        cx: &mut std::task::Context,
469    ) -> impl Iterator<Item = (StreamId, Option<&'a AnyRelayMsg>)> + 'a {
470        self.open_streams
471            .poll_ready_iter_mut(cx)
472            .map(|(sid, _priority, ent)| {
473                let ent = Pin::new(ent);
474                let msg = ent.unobtrusive_peek();
475                (*sid, msg)
476            })
477    }
478
479    /// If the stream `sid` has a message ready, take it, and reprioritize `sid`
480    /// to the "back of the line" with respect to
481    /// [`Self::poll_ready_streams_iter`].
482    pub(super) fn take_ready_msg(&mut self, sid: StreamId) -> Option<AnyRelayMsg> {
483        let new_priority = self.take_next_priority();
484        let (_prev_priority, val) = self
485            .open_streams
486            .take_ready_value_and_reprioritize(&sid, new_priority)?;
487        Some(val)
488    }
489
490    // TODO: Eventually if we want relay support, we'll need to support
491    // stream IDs chosen by somebody else. But for now, we don't need those.
492}
493
494/// A reason for terminating a stream.
495///
496/// We use this type in order to ensure that we obey the API restrictions of [`StreamMap::terminate`]
497#[derive(Copy, Clone, Debug, PartialEq, Eq)]
498pub(super) enum TerminateReason {
499    /// Closing a stream because the receiver got `Ok(None)`, indicating that the
500    /// corresponding senders were all dropped.
501    StreamTargetClosed,
502    /// Closing a stream because we were explicitly told to end it via
503    /// [`StreamTarget::close_pending`](crate::tunnel::StreamTarget::close_pending).
504    ExplicitEnd,
505}
506
507/// Convenience function for doing a wrapping increment of a `StreamId`.
508fn wrapping_next_stream_id(id: StreamId) -> StreamId {
509    let next_val = NonZeroU16::from(id)
510        .checked_add(1)
511        .unwrap_or_else(|| NonZeroU16::new(1).expect("Impossibly got 0 value"));
512    next_val.into()
513}
514
515#[cfg(test)]
516mod test {
517    // @@ begin test lint list maintained by maint/add_warning @@
518    #![allow(clippy::bool_assert_comparison)]
519    #![allow(clippy::clone_on_copy)]
520    #![allow(clippy::dbg_macro)]
521    #![allow(clippy::mixed_attributes_style)]
522    #![allow(clippy::print_stderr)]
523    #![allow(clippy::print_stdout)]
524    #![allow(clippy::single_char_pattern)]
525    #![allow(clippy::unwrap_used)]
526    #![allow(clippy::unchecked_duration_subtraction)]
527    #![allow(clippy::useless_vec)]
528    #![allow(clippy::needless_pass_by_value)]
529    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
530    use super::*;
531    use crate::tunnel::circuit::test::fake_mpsc;
532    use crate::{congestion::sendme::StreamSendWindow, stream::DataCmdChecker};
533
534    #[test]
535    fn test_wrapping_next_stream_id() {
536        let one = StreamId::new(1).unwrap();
537        let two = StreamId::new(2).unwrap();
538        let max = StreamId::new(0xffff).unwrap();
539        assert_eq!(wrapping_next_stream_id(one), two);
540        assert_eq!(wrapping_next_stream_id(max), one);
541    }
542
543    #[test]
544    #[allow(clippy::cognitive_complexity)]
545    fn streammap_basics() -> Result<()> {
546        let mut map = StreamMap::new();
547        let mut next_id = map.next_stream_id;
548        let mut ids = Vec::new();
549
550        assert_eq!(map.n_open_streams(), 0);
551
552        // Try add_ent
553        for n in 1..=128 {
554            let (sink, _) = fake_mpsc(128);
555            let (_, rx) = fake_mpsc(2);
556            let id = map.add_ent(
557                sink,
558                rx,
559                StreamSendFlowControl::new_window_based(StreamSendWindow::new(500)),
560                DataCmdChecker::new_any(),
561            )?;
562            let expect_id: StreamId = next_id;
563            assert_eq!(expect_id, id);
564            next_id = wrapping_next_stream_id(next_id);
565            ids.push(id);
566            assert_eq!(map.n_open_streams(), n);
567        }
568
569        // Test get_mut.
570        let nonesuch_id = next_id;
571        assert!(matches!(
572            map.get_mut(ids[0]),
573            Some(StreamEntMut::Open { .. })
574        ));
575        assert!(map.get_mut(nonesuch_id).is_none());
576
577        // Test end_received
578        assert!(map.ending_msg_received(nonesuch_id).is_err());
579        assert_eq!(map.n_open_streams(), 128);
580        assert!(map.ending_msg_received(ids[1]).is_ok());
581        assert_eq!(map.n_open_streams(), 127);
582        assert!(matches!(
583            map.get_mut(ids[1]),
584            Some(StreamEntMut::EndReceived)
585        ));
586        assert!(map.ending_msg_received(ids[1]).is_err());
587
588        // Test terminate
589        use TerminateReason as TR;
590        assert!(map.terminate(nonesuch_id, TR::ExplicitEnd).is_err());
591        assert_eq!(map.n_open_streams(), 127);
592        assert_eq!(
593            map.terminate(ids[2], TR::ExplicitEnd).unwrap(),
594            ShouldSendEnd::Send
595        );
596        assert_eq!(map.n_open_streams(), 126);
597        assert!(matches!(
598            map.get_mut(ids[2]),
599            Some(StreamEntMut::EndSent { .. })
600        ));
601        assert_eq!(
602            map.terminate(ids[1], TR::ExplicitEnd).unwrap(),
603            ShouldSendEnd::DontSend
604        );
605        // This stream was already closed when we called `ending_msg_received`
606        // above.
607        assert_eq!(map.n_open_streams(), 126);
608        assert!(map.get_mut(ids[1]).is_none());
609
610        // Try receiving an end after a terminate.
611        assert!(map.ending_msg_received(ids[2]).is_ok());
612        assert!(map.get_mut(ids[2]).is_none());
613        assert_eq!(map.n_open_streams(), 126);
614
615        Ok(())
616    }
617}