1use std::cmp::{max, min};
4use std::collections::VecDeque;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::time::{Duration, Instant};
7
8use super::params::RoundTripEstimatorParams;
9use super::{CongestionWindow, State};
10
11use thiserror::Error;
12use tor_error::{ErrorKind, HasKind};
13
14#[derive(Error, Debug, Clone)]
16#[non_exhaustive]
17pub(crate) enum Error {
18 #[error("Informed of a SENDME we weren't expecting")]
21 MismatchedEstimationCall,
22}
23
24impl HasKind for Error {
25 fn kind(&self) -> ErrorKind {
26 use Error as E;
27 match self {
28 E::MismatchedEstimationCall => ErrorKind::TorProtocolViolation,
29 }
30 }
31}
32
33#[derive(Debug)]
35#[allow(dead_code)]
36pub(crate) struct RoundtripTimeEstimator {
37 sendme_expected_from: VecDeque<Instant>,
45 last_rtt: Duration,
47 ewma_rtt: Duration,
51 min_rtt: Duration,
53 max_rtt: Duration,
55 params: RoundTripEstimatorParams,
57 clock_stalled: AtomicBool,
60}
61
62#[allow(dead_code)]
63impl RoundtripTimeEstimator {
64 pub(crate) fn new(params: &RoundTripEstimatorParams) -> Self {
67 Self {
68 sendme_expected_from: Default::default(),
69 last_rtt: Default::default(),
70 ewma_rtt: Default::default(),
71 min_rtt: Duration::ZERO,
72 max_rtt: Default::default(),
73 params: params.clone(),
74 clock_stalled: AtomicBool::default(),
75 }
76 }
77
78 pub(crate) fn is_ready(&self) -> bool {
80 !self.clock_stalled() && !self.last_rtt.is_zero()
81 }
82
83 pub(crate) fn clock_stalled(&self) -> bool {
85 self.clock_stalled.load(Ordering::SeqCst)
86 }
87
88 pub(crate) fn ewma_rtt_usec(&self) -> u32 {
90 u32::try_from(self.ewma_rtt.as_micros()).unwrap_or(u32::MAX)
91 }
92
93 pub(crate) fn min_rtt_usec(&self) -> u32 {
95 u32::try_from(self.min_rtt.as_micros()).unwrap_or(u32::MAX)
96 }
97
98 pub(crate) fn expect_sendme(&mut self, now: Instant) {
101 self.sendme_expected_from.push_back(now);
102 }
103
104 fn can_crosscheck_with_current_estimate(&self, in_slow_start: bool) -> bool {
110 !(in_slow_start || self.ewma_rtt.is_zero())
114 }
115
116 fn is_clock_stalled(&self, raw_rtt: Duration, in_slow_start: bool) -> bool {
119 if raw_rtt.is_zero() {
120 self.clock_stalled.store(true, Ordering::SeqCst);
122 true
123 } else if self.can_crosscheck_with_current_estimate(in_slow_start) {
124 const DELTA_DISCREPANCY_RATIO_MAX: u32 = 5000;
128 if raw_rtt > self.ewma_rtt * DELTA_DISCREPANCY_RATIO_MAX {
130 true
137 } else if self.ewma_rtt > raw_rtt * DELTA_DISCREPANCY_RATIO_MAX {
138 self.clock_stalled.load(Ordering::SeqCst)
141 } else {
142 self.clock_stalled.store(false, Ordering::SeqCst);
144 false
145 }
146 } else {
147 false
149 }
150 }
151
152 pub(crate) fn update(
166 &mut self,
167 now: Instant,
168 state: &State,
169 cwnd: &CongestionWindow,
170 ) -> Result<(), Error> {
171 let data_sent_at = self
172 .sendme_expected_from
173 .pop_front()
174 .ok_or(Error::MismatchedEstimationCall)?;
175 let raw_rtt = now.saturating_duration_since(data_sent_at);
176
177 if self.is_clock_stalled(raw_rtt, state.in_slow_start()) {
178 return Ok(());
179 }
180
181 self.max_rtt = self.max_rtt.max(raw_rtt);
182 self.last_rtt = raw_rtt;
183
184 let ewma_n = u64::from(if state.in_slow_start() {
186 self.params.ewma_ss_max()
187 } else {
188 min(
189 (cwnd.update_rate(state) * (self.params.ewma_cwnd_pct().as_percent())) / 100,
190 self.params.ewma_max(),
191 )
192 });
193 let ewma_n = max(ewma_n, 2);
194
195 let raw_rtt_usec = raw_rtt.as_micros() as u64;
197 let prev_ewma_rtt_usec = self.ewma_rtt.as_micros() as u64;
198
199 let new_ewma_rtt_usec = if prev_ewma_rtt_usec == 0 {
207 raw_rtt_usec
208 } else {
209 ((raw_rtt_usec * 2) + ((ewma_n - 1) * prev_ewma_rtt_usec)) / (ewma_n + 1)
210 };
211 self.ewma_rtt = Duration::from_micros(new_ewma_rtt_usec);
212
213 if self.min_rtt.is_zero() {
214 self.min_rtt = self.ewma_rtt;
215 } else if cwnd.get() == cwnd.min() && !state.in_slow_start() {
216 let max = max(self.ewma_rtt, self.min_rtt).as_micros() as u64;
218 let min = min(self.ewma_rtt, self.min_rtt).as_micros() as u64;
219 let rtt_reset_pct = u64::from(self.params.rtt_reset_pct().as_percent());
220 self.min_rtt = Duration::from_micros(
221 (rtt_reset_pct * max / 100) + (100 - rtt_reset_pct) * min / 100,
222 );
223 } else if self.ewma_rtt < self.min_rtt {
224 self.min_rtt = self.ewma_rtt;
225 }
226
227 Ok(())
228 }
229}
230
231#[cfg(test)]
232#[allow(clippy::print_stderr)]
233mod test {
234 #![allow(clippy::bool_assert_comparison)]
236 #![allow(clippy::clone_on_copy)]
237 #![allow(clippy::dbg_macro)]
238 #![allow(clippy::mixed_attributes_style)]
239 #![allow(clippy::print_stderr)]
240 #![allow(clippy::print_stdout)]
241 #![allow(clippy::single_char_pattern)]
242 #![allow(clippy::unwrap_used)]
243 #![allow(clippy::unchecked_duration_subtraction)]
244 #![allow(clippy::useless_vec)]
245 #![allow(clippy::needless_pass_by_value)]
246 use std::time::{Duration, Instant};
249
250 use crate::congestion::test_utils::{new_cwnd, new_rtt_estimator};
251
252 use super::*;
253
254 #[derive(Debug)]
255 struct RttTestSample {
256 sent_usec_in: u64,
257 sendme_received_usec_in: u64,
258 cwnd_in: u32,
259 ss_in: bool,
260 last_rtt_usec_out: u64,
261 ewma_rtt_usec_out: u64,
262 min_rtt_usec_out: u64,
263 }
264
265 impl From<[u64; 7]> for RttTestSample {
266 fn from(arr: [u64; 7]) -> Self {
267 Self {
268 sent_usec_in: arr[0],
269 sendme_received_usec_in: arr[1],
270 cwnd_in: arr[2] as u32,
271 ss_in: arr[3] == 1,
272 last_rtt_usec_out: arr[4],
273 ewma_rtt_usec_out: arr[5],
274 min_rtt_usec_out: arr[6],
275 }
276 }
277 }
278 impl RttTestSample {
279 fn test(&self, estimator: &mut RoundtripTimeEstimator, start: Instant) {
280 let state = if self.ss_in {
281 State::SlowStart
282 } else {
283 State::Steady
284 };
285 let mut cwnd = new_cwnd();
286 cwnd.set(self.cwnd_in);
287 let sent = start + Duration::from_micros(self.sent_usec_in);
288 let sendme_received = start + Duration::from_micros(self.sendme_received_usec_in);
289
290 estimator.expect_sendme(sent);
291 estimator
292 .update(sendme_received, &state, &cwnd)
293 .expect("Error on RTT update");
294 assert_eq!(
295 estimator.last_rtt,
296 Duration::from_micros(self.last_rtt_usec_out)
297 );
298 assert_eq!(
299 estimator.ewma_rtt,
300 Duration::from_micros(self.ewma_rtt_usec_out)
301 );
302 assert_eq!(
303 estimator.min_rtt,
304 Duration::from_micros(self.min_rtt_usec_out)
305 );
306 }
307 }
308
309 #[test]
310 fn test_vectors() {
311 let mut rtt = new_rtt_estimator();
312 let now = Instant::now();
313 let vectors = [
315 [100000, 200000, 124, 1, 100000, 100000, 100000],
316 [200000, 300000, 124, 1, 100000, 100000, 100000],
317 [350000, 500000, 124, 1, 150000, 133333, 100000],
318 [500000, 550000, 124, 1, 50000, 77777, 77777],
319 [600000, 700000, 124, 1, 100000, 92592, 77777],
320 [700000, 750000, 124, 1, 50000, 64197, 64197],
321 [750000, 875000, 124, 0, 125000, 104732, 104732],
322 [875000, 900000, 124, 0, 25000, 51577, 104732],
323 [900000, 950000, 200, 0, 50000, 50525, 50525],
324 ];
325 for vect in vectors {
326 let vect = RttTestSample::from(vect);
327 eprintln!("Testing vector: {:?}", vect);
328 vect.test(&mut rtt, now);
329 }
330 }
331}