1
//! Congestion control subsystem.
2
//!
3
//! This object is attached to a circuit hop (CircHop) and controls the logic for the congestion
4
//! control support of the Tor Network. It also manages the circuit level SENDME logic which is
5
//! part of congestion control.
6
//!
7
//! # Implementation
8
//!
9
//! The basics of this subsystem is that it is notified when a DATA cell is received or sent. This
10
//! in turn updates the congestion control state so that the very important
11
//! [`can_send`](CongestionControl::can_send) function be accurate to decide if a DATA cell can be
12
//! sent or not.
13
//!
14
//! Any part of the arti code that wants to send a DATA cell on the wire needs to call
15
//! [`can_send`](CongestionControl::can_send) before else we'll risk leaving the circuit in a
16
//! protocol violation state.
17
//!
18
//! Furthermore, as we receive and emit SENDMEs, it also has entry point for those two events in
19
//! order to update the state.
20

            
21
#[cfg(any(test, feature = "testing"))]
22
pub(crate) mod test_utils;
23

            
24
mod fixed;
25
pub mod params;
26
mod rtt;
27
pub(crate) mod sendme;
28
mod vegas;
29

            
30
use std::time::Instant;
31

            
32
use crate::{Error, Result};
33

            
34
use self::{
35
    params::{Algorithm, CongestionControlParams, CongestionWindowParams},
36
    rtt::RoundtripTimeEstimator,
37
    sendme::SendmeValidator,
38
};
39
use tor_cell::relaycell::msg::SendmeTag;
40

            
41
/// This trait defines what a congestion control algorithm must implement in order to interface
42
/// with the circuit reactor.
43
///
44
/// Note that all functions informing the algorithm, as in not getters, return a Result meaning
45
/// that on error, it means we can't recover or that there is a protocol violation. In both
46
/// cases, the circuit MUST be closed.
47
pub(crate) trait CongestionControlAlgorithm: Send + std::fmt::Debug {
48
    /// Return true iff this algorithm uses stream level SENDMEs.
49
    fn uses_stream_sendme(&self) -> bool;
50
    /// Return true iff the next cell is expected to be a SENDME.
51
    fn is_next_cell_sendme(&self) -> bool;
52
    /// Return true iff a cell can be sent on the wire according to the congestion control
53
    /// algorithm.
54
    fn can_send(&self) -> bool;
55
    /// Return the congestion window object. The reason is returns an Option is because not all
56
    /// algorithm uses one and so we avoid acting on it if so.
57
    fn cwnd(&self) -> Option<&CongestionWindow>;
58

            
59
    /// Inform the algorithm that we just got a DATA cell.
60
    ///
61
    /// Return true if a SENDME should be sent immediately or false if not.
62
    fn data_received(&mut self) -> Result<bool>;
63
    /// Inform the algorithm that we just sent a DATA cell.
64
    fn data_sent(&mut self) -> Result<()>;
65
    /// Inform the algorithm that we've just received a SENDME.
66
    ///
67
    /// This is a core function because the algorithm massively update its state when receiving a
68
    /// SENDME by using the RTT value and congestion signals.
69
    fn sendme_received(
70
        &mut self,
71
        state: &mut State,
72
        rtt: &mut RoundtripTimeEstimator,
73
        signals: CongestionSignals,
74
    ) -> Result<()>;
75
    /// Inform the algorithm that we just sent a SENDME.
76
    fn sendme_sent(&mut self) -> Result<()>;
77

            
78
    /// Return the number of in-flight cells (sent but awaiting SENDME ack).
79
    ///
80
    /// Optional, because not all algorithms track this.
81
    #[cfg(feature = "conflux")]
82
    fn inflight(&self) -> Option<u32>;
83

            
84
    /// Test Only: Return the congestion window.
85
    #[cfg(test)]
86
    fn send_window(&self) -> u32;
87
}
88

            
89
/// These are congestion signals used by a congestion control algorithm to make decisions. These
90
/// signals are various states of our internals. This is not an exhaustive list.
91
#[derive(Copy, Clone)]
92
pub(crate) struct CongestionSignals {
93
    /// Indicate if the channel is blocked.
94
    pub(crate) channel_blocked: bool,
95
    /// The size of the channel outbound queue.
96
    pub(crate) channel_outbound_size: u32,
97
}
98

            
99
impl CongestionSignals {
100
    /// Constructor
101
128
    pub(crate) fn new(channel_blocked: bool, channel_outbound_size: usize) -> Self {
102
128
        Self {
103
128
            channel_blocked,
104
128
            channel_outbound_size: channel_outbound_size.saturating_add(0) as u32,
105
128
        }
106
128
    }
107
}
108

            
109
/// Congestion control state.
110
#[derive(Copy, Clone, Default)]
111
pub(crate) enum State {
112
    /// The initial state any circuit starts in. Used to gradually increase the amount of data
113
    /// being transmitted in order to converge towards to optimal capacity.
114
    #[default]
115
    SlowStart,
116
    /// Steady state representing what we think is optimal. This is always after slow start.
117
    Steady,
118
}
119

            
120
impl State {
121
    /// Return true iff this is SlowStart.
122
700
    pub(crate) fn in_slow_start(&self) -> bool {
123
700
        matches!(self, State::SlowStart)
124
700
    }
125
}
126

            
127
/// A congestion window. This is generic for all algorithms but their parameters' value will differ
128
/// depending on the selected algorithm.
129
#[derive(Clone, Debug)]
130
pub(crate) struct CongestionWindow {
131
    /// Congestion window parameters from the consensus.
132
    params: CongestionWindowParams,
133
    /// The actual value of our congestion window.
134
    value: u32,
135
    /// The congestion window is full.
136
    is_full: bool,
137
}
138

            
139
impl CongestionWindow {
140
    /// Constructor taking consensus parameters.
141
36
    fn new(params: &CongestionWindowParams) -> Self {
142
36
        Self {
143
36
            value: params.cwnd_init(),
144
36
            params: params.clone(),
145
36
            is_full: false,
146
36
        }
147
36
    }
148

            
149
    /// Decrement the window by the increment value.
150
10
    pub(crate) fn dec(&mut self) {
151
10
        self.value = self
152
10
            .value
153
10
            .saturating_sub(self.increment())
154
10
            .max(self.params.cwnd_min());
155
10
    }
156

            
157
    /// Increment the window by the increment value.
158
12
    pub(crate) fn inc(&mut self) {
159
12
        self.value = self
160
12
            .value
161
12
            .saturating_add(self.increment())
162
12
            .min(self.params.cwnd_max());
163
12
    }
164

            
165
    /// Return the current value.
166
1118
    pub(crate) fn get(&self) -> u32 {
167
1118
        self.value
168
1118
    }
169

            
170
    /// Return the expected rate for which the congestion window should be updated at.
171
    ///
172
    /// See `CWND_UPDATE_RATE` in prop324.
173
170
    pub(crate) fn update_rate(&self, state: &State) -> u32 {
174
170
        if state.in_slow_start() {
175
102
            1
176
        } else {
177
68
            (self.get() + self.increment_rate() * self.sendme_inc() / 2)
178
68
                / (self.increment_rate() * self.sendme_inc())
179
        }
180
170
    }
181

            
182
    /// Return minimum value of the congestion window.
183
130
    pub(crate) fn min(&self) -> u32 {
184
130
        self.params.cwnd_min()
185
130
    }
186

            
187
    /// Set the congestion window value with a new value.
188
28
    pub(crate) fn set(&mut self, value: u32) {
189
28
        self.value = value;
190
28
    }
191

            
192
    /// Return the increment value.
193
108
    pub(crate) fn increment(&self) -> u32 {
194
108
        self.params.cwnd_inc()
195
108
    }
196

            
197
    /// Return the rate at which we should increment the window.
198
218
    pub(crate) fn increment_rate(&self) -> u32 {
199
218
        self.params.cwnd_inc_rate()
200
218
    }
201

            
202
    /// Return true iff this congestion window is full.
203
224
    pub(crate) fn is_full(&self) -> bool {
204
224
        self.is_full
205
224
    }
206

            
207
    /// Reset the full flag meaning it is now not full.
208
18
    pub(crate) fn reset_full(&mut self) {
209
18
        self.is_full = false;
210
18
    }
211

            
212
    /// Return the number of expected SENDMEs per congestion window.
213
    ///
214
    /// Spec: prop324 SENDME_PER_CWND definition
215
218
    pub(crate) fn sendme_per_cwnd(&self) -> u32 {
216
218
        (self.get() + (self.sendme_inc() / 2)) / self.sendme_inc()
217
218
    }
218

            
219
    /// Return the RFC3742 slow start increment value.
220
    ///
221
    /// Spec: prop324 rfc3742_ss_inc definition
222
80
    pub(crate) fn rfc3742_ss_inc(&mut self, ss_cap: u32) -> u32 {
223
80
        let inc = if self.get() <= ss_cap {
224
74
            ((self.params.cwnd_inc_pct_ss().as_percent() * self.sendme_inc()) + 50) / 100
225
        } else {
226
6
            (((self.sendme_inc() * ss_cap) + self.get()) / (self.get() * 2)).max(1)
227
        };
228
80
        self.value += inc;
229
80
        inc
230
80
    }
231

            
232
    /// Evaluate the fullness of the window with the given parameters.
233
    ///
234
    /// Spec: prop324 see cwnd_is_full and cwnd_is_nonfull definition.
235
    /// C-tor: cwnd_became_full() and cwnd_became_nonfull()
236
120
    pub(crate) fn eval_fullness(&mut self, inflight: u32, full_gap: u32, full_minpct: u32) {
237
120
        if (inflight + (self.sendme_inc() * full_gap)) >= self.get() {
238
94
            self.is_full = true;
239
94
        } else if (100 * inflight) < (full_minpct * self.get()) {
240
12
            self.is_full = false;
241
14
        }
242
120
    }
243

            
244
    /// Return the SENDME increment value.
245
910
    pub(crate) fn sendme_inc(&self) -> u32 {
246
910
        self.params.sendme_inc()
247
910
    }
248

            
249
    /// Return the congestion window params.
250
    #[cfg(any(test, feature = "conflux"))]
251
16
    pub(crate) fn params(&self) -> &CongestionWindowParams {
252
16
        &self.params
253
16
    }
254
}
255

            
256
/// Congestion control state of a hop on a circuit.
257
///
258
/// This controls the entire logic of congestion control and circuit level SENDMEs.
259
pub(crate) struct CongestionControl {
260
    /// Which congestion control state are we in?
261
    state: State,
262
    /// This is the SENDME validator as in it keeps track of the circuit tag found within an
263
    /// authenticated SENDME cell. It can store the tags and validate a tag against our queue of
264
    /// expected values.
265
    sendme_validator: SendmeValidator<SendmeTag>,
266
    /// The RTT estimator for the circuit we are attached on.
267
    rtt: RoundtripTimeEstimator,
268
    /// The congestion control algorithm.
269
    algorithm: Box<dyn CongestionControlAlgorithm>,
270
}
271

            
272
impl CongestionControl {
273
    /// Construct a new CongestionControl
274
408
    pub(crate) fn new(params: &CongestionControlParams) -> Self {
275
408
        let state = State::default();
276
        // Use what the consensus tells us to use.
277
408
        let algorithm: Box<dyn CongestionControlAlgorithm> = match params.alg() {
278
400
            Algorithm::FixedWindow(p) => Box::new(fixed::FixedWindow::new(p.circ_window_start())),
279
8
            Algorithm::Vegas(ref p) => {
280
8
                let cwnd = CongestionWindow::new(params.cwnd_params());
281
8
                Box::new(vegas::Vegas::new(p, &state, cwnd))
282
            }
283
        };
284
408
        Self {
285
408
            algorithm,
286
408
            rtt: RoundtripTimeEstimator::new(params.rtt_params()),
287
408
            sendme_validator: SendmeValidator::new(),
288
408
            state,
289
408
        }
290
408
    }
291

            
292
    /// Return true iff the underlying algorithm uses stream level SENDMEs.
293
    /// At the moment, only FixedWindow uses it. It has been eliminated with Vegas.
294
80
    pub(crate) fn uses_stream_sendme(&self) -> bool {
295
80
        self.algorithm.uses_stream_sendme()
296
80
    }
297

            
298
    /// Return true iff a DATA cell is allowed to be sent based on the congestion control state.
299
9978
    pub(crate) fn can_send(&self) -> bool {
300
9978
        self.algorithm.can_send()
301
9978
    }
302

            
303
    /// Called when a SENDME cell is received.
304
    ///
305
    /// An error is returned if there is a protocol violation with regards to congestion control.
306
8
    pub(crate) fn note_sendme_received(
307
8
        &mut self,
308
8
        tag: SendmeTag,
309
8
        signals: CongestionSignals,
310
8
    ) -> Result<()> {
311
8
        // This MUST be the first thing that we do that is validate the SENDME. Any error leads to
312
8
        // closing the circuit.
313
8
        self.sendme_validator.validate(Some(tag))?;
314

            
315
        // Update our RTT estimate if the algorithm yields back a congestion window. RTT
316
        // measurements only make sense for a congestion window. For example, FixedWindow here
317
        // doesn't use it and so no need for the RTT.
318
4
        if let Some(cwnd) = self.algorithm.cwnd() {
319
            self.rtt
320
                .update(Instant::now(), &self.state, cwnd)
321
                .map_err(|e| Error::CircProto(e.to_string()))?;
322
4
        }
323

            
324
        // Notify the algorithm that we've received a SENDME.
325
4
        self.algorithm
326
4
            .sendme_received(&mut self.state, &mut self.rtt, signals)
327
8
    }
328

            
329
    /// Called when a SENDME cell is sent.
330
    pub(crate) fn note_sendme_sent(&mut self) -> Result<()> {
331
        self.algorithm.sendme_sent()
332
    }
333

            
334
    /// Called when a DATA cell is received.
335
    ///
336
    /// Returns true iff a SENDME should be sent false otherwise. An error is returned if there is
337
    /// a protocol violation with regards to flow or congestion control.
338
24
    pub(crate) fn note_data_received(&mut self) -> Result<bool> {
339
24
        self.algorithm.data_received()
340
24
    }
341

            
342
    /// Called when a DATA cell is sent.
343
    ///
344
    /// An error is returned if there is a protocol violation with regards to flow or congestion
345
    /// control.
346
2728
    pub(crate) fn note_data_sent<U>(&mut self, tag: &U) -> Result<()>
347
2728
    where
348
2728
        U: Clone + Into<SendmeTag>,
349
2728
    {
350
2728
        // Inform the algorithm that the data was just sent. This is important to be the very first
351
2728
        // thing so the congestion window can be updated accordingly making the following calls
352
2728
        // using the latest data.
353
2728
        self.algorithm.data_sent()?;
354

            
355
        // If next cell is a SENDME, we need to record the tag of this cell in order to validate
356
        // the next SENDME when it arrives.
357
2728
        if self.algorithm.is_next_cell_sendme() {
358
24
            self.sendme_validator.record(tag);
359
24
            // Only keep the SENDME timestamp if the algorithm has a congestion window.
360
24
            if self.algorithm.cwnd().is_some() {
361
                self.rtt.expect_sendme(Instant::now());
362
24
            }
363
2704
        }
364

            
365
2728
        Ok(())
366
2728
    }
367

            
368
    /// Return the number of in-flight cells (sent but awaiting SENDME ack).
369
    ///
370
    /// Optional, because not all algorithms track this.
371
    #[cfg(feature = "conflux")]
372
    pub(crate) fn inflight(&self) -> Option<u32> {
373
        self.algorithm.inflight()
374
    }
375

            
376
    /// Return the congestion window object.
377
    ///
378
    /// Optional, because not all algorithms track this.
379
    #[cfg(feature = "conflux")]
380
    pub(crate) fn cwnd(&self) -> Option<&CongestionWindow> {
381
        self.algorithm.cwnd()
382
    }
383

            
384
    /// Return a reference to the RTT estimator.
385
    ///
386
    /// Used for conflux, for choosing the best circuit to send on.
387
    #[cfg(feature = "conflux")]
388
    pub(crate) fn rtt(&self) -> &RoundtripTimeEstimator {
389
        &self.rtt
390
    }
391
}
392

            
393
#[cfg(test)]
394
mod test {
395
    // @@ begin test lint list maintained by maint/add_warning @@
396
    #![allow(clippy::bool_assert_comparison)]
397
    #![allow(clippy::clone_on_copy)]
398
    #![allow(clippy::dbg_macro)]
399
    #![allow(clippy::mixed_attributes_style)]
400
    #![allow(clippy::print_stderr)]
401
    #![allow(clippy::print_stdout)]
402
    #![allow(clippy::single_char_pattern)]
403
    #![allow(clippy::unwrap_used)]
404
    #![allow(clippy::unchecked_duration_subtraction)]
405
    #![allow(clippy::useless_vec)]
406
    #![allow(clippy::needless_pass_by_value)]
407
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
408

            
409
    use crate::congestion::test_utils::new_cwnd;
410

            
411
    use super::CongestionControl;
412
    use tor_cell::relaycell::msg::SendmeTag;
413

            
414
    impl CongestionControl {
415
        /// For testing: get a copy of the current send window, and the
416
        /// expected incoming tags.
417
        pub(crate) fn send_window_and_expected_tags(&self) -> (u32, Vec<SendmeTag>) {
418
            (
419
                self.algorithm.send_window(),
420
                self.sendme_validator.expected_tags(),
421
            )
422
        }
423
    }
424

            
425
    #[test]
426
    fn test_cwnd() {
427
        let mut cwnd = new_cwnd();
428

            
429
        // Validate the getters are coherent with initialization.
430
        assert_eq!(cwnd.get(), cwnd.params().cwnd_init());
431
        assert_eq!(cwnd.min(), cwnd.params().cwnd_min());
432
        assert_eq!(cwnd.increment(), cwnd.params().cwnd_inc());
433
        assert_eq!(cwnd.increment_rate(), cwnd.params().cwnd_inc_rate());
434
        assert_eq!(cwnd.sendme_inc(), cwnd.params().sendme_inc());
435
        assert!(!cwnd.is_full());
436

            
437
        // Validate changes.
438
        cwnd.inc();
439
        assert_eq!(
440
            cwnd.get(),
441
            cwnd.params().cwnd_init() + cwnd.params().cwnd_inc()
442
        );
443
        cwnd.dec();
444
        assert_eq!(cwnd.get(), cwnd.params().cwnd_init());
445
    }
446
}