tor_proto/
congestion.rs

1//! Congestion control subsystem.
2//!
3//! This object is attached to a circuit hop (CircHop) and controls the logic for the congestion
4//! control support of the Tor Network. It also manages the circuit level SENDME logic which is
5//! part of congestion control.
6//!
7//! # Implementation
8//!
9//! The basics of this subsystem is that it is notified when a DATA cell is received or sent. This
10//! in turn updates the congestion control state so that the very important
11//! [`can_send`](CongestionControl::can_send) function be accurate to decide if a DATA cell can be
12//! sent or not.
13//!
14//! Any part of the arti code that wants to send a DATA cell on the wire needs to call
15//! [`can_send`](CongestionControl::can_send) before else we'll risk leaving the circuit in a
16//! protocol violation state.
17//!
18//! Furthermore, as we receive and emit SENDMEs, it also has entry point for those two events in
19//! order to update the state.
20
21#[cfg(any(test, feature = "testing"))]
22pub(crate) mod test_utils;
23
24mod fixed;
25pub mod params;
26mod rtt;
27pub(crate) mod sendme;
28mod vegas;
29
30use std::time::Instant;
31
32use crate::{Error, Result};
33
34use self::{
35    params::{Algorithm, CongestionControlParams, CongestionWindowParams},
36    rtt::RoundtripTimeEstimator,
37    sendme::SendmeValidator,
38};
39use tor_cell::relaycell::msg::SendmeTag;
40
41/// This trait defines what a congestion control algorithm must implement in order to interface
42/// with the circuit reactor.
43///
44/// Note that all functions informing the algorithm, as in not getters, return a Result meaning
45/// that on error, it means we can't recover or that there is a protocol violation. In both
46/// cases, the circuit MUST be closed.
47pub(crate) trait CongestionControlAlgorithm: Send + std::fmt::Debug {
48    /// Return true iff this algorithm uses stream level SENDMEs.
49    fn uses_stream_sendme(&self) -> bool;
50    /// Return true iff the next cell is expected to be a SENDME.
51    fn is_next_cell_sendme(&self) -> bool;
52    /// Return true iff a cell can be sent on the wire according to the congestion control
53    /// algorithm.
54    fn can_send(&self) -> bool;
55    /// Return the congestion window object. The reason is returns an Option is because not all
56    /// algorithm uses one and so we avoid acting on it if so.
57    fn cwnd(&self) -> Option<&CongestionWindow>;
58
59    /// Inform the algorithm that we just got a DATA cell.
60    ///
61    /// Return true if a SENDME should be sent immediately or false if not.
62    fn data_received(&mut self) -> Result<bool>;
63    /// Inform the algorithm that we just sent a DATA cell.
64    fn data_sent(&mut self) -> Result<()>;
65    /// Inform the algorithm that we've just received a SENDME.
66    ///
67    /// This is a core function because the algorithm massively update its state when receiving a
68    /// SENDME by using the RTT value and congestion signals.
69    fn sendme_received(
70        &mut self,
71        state: &mut State,
72        rtt: &mut RoundtripTimeEstimator,
73        signals: CongestionSignals,
74    ) -> Result<()>;
75    /// Inform the algorithm that we just sent a SENDME.
76    fn sendme_sent(&mut self) -> Result<()>;
77
78    /// Return the number of in-flight cells (sent but awaiting SENDME ack).
79    ///
80    /// Optional, because not all algorithms track this.
81    #[cfg(feature = "conflux")]
82    fn inflight(&self) -> Option<u32>;
83
84    /// Test Only: Return the congestion window.
85    #[cfg(test)]
86    fn send_window(&self) -> u32;
87}
88
89/// These are congestion signals used by a congestion control algorithm to make decisions. These
90/// signals are various states of our internals. This is not an exhaustive list.
91#[derive(Copy, Clone)]
92pub(crate) struct CongestionSignals {
93    /// Indicate if the channel is blocked.
94    pub(crate) channel_blocked: bool,
95    /// The size of the channel outbound queue.
96    pub(crate) channel_outbound_size: u32,
97}
98
99impl CongestionSignals {
100    /// Constructor
101    pub(crate) fn new(channel_blocked: bool, channel_outbound_size: usize) -> Self {
102        Self {
103            channel_blocked,
104            channel_outbound_size: channel_outbound_size.saturating_add(0) as u32,
105        }
106    }
107}
108
109/// Congestion control state.
110#[derive(Copy, Clone, Default)]
111pub(crate) enum State {
112    /// The initial state any circuit starts in. Used to gradually increase the amount of data
113    /// being transmitted in order to converge towards to optimal capacity.
114    #[default]
115    SlowStart,
116    /// Steady state representing what we think is optimal. This is always after slow start.
117    Steady,
118}
119
120impl State {
121    /// Return true iff this is SlowStart.
122    pub(crate) fn in_slow_start(&self) -> bool {
123        matches!(self, State::SlowStart)
124    }
125}
126
127/// A congestion window. This is generic for all algorithms but their parameters' value will differ
128/// depending on the selected algorithm.
129#[derive(Clone, Debug)]
130pub(crate) struct CongestionWindow {
131    /// Congestion window parameters from the consensus.
132    params: CongestionWindowParams,
133    /// The actual value of our congestion window.
134    value: u32,
135    /// The congestion window is full.
136    is_full: bool,
137}
138
139impl CongestionWindow {
140    /// Constructor taking consensus parameters.
141    fn new(params: &CongestionWindowParams) -> Self {
142        Self {
143            value: params.cwnd_init(),
144            params: params.clone(),
145            is_full: false,
146        }
147    }
148
149    /// Decrement the window by the increment value.
150    pub(crate) fn dec(&mut self) {
151        self.value = self
152            .value
153            .saturating_sub(self.increment())
154            .max(self.params.cwnd_min());
155    }
156
157    /// Increment the window by the increment value.
158    pub(crate) fn inc(&mut self) {
159        self.value = self
160            .value
161            .saturating_add(self.increment())
162            .min(self.params.cwnd_max());
163    }
164
165    /// Return the current value.
166    pub(crate) fn get(&self) -> u32 {
167        self.value
168    }
169
170    /// Return the expected rate for which the congestion window should be updated at.
171    ///
172    /// See `CWND_UPDATE_RATE` in prop324.
173    pub(crate) fn update_rate(&self, state: &State) -> u32 {
174        if state.in_slow_start() {
175            1
176        } else {
177            (self.get() + self.increment_rate() * self.sendme_inc() / 2)
178                / (self.increment_rate() * self.sendme_inc())
179        }
180    }
181
182    /// Return minimum value of the congestion window.
183    pub(crate) fn min(&self) -> u32 {
184        self.params.cwnd_min()
185    }
186
187    /// Set the congestion window value with a new value.
188    pub(crate) fn set(&mut self, value: u32) {
189        self.value = value;
190    }
191
192    /// Return the increment value.
193    pub(crate) fn increment(&self) -> u32 {
194        self.params.cwnd_inc()
195    }
196
197    /// Return the rate at which we should increment the window.
198    pub(crate) fn increment_rate(&self) -> u32 {
199        self.params.cwnd_inc_rate()
200    }
201
202    /// Return true iff this congestion window is full.
203    pub(crate) fn is_full(&self) -> bool {
204        self.is_full
205    }
206
207    /// Reset the full flag meaning it is now not full.
208    pub(crate) fn reset_full(&mut self) {
209        self.is_full = false;
210    }
211
212    /// Return the number of expected SENDMEs per congestion window.
213    ///
214    /// Spec: prop324 SENDME_PER_CWND definition
215    pub(crate) fn sendme_per_cwnd(&self) -> u32 {
216        (self.get() + (self.sendme_inc() / 2)) / self.sendme_inc()
217    }
218
219    /// Return the RFC3742 slow start increment value.
220    ///
221    /// Spec: prop324 rfc3742_ss_inc definition
222    pub(crate) fn rfc3742_ss_inc(&mut self, ss_cap: u32) -> u32 {
223        let inc = if self.get() <= ss_cap {
224            ((self.params.cwnd_inc_pct_ss().as_percent() * self.sendme_inc()) + 50) / 100
225        } else {
226            (((self.sendme_inc() * ss_cap) + self.get()) / (self.get() * 2)).max(1)
227        };
228        self.value += inc;
229        inc
230    }
231
232    /// Evaluate the fullness of the window with the given parameters.
233    ///
234    /// Spec: prop324 see cwnd_is_full and cwnd_is_nonfull definition.
235    /// C-tor: cwnd_became_full() and cwnd_became_nonfull()
236    pub(crate) fn eval_fullness(&mut self, inflight: u32, full_gap: u32, full_minpct: u32) {
237        if (inflight + (self.sendme_inc() * full_gap)) >= self.get() {
238            self.is_full = true;
239        } else if (100 * inflight) < (full_minpct * self.get()) {
240            self.is_full = false;
241        }
242    }
243
244    /// Return the SENDME increment value.
245    pub(crate) fn sendme_inc(&self) -> u32 {
246        self.params.sendme_inc()
247    }
248
249    /// Return the congestion window params.
250    #[cfg(any(test, feature = "conflux"))]
251    pub(crate) fn params(&self) -> &CongestionWindowParams {
252        &self.params
253    }
254}
255
256/// Congestion control state of a hop on a circuit.
257///
258/// This controls the entire logic of congestion control and circuit level SENDMEs.
259pub(crate) struct CongestionControl {
260    /// Which congestion control state are we in?
261    state: State,
262    /// This is the SENDME validator as in it keeps track of the circuit tag found within an
263    /// authenticated SENDME cell. It can store the tags and validate a tag against our queue of
264    /// expected values.
265    sendme_validator: SendmeValidator<SendmeTag>,
266    /// The RTT estimator for the circuit we are attached on.
267    rtt: RoundtripTimeEstimator,
268    /// The congestion control algorithm.
269    algorithm: Box<dyn CongestionControlAlgorithm>,
270}
271
272impl CongestionControl {
273    /// Construct a new CongestionControl
274    pub(crate) fn new(params: &CongestionControlParams) -> Self {
275        let state = State::default();
276        // Use what the consensus tells us to use.
277        let algorithm: Box<dyn CongestionControlAlgorithm> = match params.alg() {
278            Algorithm::FixedWindow(p) => Box::new(fixed::FixedWindow::new(p.circ_window_start())),
279            Algorithm::Vegas(ref p) => {
280                let cwnd = CongestionWindow::new(params.cwnd_params());
281                Box::new(vegas::Vegas::new(p, &state, cwnd))
282            }
283        };
284        Self {
285            algorithm,
286            rtt: RoundtripTimeEstimator::new(params.rtt_params()),
287            sendme_validator: SendmeValidator::new(),
288            state,
289        }
290    }
291
292    /// Return true iff the underlying algorithm uses stream level SENDMEs.
293    /// At the moment, only FixedWindow uses it. It has been eliminated with Vegas.
294    pub(crate) fn uses_stream_sendme(&self) -> bool {
295        self.algorithm.uses_stream_sendme()
296    }
297
298    /// Return true iff a DATA cell is allowed to be sent based on the congestion control state.
299    pub(crate) fn can_send(&self) -> bool {
300        self.algorithm.can_send()
301    }
302
303    /// Called when a SENDME cell is received.
304    ///
305    /// An error is returned if there is a protocol violation with regards to congestion control.
306    pub(crate) fn note_sendme_received(
307        &mut self,
308        tag: SendmeTag,
309        signals: CongestionSignals,
310    ) -> Result<()> {
311        // This MUST be the first thing that we do that is validate the SENDME. Any error leads to
312        // closing the circuit.
313        self.sendme_validator.validate(Some(tag))?;
314
315        // Update our RTT estimate if the algorithm yields back a congestion window. RTT
316        // measurements only make sense for a congestion window. For example, FixedWindow here
317        // doesn't use it and so no need for the RTT.
318        if let Some(cwnd) = self.algorithm.cwnd() {
319            self.rtt
320                .update(Instant::now(), &self.state, cwnd)
321                .map_err(|e| Error::CircProto(e.to_string()))?;
322        }
323
324        // Notify the algorithm that we've received a SENDME.
325        self.algorithm
326            .sendme_received(&mut self.state, &mut self.rtt, signals)
327    }
328
329    /// Called when a SENDME cell is sent.
330    pub(crate) fn note_sendme_sent(&mut self) -> Result<()> {
331        self.algorithm.sendme_sent()
332    }
333
334    /// Called when a DATA cell is received.
335    ///
336    /// Returns true iff a SENDME should be sent false otherwise. An error is returned if there is
337    /// a protocol violation with regards to flow or congestion control.
338    pub(crate) fn note_data_received(&mut self) -> Result<bool> {
339        self.algorithm.data_received()
340    }
341
342    /// Called when a DATA cell is sent.
343    ///
344    /// An error is returned if there is a protocol violation with regards to flow or congestion
345    /// control.
346    pub(crate) fn note_data_sent<U>(&mut self, tag: &U) -> Result<()>
347    where
348        U: Clone + Into<SendmeTag>,
349    {
350        // Inform the algorithm that the data was just sent. This is important to be the very first
351        // thing so the congestion window can be updated accordingly making the following calls
352        // using the latest data.
353        self.algorithm.data_sent()?;
354
355        // If next cell is a SENDME, we need to record the tag of this cell in order to validate
356        // the next SENDME when it arrives.
357        if self.algorithm.is_next_cell_sendme() {
358            self.sendme_validator.record(tag);
359            // Only keep the SENDME timestamp if the algorithm has a congestion window.
360            if self.algorithm.cwnd().is_some() {
361                self.rtt.expect_sendme(Instant::now());
362            }
363        }
364
365        Ok(())
366    }
367
368    /// Return the number of in-flight cells (sent but awaiting SENDME ack).
369    ///
370    /// Optional, because not all algorithms track this.
371    #[cfg(feature = "conflux")]
372    pub(crate) fn inflight(&self) -> Option<u32> {
373        self.algorithm.inflight()
374    }
375
376    /// Return the congestion window object.
377    ///
378    /// Optional, because not all algorithms track this.
379    #[cfg(feature = "conflux")]
380    pub(crate) fn cwnd(&self) -> Option<&CongestionWindow> {
381        self.algorithm.cwnd()
382    }
383
384    /// Return a reference to the RTT estimator.
385    ///
386    /// Used for conflux, for choosing the best circuit to send on.
387    #[cfg(feature = "conflux")]
388    pub(crate) fn rtt(&self) -> &RoundtripTimeEstimator {
389        &self.rtt
390    }
391}
392
393#[cfg(test)]
394mod test {
395    // @@ begin test lint list maintained by maint/add_warning @@
396    #![allow(clippy::bool_assert_comparison)]
397    #![allow(clippy::clone_on_copy)]
398    #![allow(clippy::dbg_macro)]
399    #![allow(clippy::mixed_attributes_style)]
400    #![allow(clippy::print_stderr)]
401    #![allow(clippy::print_stdout)]
402    #![allow(clippy::single_char_pattern)]
403    #![allow(clippy::unwrap_used)]
404    #![allow(clippy::unchecked_duration_subtraction)]
405    #![allow(clippy::useless_vec)]
406    #![allow(clippy::needless_pass_by_value)]
407    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
408
409    use crate::congestion::test_utils::new_cwnd;
410
411    use super::CongestionControl;
412    use tor_cell::relaycell::msg::SendmeTag;
413
414    impl CongestionControl {
415        /// For testing: get a copy of the current send window, and the
416        /// expected incoming tags.
417        pub(crate) fn send_window_and_expected_tags(&self) -> (u32, Vec<SendmeTag>) {
418            (
419                self.algorithm.send_window(),
420                self.sendme_validator.expected_tags(),
421            )
422        }
423    }
424
425    #[test]
426    fn test_cwnd() {
427        let mut cwnd = new_cwnd();
428
429        // Validate the getters are coherent with initialization.
430        assert_eq!(cwnd.get(), cwnd.params().cwnd_init());
431        assert_eq!(cwnd.min(), cwnd.params().cwnd_min());
432        assert_eq!(cwnd.increment(), cwnd.params().cwnd_inc());
433        assert_eq!(cwnd.increment_rate(), cwnd.params().cwnd_inc_rate());
434        assert_eq!(cwnd.sendme_inc(), cwnd.params().sendme_inc());
435        assert!(!cwnd.is_full());
436
437        // Validate changes.
438        cwnd.inc();
439        assert_eq!(
440            cwnd.get(),
441            cwnd.params().cwnd_init() + cwnd.params().cwnd_inc()
442        );
443        cwnd.dec();
444        assert_eq!(cwnd.get(), cwnd.params().cwnd_init());
445    }
446}