1use std::cmp::{max, min};
4use std::collections::VecDeque;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::time::{Duration, Instant};
7
8use super::params::RoundTripEstimatorParams;
9use super::{CongestionWindow, State};
10
11use thiserror::Error;
12use tor_error::{ErrorKind, HasKind};
13
14#[derive(Error, Debug, Clone)]
16#[non_exhaustive]
17pub(crate) enum Error {
18 #[error("Informed of a SENDME we weren't expecting")]
21 MismatchedEstimationCall,
22}
23
24impl HasKind for Error {
25 fn kind(&self) -> ErrorKind {
26 use Error as E;
27 match self {
28 E::MismatchedEstimationCall => ErrorKind::TorProtocolViolation,
29 }
30 }
31}
32
33#[derive(Debug)]
35#[allow(dead_code)]
36pub(crate) struct RoundtripTimeEstimator {
37 sendme_expected_from: VecDeque<Instant>,
45 last_rtt: Option<Duration>,
49 ewma_rtt: Option<Duration>,
53 min_rtt: Option<Duration>,
57 max_rtt: Option<Duration>,
61 params: RoundTripEstimatorParams,
63 clock_stalled: AtomicBool,
66}
67
68#[allow(dead_code)]
69impl RoundtripTimeEstimator {
70 pub(crate) fn new(params: &RoundTripEstimatorParams) -> Self {
73 Self {
74 sendme_expected_from: Default::default(),
75 last_rtt: None,
76 ewma_rtt: None,
77 min_rtt: None,
78 max_rtt: None,
79 params: params.clone(),
80 clock_stalled: AtomicBool::default(),
81 }
82 }
83
84 pub(crate) fn is_ready(&self) -> bool {
86 !self.clock_stalled() && self.last_rtt.is_some()
87 }
88
89 pub(crate) fn clock_stalled(&self) -> bool {
91 self.clock_stalled.load(Ordering::SeqCst)
92 }
93
94 pub(crate) fn ewma_rtt_usec(&self) -> Option<u32> {
96 self.ewma_rtt
97 .map(|rtt| u32::try_from(rtt.as_micros()).ok().unwrap_or(u32::MAX))
98 }
99
100 pub(crate) fn min_rtt_usec(&self) -> Option<u32> {
102 self.min_rtt
103 .map(|rtt| u32::try_from(rtt.as_micros()).ok().unwrap_or(u32::MAX))
104 }
105
106 pub(crate) fn expect_sendme(&mut self, now: Instant) {
109 self.sendme_expected_from.push_back(now);
110 }
111
112 fn can_crosscheck_with_current_estimate(&self, in_slow_start: bool) -> bool {
118 !in_slow_start && self.ewma_rtt.is_some()
122 }
123
124 fn is_clock_stalled(&self, raw_rtt: Duration, in_slow_start: bool) -> bool {
127 if raw_rtt.is_zero() {
128 self.clock_stalled.store(true, Ordering::SeqCst);
130 true
131 } else if self.can_crosscheck_with_current_estimate(in_slow_start) {
132 let ewma_rtt = self
133 .ewma_rtt
134 .expect("ewma_rtt was not checked by can_crosscheck_with_current_estimate?!");
135
136 const DELTA_DISCREPANCY_RATIO_MAX: u32 = 5000;
140 if raw_rtt > ewma_rtt * DELTA_DISCREPANCY_RATIO_MAX {
142 true
149 } else if ewma_rtt > raw_rtt * DELTA_DISCREPANCY_RATIO_MAX {
150 self.clock_stalled.load(Ordering::SeqCst)
153 } else {
154 self.clock_stalled.store(false, Ordering::SeqCst);
156 false
157 }
158 } else {
159 false
161 }
162 }
163
164 pub(crate) fn update(
178 &mut self,
179 now: Instant,
180 state: &State,
181 cwnd: &CongestionWindow,
182 ) -> Result<(), Error> {
183 let data_sent_at = self
184 .sendme_expected_from
185 .pop_front()
186 .ok_or(Error::MismatchedEstimationCall)?;
187 let raw_rtt = now.saturating_duration_since(data_sent_at);
188
189 if self.is_clock_stalled(raw_rtt, state.in_slow_start()) {
190 return Ok(());
191 }
192
193 self.max_rtt = self.max_rtt.max(Some(raw_rtt));
194 self.last_rtt = Some(raw_rtt);
195
196 let ewma_n = u64::from(if state.in_slow_start() {
198 self.params.ewma_ss_max()
199 } else {
200 min(
201 (cwnd.update_rate(state) * (self.params.ewma_cwnd_pct().as_percent())) / 100,
202 self.params.ewma_max(),
203 )
204 });
205 let ewma_n = max(ewma_n, 2);
206
207 let raw_rtt_usec = raw_rtt.as_micros() as u64;
209 let prev_ewma_rtt_usec = self.ewma_rtt.map(|rtt| rtt.as_micros() as u64);
210
211 let new_ewma_rtt_usec = match prev_ewma_rtt_usec {
219 None => raw_rtt_usec,
220 Some(prev_ewma_rtt_usec) => {
221 ((raw_rtt_usec * 2) + ((ewma_n - 1) * prev_ewma_rtt_usec)) / (ewma_n + 1)
222 }
223 };
224 let ewma_rtt = Duration::from_micros(new_ewma_rtt_usec);
225 self.ewma_rtt = Some(ewma_rtt);
226
227 let Some(min_rtt) = self.min_rtt else {
228 self.min_rtt = self.ewma_rtt;
229 return Ok(());
230 };
231
232 if cwnd.get() == cwnd.min() && !state.in_slow_start() {
233 let max = max(ewma_rtt, min_rtt).as_micros() as u64;
235 let min = min(ewma_rtt, min_rtt).as_micros() as u64;
236 let rtt_reset_pct = u64::from(self.params.rtt_reset_pct().as_percent());
237 let min_rtt = Duration::from_micros(
238 (rtt_reset_pct * max / 100) + (100 - rtt_reset_pct) * min / 100,
239 );
240
241 self.min_rtt = Some(min_rtt);
242 } else if self.ewma_rtt < self.min_rtt {
243 self.min_rtt = self.ewma_rtt;
244 }
245
246 Ok(())
247 }
248}
249
250#[cfg(test)]
251#[allow(clippy::print_stderr)]
252mod test {
253 #![allow(clippy::bool_assert_comparison)]
255 #![allow(clippy::clone_on_copy)]
256 #![allow(clippy::dbg_macro)]
257 #![allow(clippy::mixed_attributes_style)]
258 #![allow(clippy::print_stderr)]
259 #![allow(clippy::print_stdout)]
260 #![allow(clippy::single_char_pattern)]
261 #![allow(clippy::unwrap_used)]
262 #![allow(clippy::unchecked_duration_subtraction)]
263 #![allow(clippy::useless_vec)]
264 #![allow(clippy::needless_pass_by_value)]
265 use std::time::{Duration, Instant};
268
269 use crate::congestion::test_utils::{new_cwnd, new_rtt_estimator};
270
271 use super::*;
272
273 #[derive(Debug)]
274 struct RttTestSample {
275 sent_usec_in: u64,
276 sendme_received_usec_in: u64,
277 cwnd_in: u32,
278 ss_in: bool,
279 last_rtt_usec_out: u64,
280 ewma_rtt_usec_out: u64,
281 min_rtt_usec_out: u64,
282 }
283
284 impl From<[u64; 7]> for RttTestSample {
285 fn from(arr: [u64; 7]) -> Self {
286 Self {
287 sent_usec_in: arr[0],
288 sendme_received_usec_in: arr[1],
289 cwnd_in: arr[2] as u32,
290 ss_in: arr[3] == 1,
291 last_rtt_usec_out: arr[4],
292 ewma_rtt_usec_out: arr[5],
293 min_rtt_usec_out: arr[6],
294 }
295 }
296 }
297 impl RttTestSample {
298 fn test(&self, estimator: &mut RoundtripTimeEstimator, start: Instant) {
299 let state = if self.ss_in {
300 State::SlowStart
301 } else {
302 State::Steady
303 };
304 let mut cwnd = new_cwnd();
305 cwnd.set(self.cwnd_in);
306 let sent = start + Duration::from_micros(self.sent_usec_in);
307 let sendme_received = start + Duration::from_micros(self.sendme_received_usec_in);
308
309 estimator.expect_sendme(sent);
310 estimator
311 .update(sendme_received, &state, &cwnd)
312 .expect("Error on RTT update");
313 assert_eq!(
314 estimator.last_rtt,
315 Some(Duration::from_micros(self.last_rtt_usec_out))
316 );
317 assert_eq!(
318 estimator.ewma_rtt,
319 Some(Duration::from_micros(self.ewma_rtt_usec_out))
320 );
321 assert_eq!(
322 estimator.min_rtt,
323 Some(Duration::from_micros(self.min_rtt_usec_out))
324 );
325 }
326 }
327
328 #[test]
329 fn test_vectors() {
330 let mut rtt = new_rtt_estimator();
331 let now = Instant::now();
332 let vectors = [
334 [100000, 200000, 124, 1, 100000, 100000, 100000],
335 [200000, 300000, 124, 1, 100000, 100000, 100000],
336 [350000, 500000, 124, 1, 150000, 133333, 100000],
337 [500000, 550000, 124, 1, 50000, 77777, 77777],
338 [600000, 700000, 124, 1, 100000, 92592, 77777],
339 [700000, 750000, 124, 1, 50000, 64197, 64197],
340 [750000, 875000, 124, 0, 125000, 104732, 104732],
341 [875000, 900000, 124, 0, 25000, 51577, 104732],
342 [900000, 950000, 200, 0, 50000, 50525, 50525],
343 ];
344 for vect in vectors {
345 let vect = RttTestSample::from(vect);
346 eprintln!("Testing vector: {:?}", vect);
347 vect.test(&mut rtt, now);
348 }
349 }
350}