tor_circmgr/timeouts/
pareto.rs

1//! Implement Tor's sort-of-Pareto estimator for circuit build timeouts.
2//!
3//! Our build times don't truly follow a
4//! [Pareto](https://en.wikipedia.org/wiki/Pareto_distribution)
5//! distribution; instead they seem to be closer to a
6//! [Fréchet](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distribution)
7//! distribution.  But those are hard to work with, and we only care
8//! about the right tail, so we're using Pareto instead.
9//!
10//! This estimator also includes several heuristics and kludges to
11//! try to behave better on unreliable networks.
12//! For more information on the exact algorithms and their rationales,
13//! see [`path-spec.txt`](https://gitlab.torproject.org/tpo/core/torspec/-/blob/master/path-spec.txt).
14
15use bounded_vec_deque::BoundedVecDeque;
16use serde::{Deserialize, Serialize};
17use std::collections::{BTreeMap, HashMap};
18use std::time::Duration;
19use tor_netdir::params::NetParameters;
20
21use super::Action;
22use tor_persist::JsonValue;
23
24/// How many circuit build time observations do we record?
25const TIME_HISTORY_LEN: usize = 1000;
26
27/// How many circuit success-versus-timeout observations do we record
28/// by default?
29const SUCCESS_HISTORY_DEFAULT_LEN: usize = 20;
30
31/// How many milliseconds wide is each bucket in our histogram?
32const BUCKET_WIDTH_MSEC: u32 = 10;
33
34/// A circuit build time or timeout duration, measured in milliseconds.
35///
36/// Requires that we don't care about tracking timeouts above u32::MAX
37/// milliseconds (about 49 days).
38#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
39#[serde(transparent)]
40struct MsecDuration(u32);
41
42impl MsecDuration {
43    /// Convert a Duration into a MsecDuration, saturating
44    /// extra-high values to u32::MAX milliseconds.
45    fn new_saturating(d: &Duration) -> Self {
46        let msec = std::cmp::min(d.as_millis(), u128::from(u32::MAX)) as u32;
47        MsecDuration(msec)
48    }
49}
50
51/// Module to hold calls to const_assert.
52///
53/// This is a separate module so we can change the clippy warnings on it.
54#[allow(clippy::checked_conversions)]
55mod assertion {
56    use static_assertions::const_assert;
57    // If this assertion is untrue, then we can't safely use u16 fields in
58    // time_histogram.
59    const_assert!(super::TIME_HISTORY_LEN <= u16::MAX as usize);
60}
61
62/// A history of circuit timeout observations, used to estimate our
63/// likely circuit timeouts.
64#[derive(Debug, Clone)]
65struct History {
66    /// Our most recent observed circuit construction times.
67    ///
68    /// For the purpose of this estimator, a circuit counts as
69    /// "constructed" when a certain "significant" hop (typically the third)
70    /// is completed.
71    time_history: BoundedVecDeque<MsecDuration>,
72
73    /// A histogram representation of the values in [`History::time_history`].
74    ///
75    /// This histogram is implemented as a sparse map from the center
76    /// value of each histogram bucket to the number of entries in
77    /// that bucket.  It is completely derivable from time_history; we
78    /// keep it separate here for efficiency.
79    time_histogram: BTreeMap<MsecDuration, u16>,
80
81    /// Our most recent circuit timeout statuses.
82    ///
83    /// Each `true` value represents a successfully completed circuit
84    /// (all hops).  Each `false` value represents a circuit that
85    /// timed out after having completed at least one hop.
86    success_history: BoundedVecDeque<bool>,
87}
88
89impl History {
90    /// Initialize a new empty `History` with no observations.
91    fn new_empty() -> Self {
92        History {
93            time_history: BoundedVecDeque::new(TIME_HISTORY_LEN),
94            time_histogram: BTreeMap::new(),
95            success_history: BoundedVecDeque::new(SUCCESS_HISTORY_DEFAULT_LEN),
96        }
97    }
98
99    /// Remove all observations from this `History`.
100    fn clear(&mut self) {
101        self.time_history.clear();
102        self.time_histogram.clear();
103        self.success_history.clear();
104    }
105
106    /// Change the number of successes to record in our success
107    /// history to `n`.
108    fn set_success_history_len(&mut self, n: usize) {
109        if n < self.success_history.len() {
110            self.success_history
111                .drain(0..(self.success_history.len() - n));
112        }
113        self.success_history.set_max_len(n);
114    }
115
116    /// Change the number of circuit time observations to record in
117    /// our time history to `n`.
118    ///
119    /// This is a testing-only function.
120    #[cfg(test)]
121    fn set_time_history_len(&mut self, n: usize) {
122        self.time_history.set_max_len(n);
123    }
124
125    /// Construct a new `History` from an iterator representing a sparse
126    /// histogram of values.
127    ///
128    /// The input must be a sequence of `(D,N)` tuples, where each `D`
129    /// represents a circuit build duration, and `N` represents the
130    /// number of observations with that duration.
131    ///
132    /// These observations are shuffled into a random order, then
133    /// added to a new History.
134    fn from_sparse_histogram<I>(iter: I) -> Self
135    where
136        I: Iterator<Item = (MsecDuration, u16)>,
137    {
138        use rand::seq::{IteratorRandom, SliceRandom};
139        let mut rng = rand::rng();
140
141        // We want to build a vector with the elements of the old histogram in
142        // random order, but we want to defend ourselves against bogus inputs
143        // that would take too much RAM.
144        let mut observations = iter
145            .take(TIME_HISTORY_LEN) // limit number of bins
146            .flat_map(|(dur, n)| std::iter::repeat_n(dur, n as usize))
147            .choose_multiple(&mut rng, TIME_HISTORY_LEN);
148        // IteratorRand::choose_multiple doesn't guarantee anything about the order of its output.
149        observations.shuffle(&mut rng);
150
151        let mut result = History::new_empty();
152        for obs in observations {
153            result.add_time(obs);
154        }
155
156        result
157    }
158
159    /// Return an iterator yielding a sparse histogram of the circuit build
160    /// time values in this `History`.
161    ///
162    /// Each histogram entry is a `(D,N)` tuple, where `D` is the
163    /// center of a histogram bucket, and `N` is the number of
164    /// observations in that bucket.
165    ///
166    /// Buckets with `N=0` are omitted.  Buckets are yielded in order.
167    fn sparse_histogram(&self) -> impl Iterator<Item = (MsecDuration, u16)> + '_ {
168        self.time_histogram.iter().map(|(d, n)| (*d, *n))
169    }
170
171    /// Return the center value for the bucket containing `time`.
172    fn bucket_center(time: MsecDuration) -> MsecDuration {
173        let idx = time.0 / BUCKET_WIDTH_MSEC;
174        let msec = (idx * BUCKET_WIDTH_MSEC) + (BUCKET_WIDTH_MSEC) / 2;
175        MsecDuration(msec)
176    }
177
178    /// Increment the histogram bucket containing `time` by one.
179    fn inc_bucket(&mut self, time: MsecDuration) {
180        let center = History::bucket_center(time);
181        *self.time_histogram.entry(center).or_insert(0) += 1;
182    }
183
184    /// Decrement the histogram bucket containing `time` by one, removing
185    /// it if it becomes 0.
186    fn dec_bucket(&mut self, time: MsecDuration) {
187        use std::collections::btree_map::Entry;
188        let center = History::bucket_center(time);
189        match self.time_histogram.entry(center) {
190            Entry::Vacant(_) => {
191                // this is a bug.
192            }
193            Entry::Occupied(e) if e.get() <= &1 => {
194                e.remove();
195            }
196            Entry::Occupied(mut e) => {
197                *e.get_mut() -= 1;
198            }
199        }
200    }
201
202    /// Add `time` to our list of circuit build time observations, and
203    /// adjust the histogram accordingly.
204    fn add_time(&mut self, time: MsecDuration) {
205        match self.time_history.push_back(time) {
206            None => {}
207            Some(removed_time) => {
208                // `removed_time` just fell off the end of the deque:
209                // remove it from the histogram.
210                self.dec_bucket(removed_time);
211            }
212        }
213        self.inc_bucket(time);
214    }
215
216    /// Return the number of observations in our time history.
217    ///
218    /// This will always be `<= TIME_HISTORY_LEN`.
219    fn n_times(&self) -> usize {
220        self.time_history.len()
221    }
222
223    /// Record a success (true) or timeout (false) in our record of whether
224    /// circuits timed out or not.
225    fn add_success(&mut self, succeeded: bool) {
226        self.success_history.push_back(succeeded);
227    }
228
229    /// Return the number of timeouts recorded in our success history.
230    fn n_recent_timeouts(&self) -> usize {
231        self.success_history.iter().filter(|x| !**x).count()
232    }
233
234    /// Helper: return the `n` most frequent histogram bins.
235    fn n_most_frequent_bins(&self, n: usize) -> Vec<(MsecDuration, u16)> {
236        use itertools::Itertools;
237        // we use cmp::Reverse here so that we can use k_smallest as
238        // if it were "k_largest".
239        use std::cmp::Reverse;
240
241        // We want the buckets that have the _largest_ counts; we want
242        // to break ties in favor of the _smallest_ values.  So we
243        // apply Reverse only to the counts before passing the tuples
244        // to k_smallest.
245
246        self.sparse_histogram()
247            .map(|(center, count)| (Reverse(count), center))
248            // (k_smallest runs in O(n_bins * lg(n))
249            .k_smallest(n)
250            .map(|(Reverse(count), center)| (center, count))
251            .collect()
252    }
253
254    /// Return an estimator for the `X_m` of our Pareto distribution,
255    /// by looking at the `n_modes` most frequently filled histogram
256    /// bins.
257    ///
258    /// It is not a true `X_m` value, since there are definitely
259    /// values less than this, but it seems to work as a decent
260    /// heuristic.
261    ///
262    /// Return `None` if we have no observations.
263    fn estimate_xm(&self, n_modes: usize) -> Option<u32> {
264        // From path-spec:
265        //   Tor clients compute the Xm parameter using the weighted
266        //   average of the midpoints of the 'cbtnummodes' (10)
267        //   most frequently occurring 10ms histogram bins.
268
269        // The most frequently used bins.
270        let bins = self.n_most_frequent_bins(n_modes);
271        // Total number of observations in these bins.
272        let n_observations: u16 = bins.iter().map(|(_, n)| n).sum();
273        // Sum of all observations in these bins.
274        let total_observations: u64 = bins
275            .iter()
276            .map(|(d, n)| u64::from(d.0 * u32::from(*n)))
277            .sum();
278
279        if n_observations == 0 {
280            None
281        } else {
282            Some((total_observations / u64::from(n_observations)) as u32)
283        }
284    }
285
286    /// Compute a maximum-likelihood pareto distribution based on this
287    /// history, computing `X_m` based on the `n_modes` most frequent
288    /// histograms.
289    ///
290    /// Return None if we have no observations.
291    fn pareto_estimate(&self, n_modes: usize) -> Option<ParetoDist> {
292        let xm = self.estimate_xm(n_modes)?;
293
294        // From path-spec:
295        //     alpha = n/(Sum_n{ln(MAX(Xm, x_i))} - n*ln(Xm))
296
297        let n = self.time_history.len();
298        let sum_of_log_observations: f64 = self
299            .time_history
300            .iter()
301            .map(|m| f64::from(std::cmp::max(m.0, xm)).ln())
302            .sum();
303        let sum_of_log_xm = (n as f64) * f64::from(xm).ln();
304
305        // We're computing 1/alpha here, instead of alpha.  This avoids
306        // division by zero, and has the advantage of being what our
307        // quantile estimator actually needs.
308        let inv_alpha = (sum_of_log_observations - sum_of_log_xm) / (n as f64);
309
310        Some(ParetoDist {
311            x_m: f64::from(xm),
312            inv_alpha,
313        })
314    }
315}
316
317/// A Pareto distribution, for use in estimating timeouts.
318///
319/// Values are represented by a number of milliseconds.
320#[derive(Debug)]
321struct ParetoDist {
322    /// The lower bound for the pareto distribution.
323    x_m: f64,
324    /// The inverse of the alpha parameter in the pareto distribution.
325    ///
326    /// (We use 1/alpha here to save a step in [`ParetoDist::quantile`].
327    inv_alpha: f64,
328}
329
330impl ParetoDist {
331    /// Compute an inverse CDF for this distribution.
332    ///
333    /// Given a `q` value between 0 and 1, compute a distribution `v`
334    /// value such that `q` of the Pareto Distribution is expected to
335    /// be less than `v`.
336    ///
337    /// If `q` is out of bounds, it is clamped to [0.0, 1.0].
338    fn quantile(&self, q: f64) -> f64 {
339        let q = q.clamp(0.0, 1.0);
340        self.x_m / ((1.0 - q).powf(self.inv_alpha))
341    }
342}
343
344/// A set of parameters determining the behavior of a ParetoTimeoutEstimator.
345///
346/// These are typically derived from a set of consensus parameters.
347#[derive(Clone, Debug)]
348pub(crate) struct Params {
349    /// Should we use our estimates when deciding on circuit timeouts.
350    ///
351    /// When this is false, our timeouts are fixed to the default.
352    use_estimates: bool,
353    /// How many observations must we have made before we can use our
354    /// Pareto estimators to guess a good set of timeouts?
355    min_observations: u16,
356    /// Which hop is the "significant hop" we should use when recording circuit
357    /// build times?  (Watch out! This is zero-indexed.)
358    significant_hop: u8,
359    /// A quantile (in range [0.0,1.0]) describing a point in the
360    /// Pareto distribution to use when determining when a circuit
361    /// should be treated as having "timed out".
362    ///
363    /// (A "timed out" circuit continues building for measurement
364    /// purposes, but can't be used for traffic.)
365    timeout_quantile: f64,
366    /// A quantile (in range [0.0,1.0]) describing a point in the Pareto
367    /// distribution to use when determining when a circuit should be
368    /// "abandoned".
369    ///
370    /// (An "abandoned" circuit is stopped entirely, and not included
371    /// in measurements.
372    abandon_quantile: f64,
373    /// Default values to return from the `timeouts` function when we
374    /// have no observations.
375    default_thresholds: (Duration, Duration),
376    /// Number of histogram buckets to use when determining the Xm estimate.
377    ///
378    /// (See [`History::estimate_xm`] for details.)
379    n_modes_for_xm: usize,
380    /// How many entries do we record in our success/timeout history?
381    success_history_len: usize,
382    /// How many timeouts should we allow in our success/timeout history
383    /// before we assume that network has changed in a way that makes
384    /// our estimates completely wrong?
385    reset_after_timeouts: usize,
386    /// Minimum base timeout to ever infer or return.
387    min_timeout: Duration,
388}
389
390impl Default for Params {
391    fn default() -> Self {
392        Params {
393            use_estimates: true,
394            min_observations: 100,
395            significant_hop: 2,
396            timeout_quantile: 0.80,
397            abandon_quantile: 0.99,
398            default_thresholds: (Duration::from_secs(60), Duration::from_secs(60)),
399            n_modes_for_xm: 10,
400            success_history_len: SUCCESS_HISTORY_DEFAULT_LEN,
401            reset_after_timeouts: 18,
402            min_timeout: Duration::from_millis(10),
403        }
404    }
405}
406
407impl From<&NetParameters> for Params {
408    fn from(p: &NetParameters) -> Params {
409        // Because of the underlying bounds, the "unwrap_or_else"
410        // conversions here should be impossible, and the "as"
411        // conversions should always be in-range.
412
413        let timeout = p
414            .cbt_initial_timeout
415            .try_into()
416            .unwrap_or_else(|_| Duration::from_secs(60));
417        let learning_disabled: bool = p.cbt_learning_disabled.into();
418        Params {
419            use_estimates: !learning_disabled,
420            min_observations: p.cbt_min_circs_for_estimate.get() as u16,
421            significant_hop: 2,
422            timeout_quantile: p.cbt_timeout_quantile.as_fraction(),
423            abandon_quantile: p.cbt_abandon_quantile.as_fraction(),
424            default_thresholds: (timeout, timeout),
425            n_modes_for_xm: p.cbt_num_xm_modes.get() as usize,
426            success_history_len: p.cbt_success_count.get() as usize,
427            reset_after_timeouts: p.cbt_max_timeouts.get() as usize,
428            min_timeout: p
429                .cbt_min_timeout
430                .try_into()
431                .unwrap_or_else(|_| Duration::from_millis(10)),
432        }
433    }
434}
435
436/// Tor's default circuit build timeout estimator.
437///
438/// This object records a set of observed circuit build times, and
439/// uses it to determine good values for how long we should allow
440/// circuits to build.
441///
442/// For full details of the algorithms used, see
443/// [`path-spec.txt`](https://gitlab.torproject.org/tpo/core/torspec/-/blob/master/path-spec.txt).
444pub(crate) struct ParetoTimeoutEstimator {
445    /// Our observations for circuit build times and success/failure
446    /// history.
447    history: History,
448
449    /// Our most recent timeout estimate, if we have one that is
450    /// up-to-date.
451    ///
452    /// (We reset this to None whenever we get a new observation.)
453    timeouts: Option<(Duration, Duration)>,
454
455    /// The timeouts that we use when we do not have sufficient observations
456    /// to conclude anything about our circuit build times.
457    ///
458    /// These start out as `p.default_thresholds`, but can be adjusted
459    /// depending on how many timeouts we've been seeing.
460    fallback_timeouts: (Duration, Duration),
461
462    /// A set of parameters to use in computing circuit build timeout
463    /// estimates.
464    p: Params,
465}
466
467impl Default for ParetoTimeoutEstimator {
468    fn default() -> Self {
469        Self::from_history(History::new_empty())
470    }
471}
472
473/// An object used to serialize our timeout history for persistent state.
474#[derive(Clone, Debug, Serialize, Deserialize, Default)]
475#[serde(default)]
476pub(crate) struct ParetoTimeoutState {
477    /// A version field used to help encoding and decoding.
478    #[allow(dead_code)]
479    version: usize,
480    /// A record of observed timeouts, as returned by `sparse_histogram()`.
481    histogram: Vec<(MsecDuration, u16)>,
482    /// The current timeout estimate: kept for reference.
483    current_timeout: Option<MsecDuration>,
484
485    /// Fields from the state file that was used to make this `ParetoTimeoutState` that
486    /// this version of Arti doesn't understand.
487    #[serde(flatten)]
488    unknown_fields: HashMap<String, JsonValue>,
489}
490
491impl ParetoTimeoutState {
492    /// Return the latest base timeout estimate, as recorded in this state.
493    pub(crate) fn latest_estimate(&self) -> Option<Duration> {
494        self.current_timeout
495            .map(|m| Duration::from_millis(m.0.into()))
496    }
497}
498
499impl ParetoTimeoutEstimator {
500    /// Construct a new ParetoTimeoutEstimator from the provided history
501    /// object.
502    fn from_history(history: History) -> Self {
503        let p = Params::default();
504        ParetoTimeoutEstimator {
505            history,
506            timeouts: None,
507            fallback_timeouts: p.default_thresholds,
508            p,
509        }
510    }
511
512    /// Create a new ParetoTimeoutEstimator based on a loaded
513    /// ParetoTimeoutState.
514    pub(crate) fn from_state(state: ParetoTimeoutState) -> Self {
515        let history = History::from_sparse_histogram(state.histogram.into_iter());
516        Self::from_history(history)
517    }
518
519    /// Compute an unscaled basic pair of timeouts for a circuit of
520    /// the "normal" length.
521    ///
522    /// Return a cached value if we have no observations since the
523    /// last time this function was called.
524    fn base_timeouts(&mut self) -> (Duration, Duration) {
525        if let Some(x) = self.timeouts {
526            // Great; we have a cached value.
527            return x;
528        }
529
530        if self.history.n_times() < self.p.min_observations as usize {
531            // We don't have enough values to estimate.
532            return self.fallback_timeouts;
533        }
534
535        // Here we're going to compute the timeouts, cache them, and
536        // return them.
537        let dist = match self.history.pareto_estimate(self.p.n_modes_for_xm) {
538            Some(dist) => dist,
539            None => {
540                return self.fallback_timeouts;
541            }
542        };
543        let timeout_threshold = dist.quantile(self.p.timeout_quantile);
544        let abandon_threshold = dist
545            .quantile(self.p.abandon_quantile)
546            .max(timeout_threshold);
547
548        let timeouts = (
549            Duration::from_secs_f64(timeout_threshold / 1000.0).max(self.p.min_timeout),
550            Duration::from_secs_f64(abandon_threshold / 1000.0).max(self.p.min_timeout),
551        );
552        self.timeouts = Some(timeouts);
553
554        timeouts
555    }
556}
557
558impl super::TimeoutEstimator for ParetoTimeoutEstimator {
559    fn update_params(&mut self, p: &NetParameters) {
560        let parameters = p.into();
561        self.p = parameters;
562        let new_success_len = self.p.success_history_len;
563        self.history.set_success_history_len(new_success_len);
564    }
565
566    fn note_hop_completed(&mut self, hop: u8, delay: Duration, is_last: bool) {
567        if hop == self.p.significant_hop {
568            let time = MsecDuration::new_saturating(&delay);
569            self.history.add_time(time);
570            self.timeouts.take();
571        }
572        if is_last {
573            self.history.add_success(true);
574        }
575    }
576
577    fn note_circ_timeout(&mut self, hop: u8, delay: Duration) {
578        // Only record this timeout if we have seen some network activity since
579        // we launched the circuit.
580        let have_seen_recent_activity =
581            if let Some(last_traffic) = tor_proto::time_since_last_incoming_traffic() {
582                last_traffic < delay
583            } else {
584                // TODO: Is this the correct behavior in this case?
585                true
586            };
587
588        tracing::trace!(%hop, ?delay, %have_seen_recent_activity, "Circuit timeout");
589
590        if hop > 0 && have_seen_recent_activity {
591            self.history.add_success(false);
592            if self.history.n_recent_timeouts() > self.p.reset_after_timeouts {
593                let base_timeouts = self.base_timeouts();
594                self.history.clear();
595                self.timeouts.take();
596                // If we already had a timeout that was at least the
597                // length of our fallback timeouts, we should double
598                // those fallback timeouts, up to a maximum.
599                if base_timeouts.0 >= self.fallback_timeouts.0 {
600                    /// Largest value we'll allow a fallback timeout
601                    /// (the one we return when we have insufficient data)
602                    /// to reach.
603                    ///
604                    /// TODO: This is a ridiculous over-estimate.
605                    const MAX_FALLBACK_TIMEOUT: Duration = Duration::from_secs(7200);
606                    self.fallback_timeouts.0 =
607                        (self.fallback_timeouts.0 * 2).min(MAX_FALLBACK_TIMEOUT);
608                    self.fallback_timeouts.1 =
609                        (self.fallback_timeouts.1 * 2).min(MAX_FALLBACK_TIMEOUT);
610                }
611            }
612        }
613    }
614
615    fn timeouts(&mut self, action: &Action) -> (Duration, Duration) {
616        let (base_t, base_a) = if self.p.use_estimates {
617            self.base_timeouts()
618        } else {
619            // If we aren't using this estimator, then just return the
620            // default thresholds from our parameters.
621            return self.p.default_thresholds;
622        };
623
624        let reference_action = Action::BuildCircuit {
625            length: self.p.significant_hop as usize + 1,
626        };
627        debug_assert!(reference_action.timeout_scale() > 0);
628
629        let multiplier =
630            (action.timeout_scale() as f64) / (reference_action.timeout_scale() as f64);
631
632        // TODO-SPEC The spec doesn't define any of this
633        // action-based-multiplier stuff.  Tor doesn't multiply the
634        // abandon timeout.
635        use super::mul_duration_f64_saturating as mul;
636        (mul(base_t, multiplier), mul(base_a, multiplier))
637    }
638
639    fn learning_timeouts(&self) -> bool {
640        self.p.use_estimates && self.history.n_times() < usize::from(self.p.min_observations)
641    }
642
643    fn build_state(&mut self) -> Option<ParetoTimeoutState> {
644        let cur_timeout = MsecDuration::new_saturating(&self.base_timeouts().0);
645        Some(ParetoTimeoutState {
646            version: 1,
647            histogram: self.history.sparse_histogram().collect(),
648            current_timeout: Some(cur_timeout),
649            unknown_fields: Default::default(),
650        })
651    }
652}
653
654#[cfg(test)]
655mod test {
656    // @@ begin test lint list maintained by maint/add_warning @@
657    #![allow(clippy::bool_assert_comparison)]
658    #![allow(clippy::clone_on_copy)]
659    #![allow(clippy::dbg_macro)]
660    #![allow(clippy::mixed_attributes_style)]
661    #![allow(clippy::print_stderr)]
662    #![allow(clippy::print_stdout)]
663    #![allow(clippy::single_char_pattern)]
664    #![allow(clippy::unwrap_used)]
665    #![allow(clippy::unchecked_duration_subtraction)]
666    #![allow(clippy::useless_vec)]
667    #![allow(clippy::needless_pass_by_value)]
668    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
669    use super::*;
670    use crate::timeouts::TimeoutEstimator;
671    use tor_basic_utils::test_rng::testing_rng;
672    use tor_basic_utils::RngExt as _;
673
674    /// Return an action to build a 3-hop circuit.
675    fn b3() -> Action {
676        Action::BuildCircuit { length: 3 }
677    }
678
679    impl From<u32> for MsecDuration {
680        fn from(v: u32) -> Self {
681            Self(v)
682        }
683    }
684
685    #[test]
686    fn ms_partial_cmp() {
687        #![allow(clippy::eq_op)]
688        let myriad: MsecDuration = 10_000.into();
689        let lakh: MsecDuration = 100_000.into();
690        let crore: MsecDuration = 10_000_000.into();
691
692        assert!(myriad < lakh);
693        assert!(myriad == myriad);
694        assert!(crore > lakh);
695        assert!(crore >= crore);
696        assert!(crore <= crore);
697    }
698
699    #[test]
700    fn history_lowlev() {
701        assert_eq!(History::bucket_center(1.into()), 5.into());
702        assert_eq!(History::bucket_center(903.into()), 905.into());
703        assert_eq!(History::bucket_center(0.into()), 5.into());
704        assert_eq!(History::bucket_center(u32::MAX.into()), 4294967295.into());
705
706        let mut h = History::new_empty();
707        h.inc_bucket(7.into());
708        h.inc_bucket(8.into());
709        h.inc_bucket(9.into());
710        h.inc_bucket(10.into());
711        h.inc_bucket(11.into());
712        h.inc_bucket(12.into());
713        h.inc_bucket(13.into());
714        h.inc_bucket(299.into());
715        assert_eq!(h.time_histogram.get(&5.into()), Some(&3));
716        assert_eq!(h.time_histogram.get(&15.into()), Some(&4));
717        assert_eq!(h.time_histogram.get(&25.into()), None);
718        assert_eq!(h.time_histogram.get(&295.into()), Some(&1));
719
720        h.dec_bucket(299.into());
721        h.dec_bucket(24.into());
722        h.dec_bucket(12.into());
723
724        assert_eq!(h.time_histogram.get(&15.into()), Some(&3));
725        assert_eq!(h.time_histogram.get(&25.into()), None);
726        assert_eq!(h.time_histogram.get(&295.into()), None);
727
728        h.add_success(true);
729        h.add_success(false);
730        assert_eq!(h.success_history.len(), 2);
731
732        h.clear();
733        assert_eq!(h.time_histogram.len(), 0);
734        assert_eq!(h.time_history.len(), 0);
735        assert_eq!(h.success_history.len(), 0);
736    }
737
738    #[test]
739    fn time_observation_management() {
740        let mut h = History::new_empty();
741        h.set_time_history_len(8); // to make it easier to overflow.
742
743        h.add_time(300.into());
744        h.add_time(500.into());
745        h.add_time(542.into());
746        h.add_time(305.into());
747        h.add_time(543.into());
748        h.add_time(307.into());
749
750        assert_eq!(h.n_times(), 6);
751        let v = h.n_most_frequent_bins(10);
752        assert_eq!(&v[..], [(305.into(), 3), (545.into(), 2), (505.into(), 1)]);
753        let v = h.n_most_frequent_bins(2);
754        assert_eq!(&v[..], [(305.into(), 3), (545.into(), 2)]);
755
756        let v: Vec<_> = h.sparse_histogram().collect();
757        assert_eq!(&v[..], [(305.into(), 3), (505.into(), 1), (545.into(), 2)]);
758
759        h.add_time(212.into());
760        h.add_time(203.into());
761        // now we replace the first couple of older elements.
762        h.add_time(617.into());
763        h.add_time(413.into());
764
765        assert_eq!(h.n_times(), 8);
766
767        let v: Vec<_> = h.sparse_histogram().collect();
768        assert_eq!(
769            &v[..],
770            [
771                (205.into(), 1),
772                (215.into(), 1),
773                (305.into(), 2),
774                (415.into(), 1),
775                (545.into(), 2),
776                (615.into(), 1)
777            ]
778        );
779
780        let h2 = History::from_sparse_histogram(v.clone().into_iter());
781        let v2: Vec<_> = h2.sparse_histogram().collect();
782        assert_eq!(v, v2);
783    }
784
785    #[test]
786    fn success_observation_mechanism() {
787        let mut h = History::new_empty();
788        h.set_success_history_len(20);
789
790        assert_eq!(h.n_recent_timeouts(), 0);
791        h.add_success(true);
792        assert_eq!(h.n_recent_timeouts(), 0);
793        h.add_success(false);
794        assert_eq!(h.n_recent_timeouts(), 1);
795        for _ in 0..200 {
796            h.add_success(false);
797        }
798        assert_eq!(h.n_recent_timeouts(), 20);
799        h.add_success(true);
800        h.add_success(true);
801        h.add_success(true);
802        assert_eq!(h.n_recent_timeouts(), 20 - 3);
803
804        h.set_success_history_len(10);
805        assert_eq!(h.n_recent_timeouts(), 10 - 3);
806    }
807
808    #[test]
809    fn xm_calculation() {
810        let mut h = History::new_empty();
811        assert_eq!(h.estimate_xm(2), None);
812
813        for n in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
814            h.add_time(MsecDuration(*n));
815        }
816
817        let v = h.n_most_frequent_bins(2);
818        assert_eq!(&v[..], [(305.into(), 3), (545.into(), 2)]);
819        let est = (305 * 3 + 545 * 2) / 5;
820        assert_eq!(h.estimate_xm(2), Some(est));
821        assert_eq!(est, 401);
822    }
823
824    #[test]
825    fn pareto_estimate() {
826        let mut h = History::new_empty();
827        assert!(h.pareto_estimate(2).is_none());
828
829        for n in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
830            h.add_time(MsecDuration(*n));
831        }
832        let expected_log_sum: f64 = [401, 500, 542, 401, 543, 401, 401, 401, 617, 413]
833            .iter()
834            .map(|x| f64::from(*x).ln())
835            .sum();
836        let expected_log_xm: f64 = (401_f64).ln() * 10.0;
837        let expected_alpha = 10.0 / (expected_log_sum - expected_log_xm);
838        let expected_inv_alpha = 1.0 / expected_alpha;
839
840        let p = h.pareto_estimate(2).unwrap();
841
842        // We can't do "eq" with floats, so we'll do "very close".
843        assert!((401.0 - p.x_m).abs() < 1.0e-9);
844        assert!((expected_inv_alpha - p.inv_alpha).abs() < 1.0e-9);
845
846        let q60 = p.quantile(0.60);
847        let q99 = p.quantile(0.99);
848
849        assert!((q60 - 451.127) < 0.001);
850        assert!((q99 - 724.841) < 0.001);
851    }
852
853    #[test]
854    fn pareto_estimate_timeout() {
855        let mut est = ParetoTimeoutEstimator::default();
856
857        assert_eq!(
858            est.timeouts(&b3()),
859            (Duration::from_secs(60), Duration::from_secs(60))
860        );
861        // Set the parameters up to mimic the situation in
862        // `pareto_estimate` above.
863        est.p.min_observations = 0;
864        est.p.n_modes_for_xm = 2;
865        assert_eq!(
866            est.timeouts(&b3()),
867            (Duration::from_secs(60), Duration::from_secs(60))
868        );
869
870        for msec in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
871            let d = Duration::from_millis(*msec);
872            est.note_hop_completed(2, d, true);
873        }
874
875        let t = est.timeouts(&b3());
876        assert_eq!(t.0.as_micros(), 493_169);
877        assert_eq!(t.1.as_micros(), 724_841);
878
879        let t2 = est.timeouts(&b3());
880        assert_eq!(t2, t);
881
882        let t2 = est.timeouts(&Action::BuildCircuit { length: 4 });
883        assert_eq!(t2.0, t.0.mul_f64(10.0 / 6.0));
884        assert_eq!(t2.1, t.1.mul_f64(10.0 / 6.0));
885    }
886
887    #[test]
888    fn pareto_estimate_clear() {
889        let mut est = ParetoTimeoutEstimator::default();
890
891        // Set the parameters up to mimic the situation in
892        // `pareto_estimate` above.
893        let params = NetParameters::from_map(&"cbtmincircs=1 cbtnummodes=2".parse().unwrap());
894        est.update_params(&params);
895
896        assert_eq!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
897        assert!(est.learning_timeouts());
898
899        for msec in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
900            let d = Duration::from_millis(*msec);
901            est.note_hop_completed(2, d, true);
902        }
903        assert_ne!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
904        assert!(!est.learning_timeouts());
905        assert_eq!(est.history.n_recent_timeouts(), 0);
906
907        // 17 timeouts happen and we're still getting real numbers...
908        for _ in 0..18 {
909            est.note_circ_timeout(2, Duration::from_secs(2000));
910        }
911        assert_ne!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
912
913        // ... but 18 means "reset".
914        est.note_circ_timeout(2, Duration::from_secs(2000));
915        assert_eq!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
916
917        // And if we fail 18 bunch more times, it doubles.
918        for _ in 0..20 {
919            est.note_circ_timeout(2, Duration::from_secs(2000));
920        }
921        assert_eq!(est.timeouts(&b3()).0.as_micros(), 120_000_000);
922    }
923
924    #[test]
925    fn default_params() {
926        let p1 = Params::default();
927        let p2 = Params::from(&tor_netdir::params::NetParameters::default());
928        // discount version of derive(eq)
929        assert_eq!(format!("{:?}", p1), format!("{:?}", p2));
930    }
931
932    #[test]
933    fn state_conversion() {
934        // We have tests elsewhere for converting to and from
935        // histograms, so all we really need to ddo here is make sure
936        // that the histogram conversion happens.
937
938        let mut est = ParetoTimeoutEstimator::default();
939        let mut rng = testing_rng();
940        for _ in 0..1000 {
941            let d = Duration::from_millis(rng.gen_range_checked(10..3_000).unwrap());
942            est.note_hop_completed(2, d, true);
943        }
944
945        let state = est.build_state().unwrap();
946        assert_eq!(state.version, 1);
947        assert!(state.current_timeout.is_some());
948
949        let mut est2 = ParetoTimeoutEstimator::from_state(state);
950        let act = Action::BuildCircuit { length: 3 };
951        // This isn't going to be exact, since we're recording histogram bins
952        // instead of exact timeouts.
953        let ms1 = est.timeouts(&act).0.as_millis() as i32;
954        let ms2 = est2.timeouts(&act).0.as_millis() as i32;
955        assert!((ms1 - ms2).abs() < 50);
956    }
957
958    #[test]
959    fn validate_iterator_choose_multiple() {
960        // The documentation for IteratorRandom::choose_multiple says that it
961        // returns fewer than N elements if the iterators has fewer than N elements.
962        // But rand has changed behavior in the past, so let's make sure this doesn't
963        // change in the future.
964        use rand::seq::IteratorRandom as _;
965        let mut rng = testing_rng();
966        let mut ten_elements = (1..=10).choose_multiple(&mut rng, 100);
967        ten_elements.sort();
968        assert_eq!(ten_elements.len(), 10);
969        assert_eq!(ten_elements, (1..=10).collect::<Vec<_>>());
970    }
971
972    // TODO: add tests from Tor.
973}