tor_proto/tunnel/
streammap.rs

1//! Types and code for mapping StreamIDs to streams on a circuit.
2
3use crate::congestion::sendme;
4use crate::stream::{AnyCmdChecker, StreamSendFlowControl};
5use crate::tunnel::circuit::{StreamMpscReceiver, StreamMpscSender};
6use crate::tunnel::halfstream::HalfStream;
7use crate::tunnel::reactor::circuit::RECV_WINDOW_INIT;
8use crate::util::stream_poll_set::{KeyAlreadyInsertedError, StreamPollSet};
9use crate::{Error, Result};
10use pin_project::pin_project;
11use tor_async_utils::peekable_stream::{PeekableStream, UnobtrusivePeekableStream};
12use tor_async_utils::stream_peek::StreamUnobtrusivePeeker;
13use tor_cell::relaycell::{msg::AnyRelayMsg, StreamId};
14use tor_cell::relaycell::{RelayMsg, UnparsedRelayMsg};
15
16use std::collections::hash_map;
17use std::collections::HashMap;
18use std::num::NonZeroU16;
19use std::pin::Pin;
20use std::task::{Poll, Waker};
21use tor_error::{bad_api_usage, internal};
22
23use rand::Rng;
24
25use tracing::debug;
26
27/// Entry for an open stream
28///
29/// (For the purposes of this module, an open stream is one where we have not
30/// sent or received any message indicating that the stream is ended.)
31#[derive(Debug)]
32#[pin_project]
33pub(super) struct OpenStreamEnt {
34    /// Sink to send relay cells tagged for this stream into.
35    pub(super) sink: StreamMpscSender<UnparsedRelayMsg>,
36    /// Number of cells dropped due to the stream disappearing before we can
37    /// transform this into an `EndSent`.
38    pub(super) dropped: u16,
39    /// A `CmdChecker` used to tell whether cells on this stream are valid.
40    pub(super) cmd_checker: AnyCmdChecker,
41    /// Flow control for this stream.
42    // Non-pub because we need to proxy `put_for_incoming_sendme` to ensure
43    // `flow_ctrl_waker` is woken.
44    flow_ctrl: StreamSendFlowControl,
45    /// Stream for cells that should be sent down this stream.
46    // Not directly exposed. This should only be polled via
47    // `OpenStreamEntStream`s implementation of `Stream`, which in turn should
48    // only be used through `StreamPollSet`.
49    #[pin]
50    rx: StreamUnobtrusivePeeker<StreamMpscReceiver<AnyRelayMsg>>,
51    /// Waker to be woken when more sending capacity becomes available (e.g.
52    /// receiving a SENDME).
53    flow_ctrl_waker: Option<Waker>,
54}
55
56impl OpenStreamEnt {
57    /// Whether this stream is ready to send `msg`.
58    pub(crate) fn can_send<M: RelayMsg>(&self, msg: &M) -> bool {
59        self.flow_ctrl.can_send(msg)
60    }
61
62    /// Handle an incoming sendme.
63    ///
64    /// On success, return the number of cells left in the window.
65    ///
66    /// On failure, return an error: the caller should close the stream or
67    /// circuit with a protocol error.
68    pub(crate) fn put_for_incoming_sendme(&mut self) -> Result<()> {
69        self.flow_ctrl.put_for_incoming_sendme()?;
70        // Wake the stream if it was blocked on flow control.
71        if let Some(waker) = self.flow_ctrl_waker.take() {
72            waker.wake();
73        }
74        Ok(())
75    }
76
77    /// Take capacity to send `msg`. If there's insufficient capacity, returns
78    /// an error. Should be called at the point we've fully committed to
79    /// sending the message.
80    //
81    // TODO: Consider not exposing this, and instead taking the capacity in
82    // `StreamMap::take_ready_msg`.
83    pub(crate) fn take_capacity_to_send<M: RelayMsg>(&mut self, msg: &M) -> Result<()> {
84        self.flow_ctrl.take_capacity_to_send(msg)
85    }
86}
87
88/// Private wrapper over `OpenStreamEnt`. We implement `futures::Stream` for
89/// this wrapper, and not directly for `OpenStreamEnt`, so that client code
90/// can't directly access the stream.
91#[derive(Debug)]
92#[pin_project]
93struct OpenStreamEntStream {
94    /// Inner value.
95    #[pin]
96    inner: OpenStreamEnt,
97}
98
99impl futures::Stream for OpenStreamEntStream {
100    type Item = AnyRelayMsg;
101
102    fn poll_next(
103        mut self: std::pin::Pin<&mut Self>,
104        cx: &mut std::task::Context<'_>,
105    ) -> Poll<Option<Self::Item>> {
106        if !self.as_mut().poll_peek_mut(cx).is_ready() {
107            return Poll::Pending;
108        };
109        let res = self.project().inner.project().rx.poll_next(cx);
110        debug_assert!(res.is_ready());
111        // TODO: consider calling `inner.flow_ctrl.take_capacity_to_send` here;
112        // particularly if we change it to return a wrapper type that proves
113        // we've taken the capacity. Otherwise it'd make it tricky in the reactor
114        // to be sure we've correctly taken the capacity, since messages can originate
115        // in other parts of the code (currently none of those should be of types that
116        // count towards flow control, but that may change).
117        res
118    }
119}
120
121impl PeekableStream for OpenStreamEntStream {
122    fn poll_peek_mut(
123        self: Pin<&mut Self>,
124        cx: &mut std::task::Context<'_>,
125    ) -> Poll<Option<&mut <Self as futures::Stream>::Item>> {
126        let s = self.project();
127        let inner = s.inner.project();
128        let m = match inner.rx.poll_peek_mut(cx) {
129            Poll::Ready(Some(m)) => m,
130            Poll::Ready(None) => return Poll::Ready(None),
131            Poll::Pending => return Poll::Pending,
132        };
133        if !inner.flow_ctrl.can_send(m) {
134            inner.flow_ctrl_waker.replace(cx.waker().clone());
135            return Poll::Pending;
136        }
137        Poll::Ready(Some(m))
138    }
139}
140
141impl UnobtrusivePeekableStream for OpenStreamEntStream {
142    fn unobtrusive_peek_mut(
143        self: std::pin::Pin<&mut Self>,
144    ) -> Option<&mut <Self as futures::Stream>::Item> {
145        let s = self.project();
146        let inner = s.inner.project();
147        let m = inner.rx.unobtrusive_peek_mut()?;
148        if inner.flow_ctrl.can_send(m) {
149            Some(m)
150        } else {
151            None
152        }
153    }
154}
155
156/// Entry for a stream where we have sent an END, or other message
157/// indicating that the stream is terminated.
158#[derive(Debug)]
159pub(super) struct EndSentStreamEnt {
160    /// A "half-stream" that we use to check the validity of incoming
161    /// messages on this stream.
162    pub(super) half_stream: HalfStream,
163    /// True if the sender on this stream has been explicitly dropped;
164    /// false if we got an explicit close from `close_pending`
165    explicitly_dropped: bool,
166}
167
168/// The entry for a stream.
169#[derive(Debug)]
170enum ClosedStreamEnt {
171    /// A stream for which we have received an END cell, but not yet
172    /// had the stream object get dropped.
173    EndReceived,
174    /// A stream for which we have sent an END cell but not yet received an END
175    /// cell.
176    ///
177    /// TODO(arti#264) Can we ever throw this out? Do we really get END cells for
178    /// these?
179    EndSent(EndSentStreamEnt),
180}
181
182/// Mutable reference to a stream entry.
183pub(super) enum StreamEntMut<'a> {
184    /// An open stream.
185    Open(&'a mut OpenStreamEnt),
186    /// A stream for which we have received an END cell, but not yet
187    /// had the stream object get dropped.
188    EndReceived,
189    /// A stream for which we have sent an END cell but not yet received an END
190    /// cell.
191    EndSent(&'a mut EndSentStreamEnt),
192}
193
194impl<'a> From<&'a mut ClosedStreamEnt> for StreamEntMut<'a> {
195    fn from(value: &'a mut ClosedStreamEnt) -> Self {
196        match value {
197            ClosedStreamEnt::EndReceived => Self::EndReceived,
198            ClosedStreamEnt::EndSent(e) => Self::EndSent(e),
199        }
200    }
201}
202
203impl<'a> From<&'a mut OpenStreamEntStream> for StreamEntMut<'a> {
204    fn from(value: &'a mut OpenStreamEntStream) -> Self {
205        Self::Open(&mut value.inner)
206    }
207}
208
209/// Return value to indicate whether or not we send an END cell upon
210/// terminating a given stream.
211#[derive(Debug, Copy, Clone, Eq, PartialEq)]
212pub(super) enum ShouldSendEnd {
213    /// An END cell should be sent.
214    Send,
215    /// An END cell should not be sent.
216    DontSend,
217}
218
219/// A priority for use with [`StreamPollSet`].
220#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
221struct Priority(u64);
222
223/// A map from stream IDs to stream entries. Each circuit has one for each
224/// hop.
225pub(super) struct StreamMap {
226    /// Open streams.
227    // Invariants:
228    // * Keys are disjoint with `closed_streams`.
229    open_streams: StreamPollSet<StreamId, Priority, OpenStreamEntStream>,
230    /// Closed streams.
231    // Invariants:
232    // * Keys are disjoint with `open_streams`.
233    closed_streams: HashMap<StreamId, ClosedStreamEnt>,
234    /// The next StreamId that we should use for a newly allocated
235    /// circuit.
236    next_stream_id: StreamId,
237    /// Next priority to use in `rxs`. We implement round-robin scheduling of
238    /// handling outgoing messages from streams by assigning a stream the next
239    /// priority whenever an outgoing message is processed from that stream,
240    /// putting it last in line.
241    next_priority: Priority,
242}
243
244impl StreamMap {
245    /// Make a new empty StreamMap.
246    pub(super) fn new() -> Self {
247        let mut rng = rand::rng();
248        let next_stream_id: NonZeroU16 = rng.random();
249        StreamMap {
250            open_streams: StreamPollSet::new(),
251            closed_streams: HashMap::new(),
252            next_stream_id: next_stream_id.into(),
253            next_priority: Priority(0),
254        }
255    }
256
257    /// Return the number of open streams in this map.
258    pub(super) fn n_open_streams(&self) -> usize {
259        self.open_streams.len()
260    }
261
262    /// Return the next available priority.
263    fn take_next_priority(&mut self) -> Priority {
264        let rv = self.next_priority;
265        self.next_priority = Priority(rv.0 + 1);
266        rv
267    }
268
269    /// Add an entry to this map; return the newly allocated StreamId.
270    pub(super) fn add_ent(
271        &mut self,
272        sink: StreamMpscSender<UnparsedRelayMsg>,
273        rx: StreamMpscReceiver<AnyRelayMsg>,
274        flow_ctrl: StreamSendFlowControl,
275        cmd_checker: AnyCmdChecker,
276    ) -> Result<StreamId> {
277        let mut stream_ent = OpenStreamEntStream {
278            inner: OpenStreamEnt {
279                sink,
280                flow_ctrl,
281                dropped: 0,
282                cmd_checker,
283                rx: StreamUnobtrusivePeeker::new(rx),
284                flow_ctrl_waker: None,
285            },
286        };
287        let priority = self.take_next_priority();
288        // This "65536" seems too aggressive, but it's what tor does.
289        //
290        // Also, going around in a loop here is (sadly) needed in order
291        // to look like Tor clients.
292        for _ in 1..=65536 {
293            let id: StreamId = self.next_stream_id;
294            self.next_stream_id = wrapping_next_stream_id(self.next_stream_id);
295            stream_ent = match self.open_streams.try_insert(id, priority, stream_ent) {
296                Ok(_) => return Ok(id),
297                Err(KeyAlreadyInsertedError {
298                    key: _,
299                    priority: _,
300                    stream,
301                }) => stream,
302            };
303        }
304
305        Err(Error::IdRangeFull)
306    }
307
308    /// Add an entry to this map using the specified StreamId.
309    #[cfg(feature = "hs-service")]
310    pub(super) fn add_ent_with_id(
311        &mut self,
312        sink: StreamMpscSender<UnparsedRelayMsg>,
313        rx: StreamMpscReceiver<AnyRelayMsg>,
314        flow_ctrl: StreamSendFlowControl,
315        id: StreamId,
316        cmd_checker: AnyCmdChecker,
317    ) -> Result<()> {
318        let stream_ent = OpenStreamEntStream {
319            inner: OpenStreamEnt {
320                sink,
321                flow_ctrl,
322                dropped: 0,
323                cmd_checker,
324                rx: StreamUnobtrusivePeeker::new(rx),
325                flow_ctrl_waker: None,
326            },
327        };
328        let priority = self.take_next_priority();
329        self.open_streams
330            .try_insert(id, priority, stream_ent)
331            .map_err(|_| Error::IdUnavailable(id))
332    }
333
334    /// Return the entry for `id` in this map, if any.
335    pub(super) fn get_mut(&mut self, id: StreamId) -> Option<StreamEntMut<'_>> {
336        if let Some(e) = self.open_streams.stream_mut(&id) {
337            return Some(e.into());
338        }
339        if let Some(e) = self.closed_streams.get_mut(&id) {
340            return Some(e.into());
341        }
342        None
343    }
344
345    /// Note that we received an END message (or other message indicating the end of
346    /// the stream) on the stream with `id`.
347    ///
348    /// Returns true if there was really a stream there.
349    pub(super) fn ending_msg_received(&mut self, id: StreamId) -> Result<()> {
350        if self.open_streams.remove(&id).is_some() {
351            let prev = self.closed_streams.insert(id, ClosedStreamEnt::EndReceived);
352            debug_assert!(prev.is_none(), "Unexpected duplicate entry for {id}");
353            return Ok(());
354        }
355        let hash_map::Entry::Occupied(closed_entry) = self.closed_streams.entry(id) else {
356            return Err(Error::CircProto(
357                "Received END cell on nonexistent stream".into(),
358            ));
359        };
360        // Progress the stream's state machine accordingly
361        match closed_entry.get() {
362            ClosedStreamEnt::EndReceived => Err(Error::CircProto(
363                "Received two END cells on same stream".into(),
364            )),
365            ClosedStreamEnt::EndSent { .. } => {
366                debug!("Actually got an end cell on a half-closed stream!");
367                // We got an END, and we already sent an END. Great!
368                // we can forget about this stream.
369                closed_entry.remove_entry();
370                Ok(())
371            }
372        }
373    }
374
375    /// Handle a termination of the stream with `id` from this side of
376    /// the circuit. Return true if the stream was open and an END
377    /// ought to be sent.
378    pub(super) fn terminate(
379        &mut self,
380        id: StreamId,
381        why: TerminateReason,
382    ) -> Result<ShouldSendEnd> {
383        use TerminateReason as TR;
384
385        if let Some((_id, _priority, ent)) = self.open_streams.remove(&id) {
386            let OpenStreamEntStream {
387                inner:
388                    OpenStreamEnt {
389                        flow_ctrl,
390                        dropped,
391                        cmd_checker,
392                        // notably absent: the channels for sink and stream, which will get dropped and
393                        // closed (meaning reads/writes from/to this stream will now fail)
394                        ..
395                    },
396            } = ent;
397            // FIXME(eta): we don't copy the receive window, instead just creating a new one,
398            //             so a malicious peer can send us slightly more data than they should
399            //             be able to; see arti#230.
400            let mut recv_window = sendme::StreamRecvWindow::new(RECV_WINDOW_INIT);
401            recv_window.decrement_n(dropped)?;
402            // TODO: would be nice to avoid new_ref.
403            let half_stream = HalfStream::new(flow_ctrl, recv_window, cmd_checker);
404            let explicitly_dropped = why == TR::StreamTargetClosed;
405            let prev = self.closed_streams.insert(
406                id,
407                ClosedStreamEnt::EndSent(EndSentStreamEnt {
408                    half_stream,
409                    explicitly_dropped,
410                }),
411            );
412            debug_assert!(prev.is_none(), "Unexpected duplicate entry for {id}");
413            return Ok(ShouldSendEnd::Send);
414        }
415
416        // Progress the stream's state machine accordingly
417        match self
418            .closed_streams
419            .remove(&id)
420            .ok_or_else(|| Error::from(internal!("Somehow we terminated a nonexistent stream?")))?
421        {
422            ClosedStreamEnt::EndReceived => Ok(ShouldSendEnd::DontSend),
423            ClosedStreamEnt::EndSent(EndSentStreamEnt {
424                ref mut explicitly_dropped,
425                ..
426            }) => match (*explicitly_dropped, why) {
427                (false, TR::StreamTargetClosed) => {
428                    *explicitly_dropped = true;
429                    Ok(ShouldSendEnd::DontSend)
430                }
431                (true, TR::StreamTargetClosed) => {
432                    Err(bad_api_usage!("Tried to close an already closed stream.").into())
433                }
434                (_, TR::ExplicitEnd) => Err(bad_api_usage!(
435                    "Tried to end an already closed stream. (explicitly_dropped={:?})",
436                    *explicitly_dropped
437                )
438                .into()),
439            },
440        }
441    }
442
443    /// Get an up-to-date iterator of streams with ready items. `Option<AnyRelayMsg>::None`
444    /// indicates that the local sender has been dropped.
445    ///
446    /// Conceptually all streams are in a queue; new streams are added to the
447    /// back of the queue, and a stream is sent to the back of the queue
448    /// whenever a ready message is taken from it (via
449    /// [`Self::take_ready_msg`]). The returned iterator is an ordered view of
450    /// this queue, showing the subset of streams that have a message ready to
451    /// send, or whose sender has been dropped.
452    pub(super) fn poll_ready_streams_iter<'a>(
453        &'a mut self,
454        cx: &mut std::task::Context,
455    ) -> impl Iterator<Item = (StreamId, Option<&'a AnyRelayMsg>)> + 'a {
456        self.open_streams
457            .poll_ready_iter_mut(cx)
458            .map(|(sid, _priority, ent)| {
459                let ent = Pin::new(ent);
460                let msg = ent.unobtrusive_peek();
461                (*sid, msg)
462            })
463    }
464
465    /// If the stream `sid` has a message ready, take it, and reprioritize `sid`
466    /// to the "back of the line" with respect to
467    /// [`Self::poll_ready_streams_iter`].
468    pub(super) fn take_ready_msg(&mut self, sid: StreamId) -> Option<AnyRelayMsg> {
469        let new_priority = self.take_next_priority();
470        let (_prev_priority, val) = self
471            .open_streams
472            .take_ready_value_and_reprioritize(&sid, new_priority)?;
473        Some(val)
474    }
475
476    // TODO: Eventually if we want relay support, we'll need to support
477    // stream IDs chosen by somebody else. But for now, we don't need those.
478}
479
480/// A reason for terminating a stream.
481///
482/// We use this type in order to ensure that we obey the API restrictions of [`StreamMap::terminate`]
483#[derive(Copy, Clone, Debug, PartialEq, Eq)]
484pub(super) enum TerminateReason {
485    /// Closing a stream because the receiver got `Ok(None)`, indicating that the
486    /// corresponding senders were all dropped.
487    StreamTargetClosed,
488    /// Closing a stream because we were explicitly told to end it via
489    /// [`StreamTarget::close_pending`](crate::tunnel::StreamTarget::close_pending).
490    ExplicitEnd,
491}
492
493/// Convenience function for doing a wrapping increment of a `StreamId`.
494fn wrapping_next_stream_id(id: StreamId) -> StreamId {
495    let next_val = NonZeroU16::from(id)
496        .checked_add(1)
497        .unwrap_or_else(|| NonZeroU16::new(1).expect("Impossibly got 0 value"));
498    next_val.into()
499}
500
501#[cfg(test)]
502mod test {
503    // @@ begin test lint list maintained by maint/add_warning @@
504    #![allow(clippy::bool_assert_comparison)]
505    #![allow(clippy::clone_on_copy)]
506    #![allow(clippy::dbg_macro)]
507    #![allow(clippy::mixed_attributes_style)]
508    #![allow(clippy::print_stderr)]
509    #![allow(clippy::print_stdout)]
510    #![allow(clippy::single_char_pattern)]
511    #![allow(clippy::unwrap_used)]
512    #![allow(clippy::unchecked_duration_subtraction)]
513    #![allow(clippy::useless_vec)]
514    #![allow(clippy::needless_pass_by_value)]
515    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
516    use super::*;
517    use crate::tunnel::circuit::test::fake_mpsc;
518    use crate::{congestion::sendme::StreamSendWindow, stream::DataCmdChecker};
519
520    #[test]
521    fn test_wrapping_next_stream_id() {
522        let one = StreamId::new(1).unwrap();
523        let two = StreamId::new(2).unwrap();
524        let max = StreamId::new(0xffff).unwrap();
525        assert_eq!(wrapping_next_stream_id(one), two);
526        assert_eq!(wrapping_next_stream_id(max), one);
527    }
528
529    #[test]
530    #[allow(clippy::cognitive_complexity)]
531    fn streammap_basics() -> Result<()> {
532        let mut map = StreamMap::new();
533        let mut next_id = map.next_stream_id;
534        let mut ids = Vec::new();
535
536        assert_eq!(map.n_open_streams(), 0);
537
538        // Try add_ent
539        for n in 1..=128 {
540            let (sink, _) = fake_mpsc(128);
541            let (_, rx) = fake_mpsc(2);
542            let id = map.add_ent(
543                sink,
544                rx,
545                StreamSendFlowControl::new_window_based(StreamSendWindow::new(500)),
546                DataCmdChecker::new_any(),
547            )?;
548            let expect_id: StreamId = next_id;
549            assert_eq!(expect_id, id);
550            next_id = wrapping_next_stream_id(next_id);
551            ids.push(id);
552            assert_eq!(map.n_open_streams(), n);
553        }
554
555        // Test get_mut.
556        let nonesuch_id = next_id;
557        assert!(matches!(
558            map.get_mut(ids[0]),
559            Some(StreamEntMut::Open { .. })
560        ));
561        assert!(map.get_mut(nonesuch_id).is_none());
562
563        // Test end_received
564        assert!(map.ending_msg_received(nonesuch_id).is_err());
565        assert_eq!(map.n_open_streams(), 128);
566        assert!(map.ending_msg_received(ids[1]).is_ok());
567        assert_eq!(map.n_open_streams(), 127);
568        assert!(matches!(
569            map.get_mut(ids[1]),
570            Some(StreamEntMut::EndReceived)
571        ));
572        assert!(map.ending_msg_received(ids[1]).is_err());
573
574        // Test terminate
575        use TerminateReason as TR;
576        assert!(map.terminate(nonesuch_id, TR::ExplicitEnd).is_err());
577        assert_eq!(map.n_open_streams(), 127);
578        assert_eq!(
579            map.terminate(ids[2], TR::ExplicitEnd).unwrap(),
580            ShouldSendEnd::Send
581        );
582        assert_eq!(map.n_open_streams(), 126);
583        assert!(matches!(
584            map.get_mut(ids[2]),
585            Some(StreamEntMut::EndSent { .. })
586        ));
587        assert_eq!(
588            map.terminate(ids[1], TR::ExplicitEnd).unwrap(),
589            ShouldSendEnd::DontSend
590        );
591        // This stream was already closed when we called `ending_msg_received`
592        // above.
593        assert_eq!(map.n_open_streams(), 126);
594        assert!(map.get_mut(ids[1]).is_none());
595
596        // Try receiving an end after a terminate.
597        assert!(map.ending_msg_received(ids[2]).is_ok());
598        assert!(map.get_mut(ids[2]).is_none());
599        assert_eq!(map.n_open_streams(), 126);
600
601        Ok(())
602    }
603}