1
//! Implement Tor's sort-of-Pareto estimator for circuit build timeouts.
2
//!
3
//! Our build times don't truly follow a
4
//! [Pareto](https://en.wikipedia.org/wiki/Pareto_distribution)
5
//! distribution; instead they seem to be closer to a
6
//! [Fréchet](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distribution)
7
//! distribution.  But those are hard to work with, and we only care
8
//! about the right tail, so we're using Pareto instead.
9
//!
10
//! This estimator also includes several heuristics and kludges to
11
//! try to behave better on unreliable networks.
12
//! For more information on the exact algorithms and their rationales,
13
//! see [`path-spec.txt`](https://gitlab.torproject.org/tpo/core/torspec/-/blob/master/path-spec.txt).
14

            
15
use bounded_vec_deque::BoundedVecDeque;
16
use serde::{Deserialize, Serialize};
17
use std::collections::{BTreeMap, HashMap};
18
use std::time::Duration;
19
use tor_netdir::params::NetParameters;
20

            
21
use super::Action;
22
use tor_persist::JsonValue;
23

            
24
/// How many circuit build time observations do we record?
25
const TIME_HISTORY_LEN: usize = 1000;
26

            
27
/// How many circuit success-versus-timeout observations do we record
28
/// by default?
29
const SUCCESS_HISTORY_DEFAULT_LEN: usize = 20;
30

            
31
/// How many milliseconds wide is each bucket in our histogram?
32
const BUCKET_WIDTH_MSEC: u32 = 10;
33

            
34
/// A circuit build time or timeout duration, measured in milliseconds.
35
///
36
/// Requires that we don't care about tracking timeouts above u32::MAX
37
/// milliseconds (about 49 days).
38
14
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
39
#[serde(transparent)]
40
struct MsecDuration(u32);
41

            
42
impl MsecDuration {
43
    /// Convert a Duration into a MsecDuration, saturating
44
    /// extra-high values to u32::MAX milliseconds.
45
4548
    fn new_saturating(d: &Duration) -> Self {
46
4548
        let msec = std::cmp::min(d.as_millis(), u128::from(u32::MAX)) as u32;
47
4548
        MsecDuration(msec)
48
4548
    }
49
}
50

            
51
/// Module to hold calls to const_assert.
52
///
53
/// This is a separate module so we can change the clippy warnings on it.
54
#[allow(clippy::checked_conversions)]
55
mod assertion {
56
    use static_assertions::const_assert;
57
    // If this assertion is untrue, then we can't safely use u16 fields in
58
    // time_histogram.
59
    const_assert!(super::TIME_HISTORY_LEN <= u16::MAX as usize);
60
}
61

            
62
/// A history of circuit timeout observations, used to estimate our
63
/// likely circuit timeouts.
64
#[derive(Debug, Clone)]
65
struct History {
66
    /// Our most recent observed circuit construction times.
67
    ///
68
    /// For the purpose of this estimator, a circuit counts as
69
    /// "constructed" when a certain "significant" hop (typically the third)
70
    /// is completed.
71
    time_history: BoundedVecDeque<MsecDuration>,
72

            
73
    /// A histogram representation of the values in [`History::time_history`].
74
    ///
75
    /// This histogram is implemented as a sparse map from the center
76
    /// value of each histogram bucket to the number of entries in
77
    /// that bucket.  It is completely derivable from time_history; we
78
    /// keep it separate here for efficiency.
79
    time_histogram: BTreeMap<MsecDuration, u16>,
80

            
81
    /// Our most recent circuit timeout statuses.
82
    ///
83
    /// Each `true` value represents a successfully completed circuit
84
    /// (all hops).  Each `false` value represents a circuit that
85
    /// timed out after having completed at least one hop.
86
    success_history: BoundedVecDeque<bool>,
87
}
88

            
89
impl History {
90
    /// Initialize a new empty `History` with no observations.
91
126
    fn new_empty() -> Self {
92
126
        History {
93
126
            time_history: BoundedVecDeque::new(TIME_HISTORY_LEN),
94
126
            time_histogram: BTreeMap::new(),
95
126
            success_history: BoundedVecDeque::new(SUCCESS_HISTORY_DEFAULT_LEN),
96
126
        }
97
126
    }
98

            
99
    /// Remove all observations from this `History`.
100
6
    fn clear(&mut self) {
101
6
        self.time_history.clear();
102
6
        self.time_histogram.clear();
103
6
        self.success_history.clear();
104
6
    }
105

            
106
    /// Change the number of successes to record in our success
107
    /// history to `n`.
108
8
    fn set_success_history_len(&mut self, n: usize) {
109
8
        if n < self.success_history.len() {
110
2
            self.success_history
111
2
                .drain(0..(self.success_history.len() - n));
112
6
        }
113
8
        self.success_history.set_max_len(n);
114
8
    }
115

            
116
    /// Change the number of circuit time observations to record in
117
    /// our time history to `n`.
118
    ///
119
    /// This is a testing-only function.
120
    #[cfg(test)]
121
2
    fn set_time_history_len(&mut self, n: usize) {
122
2
        self.time_history.set_max_len(n);
123
2
    }
124

            
125
    /// Construct a new `History` from an iterator representing a sparse
126
    /// histogram of values.
127
    ///
128
    /// The input must be a sequence of `(D,N)` tuples, where each `D`
129
    /// represents a circuit build duration, and `N` represents the
130
    /// number of observations with that duration.
131
    ///
132
    /// These observations are shuffled into a random order, then
133
    /// added to a new History.
134
110
    fn from_sparse_histogram<I>(iter: I) -> Self
135
110
    where
136
110
        I: Iterator<Item = (MsecDuration, u16)>,
137
110
    {
138
        use rand::seq::{IteratorRandom, SliceRandom};
139
        use std::iter;
140
110
        let mut rng = rand::thread_rng();
141
110

            
142
110
        // We want to build a vector with the elements of the old histogram in
143
110
        // random order, but we want to defend ourselves against bogus inputs
144
110
        // that would take too much RAM.
145
110
        let mut observations = iter
146
110
            .take(TIME_HISTORY_LEN) // limit number of bins
147
700
            .flat_map(|(dur, n)| iter::repeat(dur).take(n as usize))
148
110
            .choose_multiple(&mut rng, TIME_HISTORY_LEN);
149
110
        // choose_multiple doesn't guarantee anything about the order of its output.
150
110
        observations.shuffle(&mut rng);
151
110

            
152
110
        let mut result = History::new_empty();
153
4126
        for obs in observations {
154
4016
            result.add_time(obs);
155
4016
        }
156

            
157
110
        result
158
110
    }
159

            
160
    /// Return an iterator yielding a sparse histogram of the circuit build
161
    /// time values in this `History`.
162
    ///
163
    /// Each histogram entry is a `(D,N)` tuple, where `D` is the
164
    /// center of a histogram bucket, and `N` is the number of
165
    /// observations in that bucket.
166
    ///
167
    /// Buckets with `N=0` are omitted.  Buckets are yielded in order.
168
144
    fn sparse_histogram(&self) -> impl Iterator<Item = (MsecDuration, u16)> + '_ {
169
1999
        self.time_histogram.iter().map(|(d, n)| (*d, *n))
170
144
    }
171

            
172
    /// Return the center value for the bucket containing `time`.
173
8950
    fn bucket_center(time: MsecDuration) -> MsecDuration {
174
8950
        let idx = time.0 / BUCKET_WIDTH_MSEC;
175
8950
        let msec = (idx * BUCKET_WIDTH_MSEC) + (BUCKET_WIDTH_MSEC) / 2;
176
8950
        MsecDuration(msec)
177
8950
    }
178

            
179
    /// Increment the histogram bucket containing `time` by one.
180
8532
    fn inc_bucket(&mut self, time: MsecDuration) {
181
8532
        let center = History::bucket_center(time);
182
8532
        *self.time_histogram.entry(center).or_insert(0) += 1;
183
8532
    }
184

            
185
    /// Decrement the histogram bucket containing `time` by one, removing
186
    /// it if it becomes 0.
187
410
    fn dec_bucket(&mut self, time: MsecDuration) {
188
        use std::collections::btree_map::Entry;
189
410
        let center = History::bucket_center(time);
190
410
        match self.time_histogram.entry(center) {
191
2
            Entry::Vacant(_) => {
192
2
                // this is a bug.
193
2
            }
194
408
            Entry::Occupied(e) if e.get() <= &1 => {
195
4
                e.remove();
196
4
            }
197
404
            Entry::Occupied(mut e) => {
198
404
                *e.get_mut() -= 1;
199
404
            }
200
        }
201
410
    }
202

            
203
    /// Add `time` to our list of circuit build time observations, and
204
    /// adjust the histogram accordingly.
205
8516
    fn add_time(&mut self, time: MsecDuration) {
206
8516
        match self.time_history.push_back(time) {
207
8112
            None => {}
208
404
            Some(removed_time) => {
209
404
                // `removed_time` just fell off the end of the deque:
210
404
                // remove it from the histogram.
211
404
                self.dec_bucket(removed_time);
212
404
            }
213
        }
214
8516
        self.inc_bucket(time);
215
8516
    }
216

            
217
    /// Return the number of observations in our time history.
218
    ///
219
    /// This will always be `<= TIME_HISTORY_LEN`.
220
144
    fn n_times(&self) -> usize {
221
144
        self.time_history.len()
222
144
    }
223

            
224
    /// Record a success (true) or timeout (false) in our record of whether
225
    /// circuits timed out or not.
226
4932
    fn add_success(&mut self, succeeded: bool) {
227
4932
        self.success_history.push_back(succeeded);
228
4932
    }
229

            
230
    /// Return the number of timeouts recorded in our success history.
231
92
    fn n_recent_timeouts(&self) -> usize {
232
1224
        self.success_history.iter().filter(|x| !**x).count()
233
92
    }
234

            
235
    /// Helper: return the `n` most frequent histogram bins.
236
30
    fn n_most_frequent_bins(&self, n: usize) -> Vec<(MsecDuration, u16)> {
237
        use itertools::Itertools;
238
        // we use cmp::Reverse here so that we can use k_smallest as
239
        // if it were "k_largest".
240
        use std::cmp::Reverse;
241

            
242
        // We want the buckets that have the _largest_ counts; we want
243
        // to break ties in favor of the _smallest_ values.  So we
244
        // apply Reverse only to the counts before passing the tuples
245
        // to k_smallest.
246

            
247
30
        self.sparse_histogram()
248
1275
            .map(|(center, count)| (Reverse(count), center))
249
30
            // (k_smallest runs in O(n_bins * lg(n))
250
30
            .k_smallest(n)
251
99
            .map(|(Reverse(count), center)| (center, count))
252
30
            .collect()
253
30
    }
254

            
255
    /// Return an estimator for the `X_m` of our Pareto distribution,
256
    /// by looking at the `n_modes` most frequently filled histogram
257
    /// bins.
258
    ///
259
    /// It is not a true `X_m` value, since there are definitely
260
    /// values less than this, but it seems to work as a decent
261
    /// heuristic.
262
    ///
263
    /// Return `None` if we have no observations.
264
24
    fn estimate_xm(&self, n_modes: usize) -> Option<u32> {
265
24
        // From path-spec:
266
24
        //   Tor clients compute the Xm parameter using the weighted
267
24
        //   average of the midpoints of the 'cbtnummodes' (10)
268
24
        //   most frequently occurring 10ms histogram bins.
269
24

            
270
24
        // The most frequently used bins.
271
24
        let bins = self.n_most_frequent_bins(n_modes);
272
24
        // Total number of observations in these bins.
273
82
        let n_observations: u16 = bins.iter().map(|(_, n)| n).sum();
274
24
        // Sum of all observations in these bins.
275
24
        let total_observations: u64 = bins
276
24
            .iter()
277
82
            .map(|(d, n)| u64::from(d.0 * u32::from(*n)))
278
24
            .sum();
279
24

            
280
24
        if n_observations == 0 {
281
6
            None
282
        } else {
283
18
            Some((total_observations / u64::from(n_observations)) as u32)
284
        }
285
24
    }
286

            
287
    /// Compute a maximum-likelihood pareto distribution based on this
288
    /// history, computing `X_m` based on the `n_modes` most frequent
289
    /// histograms.
290
    ///
291
    /// Return None if we have no observations.
292
20
    fn pareto_estimate(&self, n_modes: usize) -> Option<ParetoDist> {
293
20
        let xm = self.estimate_xm(n_modes)?;
294

            
295
        // From path-spec:
296
        //     alpha = n/(Sum_n{ln(MAX(Xm, x_i))} - n*ln(Xm))
297

            
298
16
        let n = self.time_history.len();
299
16
        let sum_of_log_observations: f64 = self
300
16
            .time_history
301
16
            .iter()
302
10068
            .map(|m| f64::from(std::cmp::max(m.0, xm)).ln())
303
16
            .sum();
304
16
        let sum_of_log_xm = (n as f64) * f64::from(xm).ln();
305
16

            
306
16
        // We're computing 1/alpha here, instead of alpha.  This avoids
307
16
        // division by zero, and has the advantage of being what our
308
16
        // quantile estimator actually needs.
309
16
        let inv_alpha = (sum_of_log_observations - sum_of_log_xm) / (n as f64);
310
16

            
311
16
        Some(ParetoDist {
312
16
            x_m: f64::from(xm),
313
16
            inv_alpha,
314
16
        })
315
20
    }
316
}
317

            
318
/// A Pareto distribution, for use in estimating timeouts.
319
///
320
/// Values are represented by a number of milliseconds.
321
#[derive(Debug)]
322
struct ParetoDist {
323
    /// The lower bound for the pareto distribution.
324
    x_m: f64,
325
    /// The inverse of the alpha parameter in the pareto distribution.
326
    ///
327
    /// (We use 1/alpha here to save a step in [`ParetoDist::quantile`].
328
    inv_alpha: f64,
329
}
330

            
331
impl ParetoDist {
332
    /// Compute an inverse CDF for this distribution.
333
    ///
334
    /// Given a `q` value between 0 and 1, compute a distribution `v`
335
    /// value such that `q` of the Pareto Distribution is expected to
336
    /// be less than `v`.
337
    ///
338
    /// If `q` is out of bounds, it is clamped to [0.0, 1.0].
339
32
    fn quantile(&self, q: f64) -> f64 {
340
32
        let q = q.clamp(0.0, 1.0);
341
32
        self.x_m / ((1.0 - q).powf(self.inv_alpha))
342
32
    }
343
}
344

            
345
/// A set of parameters determining the behavior of a ParetoTimeoutEstimator.
346
///
347
/// These are typically derived from a set of consensus parameters.
348
#[derive(Clone, Debug)]
349
pub(crate) struct Params {
350
    /// Should we use our estimates when deciding on circuit timeouts.
351
    ///
352
    /// When this is false, our timeouts are fixed to the default.
353
    use_estimates: bool,
354
    /// How many observations must we have made before we can use our
355
    /// Pareto estimators to guess a good set of timeouts?
356
    min_observations: u16,
357
    /// Which hop is the "significant hop" we should use when recording circuit
358
    /// build times?  (Watch out! This is zero-indexed.)
359
    significant_hop: u8,
360
    /// A quantile (in range [0.0,1.0]) describing a point in the
361
    /// Pareto distribution to use when determining when a circuit
362
    /// should be treated as having "timed out".
363
    ///
364
    /// (A "timed out" circuit continues building for measurement
365
    /// purposes, but can't be used for traffic.)
366
    timeout_quantile: f64,
367
    /// A quantile (in range [0.0,1.0]) describing a point in the Pareto
368
    /// distribution to use when determining when a circuit should be
369
    /// "abandoned".
370
    ///
371
    /// (An "abandoned" circuit is stopped entirely, and not included
372
    /// in measurements.
373
    abandon_quantile: f64,
374
    /// Default values to return from the `timeouts` function when we
375
    /// have no observations.
376
    default_thresholds: (Duration, Duration),
377
    /// Number of histogram buckets to use when determining the Xm estimate.
378
    ///
379
    /// (See [`History::estimate_xm`] for details.)
380
    n_modes_for_xm: usize,
381
    /// How many entries do we record in our success/timeout history?
382
    success_history_len: usize,
383
    /// How many timeouts should we allow in our success/timeout history
384
    /// before we assume that network has changed in a way that makes
385
    /// our estimates completely wrong?
386
    reset_after_timeouts: usize,
387
    /// Minimum base timeout to ever infer or return.
388
    min_timeout: Duration,
389
}
390

            
391
impl Default for Params {
392
116
    fn default() -> Self {
393
116
        Params {
394
116
            use_estimates: true,
395
116
            min_observations: 100,
396
116
            significant_hop: 2,
397
116
            timeout_quantile: 0.80,
398
116
            abandon_quantile: 0.99,
399
116
            default_thresholds: (Duration::from_secs(60), Duration::from_secs(60)),
400
116
            n_modes_for_xm: 10,
401
116
            success_history_len: SUCCESS_HISTORY_DEFAULT_LEN,
402
116
            reset_after_timeouts: 18,
403
116
            min_timeout: Duration::from_millis(10),
404
116
        }
405
116
    }
406
}
407

            
408
impl From<&NetParameters> for Params {
409
6
    fn from(p: &NetParameters) -> Params {
410
6
        // Because of the underlying bounds, the "unwrap_or_else"
411
6
        // conversions here should be impossible, and the "as"
412
6
        // conversions should always be in-range.
413
6

            
414
6
        let timeout = p
415
6
            .cbt_initial_timeout
416
6
            .try_into()
417
6
            .unwrap_or_else(|_| Duration::from_secs(60));
418
6
        let learning_disabled: bool = p.cbt_learning_disabled.into();
419
6
        Params {
420
6
            use_estimates: !learning_disabled,
421
6
            min_observations: p.cbt_min_circs_for_estimate.get() as u16,
422
6
            significant_hop: 2,
423
6
            timeout_quantile: p.cbt_timeout_quantile.as_fraction(),
424
6
            abandon_quantile: p.cbt_abandon_quantile.as_fraction(),
425
6
            default_thresholds: (timeout, timeout),
426
6
            n_modes_for_xm: p.cbt_num_xm_modes.get() as usize,
427
6
            success_history_len: p.cbt_success_count.get() as usize,
428
6
            reset_after_timeouts: p.cbt_max_timeouts.get() as usize,
429
6
            min_timeout: p
430
6
                .cbt_min_timeout
431
6
                .try_into()
432
6
                .unwrap_or_else(|_| Duration::from_millis(10)),
433
6
        }
434
6
    }
435
}
436

            
437
/// Tor's default circuit build timeout estimator.
438
///
439
/// This object records a set of observed circuit build times, and
440
/// uses it to determine good values for how long we should allow
441
/// circuits to build.
442
///
443
/// For full details of the algorithms used, see
444
/// [`path-spec.txt`](https://gitlab.torproject.org/tpo/core/torspec/-/blob/master/path-spec.txt).
445
pub(crate) struct ParetoTimeoutEstimator {
446
    /// Our observations for circuit build times and success/failure
447
    /// history.
448
    history: History,
449

            
450
    /// Our most recent timeout estimate, if we have one that is
451
    /// up-to-date.
452
    ///
453
    /// (We reset this to None whenever we get a new observation.)
454
    timeouts: Option<(Duration, Duration)>,
455

            
456
    /// The timeouts that we use when we do not have sufficient observations
457
    /// to conclude anything about our circuit build times.
458
    ///
459
    /// These start out as `p.default_thresholds`, but can be adjusted
460
    /// depending on how many timeouts we've been seeing.
461
    fallback_timeouts: (Duration, Duration),
462

            
463
    /// A set of parameters to use in computing circuit build timeout
464
    /// estimates.
465
    p: Params,
466
}
467

            
468
impl Default for ParetoTimeoutEstimator {
469
6
    fn default() -> Self {
470
6
        Self::from_history(History::new_empty())
471
6
    }
472
}
473

            
474
/// An object used to serialize our timeout history for persistent state.
475
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
476
#[serde(default)]
477
pub(crate) struct ParetoTimeoutState {
478
    /// A version field used to help encoding and decoding.
479
    #[allow(dead_code)]
480
    version: usize,
481
    /// A record of observed timeouts, as returned by `sparse_histogram()`.
482
    histogram: Vec<(MsecDuration, u16)>,
483
    /// The current timeout estimate: kept for reference.
484
    current_timeout: Option<MsecDuration>,
485

            
486
    /// Fields from the state file that was used to make this `ParetoTimeoutState` that
487
    /// this version of Arti doesn't understand.
488
    #[serde(flatten)]
489
    unknown_fields: HashMap<String, JsonValue>,
490
}
491

            
492
impl ParetoTimeoutState {
493
    /// Return the latest base timeout estimate, as recorded in this state.
494
152
    pub(crate) fn latest_estimate(&self) -> Option<Duration> {
495
152
        self.current_timeout
496
154
            .map(|m| Duration::from_millis(m.0.into()))
497
152
    }
498
}
499

            
500
impl ParetoTimeoutEstimator {
501
    /// Construct a new ParetoTimeoutEstimator from the provided history
502
    /// object.
503
114
    fn from_history(history: History) -> Self {
504
114
        let p = Params::default();
505
114
        ParetoTimeoutEstimator {
506
114
            history,
507
114
            timeouts: None,
508
114
            fallback_timeouts: p.default_thresholds,
509
114
            p,
510
114
        }
511
114
    }
512

            
513
    /// Create a new ParetoTimeoutEstimator based on a loaded
514
    /// ParetoTimeoutState.
515
108
    pub(crate) fn from_state(state: ParetoTimeoutState) -> Self {
516
108
        let history = History::from_sparse_histogram(state.histogram.into_iter());
517
108
        Self::from_history(history)
518
108
    }
519

            
520
    /// Compute an unscaled basic pair of timeouts for a circuit of
521
    /// the "normal" length.
522
    ///
523
    /// Return a cached value if we have no observations since the
524
    /// last time this function was called.
525
146
    fn base_timeouts(&mut self) -> (Duration, Duration) {
526
146
        if let Some(x) = self.timeouts {
527
            // Great; we have a cached value.
528
14
            return x;
529
132
        }
530
132

            
531
132
        if self.history.n_times() < self.p.min_observations as usize {
532
            // We don't have enough values to estimate.
533
116
            return self.fallback_timeouts;
534
16
        }
535

            
536
        // Here we're going to compute the timeouts, cache them, and
537
        // return them.
538
16
        let dist = match self.history.pareto_estimate(self.p.n_modes_for_xm) {
539
14
            Some(dist) => dist,
540
            None => {
541
2
                return self.fallback_timeouts;
542
            }
543
        };
544
14
        let timeout_threshold = dist.quantile(self.p.timeout_quantile);
545
14
        let abandon_threshold = dist
546
14
            .quantile(self.p.abandon_quantile)
547
14
            .max(timeout_threshold);
548
14

            
549
14
        let timeouts = (
550
14
            Duration::from_secs_f64(timeout_threshold / 1000.0).max(self.p.min_timeout),
551
14
            Duration::from_secs_f64(abandon_threshold / 1000.0).max(self.p.min_timeout),
552
14
        );
553
14
        self.timeouts = Some(timeouts);
554
14

            
555
14
        timeouts
556
146
    }
557
}
558

            
559
impl super::TimeoutEstimator for ParetoTimeoutEstimator {
560
4
    fn update_params(&mut self, p: &NetParameters) {
561
4
        let parameters = p.into();
562
4
        self.p = parameters;
563
4
        let new_success_len = self.p.success_history_len;
564
4
        self.history.set_success_history_len(new_success_len);
565
4
    }
566

            
567
4440
    fn note_hop_completed(&mut self, hop: u8, delay: Duration, is_last: bool) {
568
4440
        if hop == self.p.significant_hop {
569
4440
            let time = MsecDuration::new_saturating(&delay);
570
4440
            self.history.add_time(time);
571
4440
            self.timeouts.take();
572
4440
        }
573
4440
        if is_last {
574
4440
            self.history.add_success(true);
575
4440
        }
576
4440
    }
577

            
578
78
    fn note_circ_timeout(&mut self, hop: u8, delay: Duration) {
579
        // Only record this timeout if we have seen some network activity since
580
        // we launched the circuit.
581
78
        let have_seen_recent_activity =
582
78
            if let Some(last_traffic) = tor_proto::time_since_last_incoming_traffic() {
583
                last_traffic < delay
584
            } else {
585
                // TODO: Is this the correct behavior in this case?
586
78
                true
587
            };
588

            
589
78
        tracing::trace!(%hop, ?delay, %have_seen_recent_activity, "Circuit timeout");
590

            
591
78
        if hop > 0 && have_seen_recent_activity {
592
78
            self.history.add_success(false);
593
78
            if self.history.n_recent_timeouts() > self.p.reset_after_timeouts {
594
4
                let base_timeouts = self.base_timeouts();
595
4
                self.history.clear();
596
4
                self.timeouts.take();
597
4
                // If we already had a timeout that was at least the
598
4
                // length of our fallback timeouts, we should double
599
4
                // those fallback timeouts.
600
4
                if base_timeouts.0 >= self.fallback_timeouts.0 {
601
2
                    self.fallback_timeouts.0 *= 2;
602
2
                    self.fallback_timeouts.1 *= 2;
603
2
                }
604
74
            }
605
        }
606
78
    }
607

            
608
34
    fn timeouts(&mut self, action: &Action) -> (Duration, Duration) {
609
34
        let (base_t, base_a) = if self.p.use_estimates {
610
34
            self.base_timeouts()
611
        } else {
612
            // If we aren't using this estimator, then just return the
613
            // default thresholds from our parameters.
614
            return self.p.default_thresholds;
615
        };
616

            
617
34
        let reference_action = Action::BuildCircuit {
618
34
            length: self.p.significant_hop as usize + 1,
619
34
        };
620
34
        debug_assert!(reference_action.timeout_scale() > 0);
621

            
622
34
        let multiplier =
623
34
            (action.timeout_scale() as f64) / (reference_action.timeout_scale() as f64);
624

            
625
        // TODO-SPEC The spec doesn't define any of this
626
        // action-based-multiplier stuff.  Tor doesn't multiply the
627
        // abandon timeout.
628
        use super::mul_duration_f64_saturating as mul;
629
34
        (mul(base_t, multiplier), mul(base_a, multiplier))
630
34
    }
631

            
632
8
    fn learning_timeouts(&self) -> bool {
633
8
        self.p.use_estimates && self.history.n_times() < self.p.min_observations.into()
634
8
    }
635

            
636
108
    fn build_state(&mut self) -> Option<ParetoTimeoutState> {
637
108
        let cur_timeout = MsecDuration::new_saturating(&self.base_timeouts().0);
638
108
        Some(ParetoTimeoutState {
639
108
            version: 1,
640
108
            histogram: self.history.sparse_histogram().collect(),
641
108
            current_timeout: Some(cur_timeout),
642
108
            unknown_fields: Default::default(),
643
108
        })
644
108
    }
645
}
646

            
647
#[cfg(test)]
648
mod test {
649
    // @@ begin test lint list maintained by maint/add_warning @@
650
    #![allow(clippy::bool_assert_comparison)]
651
    #![allow(clippy::clone_on_copy)]
652
    #![allow(clippy::dbg_macro)]
653
    #![allow(clippy::mixed_attributes_style)]
654
    #![allow(clippy::print_stderr)]
655
    #![allow(clippy::print_stdout)]
656
    #![allow(clippy::single_char_pattern)]
657
    #![allow(clippy::unwrap_used)]
658
    #![allow(clippy::unchecked_duration_subtraction)]
659
    #![allow(clippy::useless_vec)]
660
    #![allow(clippy::needless_pass_by_value)]
661
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
662
    use super::*;
663
    use crate::timeouts::TimeoutEstimator;
664
    use tor_basic_utils::test_rng::testing_rng;
665
    use tor_basic_utils::RngExt as _;
666

            
667
    /// Return an action to build a 3-hop circuit.
668
    fn b3() -> Action {
669
        Action::BuildCircuit { length: 3 }
670
    }
671

            
672
    impl From<u32> for MsecDuration {
673
        fn from(v: u32) -> Self {
674
            Self(v)
675
        }
676
    }
677

            
678
    #[test]
679
    fn ms_partial_cmp() {
680
        #![allow(clippy::eq_op)]
681
        let myriad: MsecDuration = 10_000.into();
682
        let lakh: MsecDuration = 100_000.into();
683
        let crore: MsecDuration = 10_000_000.into();
684

            
685
        assert!(myriad < lakh);
686
        assert!(myriad == myriad);
687
        assert!(crore > lakh);
688
        assert!(crore >= crore);
689
        assert!(crore <= crore);
690
    }
691

            
692
    #[test]
693
    fn history_lowlev() {
694
        assert_eq!(History::bucket_center(1.into()), 5.into());
695
        assert_eq!(History::bucket_center(903.into()), 905.into());
696
        assert_eq!(History::bucket_center(0.into()), 5.into());
697
        assert_eq!(History::bucket_center(u32::MAX.into()), 4294967295.into());
698

            
699
        let mut h = History::new_empty();
700
        h.inc_bucket(7.into());
701
        h.inc_bucket(8.into());
702
        h.inc_bucket(9.into());
703
        h.inc_bucket(10.into());
704
        h.inc_bucket(11.into());
705
        h.inc_bucket(12.into());
706
        h.inc_bucket(13.into());
707
        h.inc_bucket(299.into());
708
        assert_eq!(h.time_histogram.get(&5.into()), Some(&3));
709
        assert_eq!(h.time_histogram.get(&15.into()), Some(&4));
710
        assert_eq!(h.time_histogram.get(&25.into()), None);
711
        assert_eq!(h.time_histogram.get(&295.into()), Some(&1));
712

            
713
        h.dec_bucket(299.into());
714
        h.dec_bucket(24.into());
715
        h.dec_bucket(12.into());
716

            
717
        assert_eq!(h.time_histogram.get(&15.into()), Some(&3));
718
        assert_eq!(h.time_histogram.get(&25.into()), None);
719
        assert_eq!(h.time_histogram.get(&295.into()), None);
720

            
721
        h.add_success(true);
722
        h.add_success(false);
723
        assert_eq!(h.success_history.len(), 2);
724

            
725
        h.clear();
726
        assert_eq!(h.time_histogram.len(), 0);
727
        assert_eq!(h.time_history.len(), 0);
728
        assert_eq!(h.success_history.len(), 0);
729
    }
730

            
731
    #[test]
732
    fn time_observation_management() {
733
        let mut h = History::new_empty();
734
        h.set_time_history_len(8); // to make it easier to overflow.
735

            
736
        h.add_time(300.into());
737
        h.add_time(500.into());
738
        h.add_time(542.into());
739
        h.add_time(305.into());
740
        h.add_time(543.into());
741
        h.add_time(307.into());
742

            
743
        assert_eq!(h.n_times(), 6);
744
        let v = h.n_most_frequent_bins(10);
745
        assert_eq!(&v[..], [(305.into(), 3), (545.into(), 2), (505.into(), 1)]);
746
        let v = h.n_most_frequent_bins(2);
747
        assert_eq!(&v[..], [(305.into(), 3), (545.into(), 2)]);
748

            
749
        let v: Vec<_> = h.sparse_histogram().collect();
750
        assert_eq!(&v[..], [(305.into(), 3), (505.into(), 1), (545.into(), 2)]);
751

            
752
        h.add_time(212.into());
753
        h.add_time(203.into());
754
        // now we replace the first couple of older elements.
755
        h.add_time(617.into());
756
        h.add_time(413.into());
757

            
758
        assert_eq!(h.n_times(), 8);
759

            
760
        let v: Vec<_> = h.sparse_histogram().collect();
761
        assert_eq!(
762
            &v[..],
763
            [
764
                (205.into(), 1),
765
                (215.into(), 1),
766
                (305.into(), 2),
767
                (415.into(), 1),
768
                (545.into(), 2),
769
                (615.into(), 1)
770
            ]
771
        );
772

            
773
        let h2 = History::from_sparse_histogram(v.clone().into_iter());
774
        let v2: Vec<_> = h2.sparse_histogram().collect();
775
        assert_eq!(v, v2);
776
    }
777

            
778
    #[test]
779
    fn success_observation_mechanism() {
780
        let mut h = History::new_empty();
781
        h.set_success_history_len(20);
782

            
783
        assert_eq!(h.n_recent_timeouts(), 0);
784
        h.add_success(true);
785
        assert_eq!(h.n_recent_timeouts(), 0);
786
        h.add_success(false);
787
        assert_eq!(h.n_recent_timeouts(), 1);
788
        for _ in 0..200 {
789
            h.add_success(false);
790
        }
791
        assert_eq!(h.n_recent_timeouts(), 20);
792
        h.add_success(true);
793
        h.add_success(true);
794
        h.add_success(true);
795
        assert_eq!(h.n_recent_timeouts(), 20 - 3);
796

            
797
        h.set_success_history_len(10);
798
        assert_eq!(h.n_recent_timeouts(), 10 - 3);
799
    }
800

            
801
    #[test]
802
    fn xm_calculation() {
803
        let mut h = History::new_empty();
804
        assert_eq!(h.estimate_xm(2), None);
805

            
806
        for n in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
807
            h.add_time(MsecDuration(*n));
808
        }
809

            
810
        let v = h.n_most_frequent_bins(2);
811
        assert_eq!(&v[..], [(305.into(), 3), (545.into(), 2)]);
812
        let est = (305 * 3 + 545 * 2) / 5;
813
        assert_eq!(h.estimate_xm(2), Some(est));
814
        assert_eq!(est, 401);
815
    }
816

            
817
    #[test]
818
    fn pareto_estimate() {
819
        let mut h = History::new_empty();
820
        assert!(h.pareto_estimate(2).is_none());
821

            
822
        for n in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
823
            h.add_time(MsecDuration(*n));
824
        }
825
        let expected_log_sum: f64 = [401, 500, 542, 401, 543, 401, 401, 401, 617, 413]
826
            .iter()
827
            .map(|x| f64::from(*x).ln())
828
            .sum();
829
        let expected_log_xm: f64 = (401_f64).ln() * 10.0;
830
        let expected_alpha = 10.0 / (expected_log_sum - expected_log_xm);
831
        let expected_inv_alpha = 1.0 / expected_alpha;
832

            
833
        let p = h.pareto_estimate(2).unwrap();
834

            
835
        // We can't do "eq" with floats, so we'll do "very close".
836
        assert!((401.0 - p.x_m).abs() < 1.0e-9);
837
        assert!((expected_inv_alpha - p.inv_alpha).abs() < 1.0e-9);
838

            
839
        let q60 = p.quantile(0.60);
840
        let q99 = p.quantile(0.99);
841

            
842
        assert!((q60 - 451.127) < 0.001);
843
        assert!((q99 - 724.841) < 0.001);
844
    }
845

            
846
    #[test]
847
    fn pareto_estimate_timeout() {
848
        let mut est = ParetoTimeoutEstimator::default();
849

            
850
        assert_eq!(
851
            est.timeouts(&b3()),
852
            (Duration::from_secs(60), Duration::from_secs(60))
853
        );
854
        // Set the parameters up to mimic the situation in
855
        // `pareto_estimate` above.
856
        est.p.min_observations = 0;
857
        est.p.n_modes_for_xm = 2;
858
        assert_eq!(
859
            est.timeouts(&b3()),
860
            (Duration::from_secs(60), Duration::from_secs(60))
861
        );
862

            
863
        for msec in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
864
            let d = Duration::from_millis(*msec);
865
            est.note_hop_completed(2, d, true);
866
        }
867

            
868
        let t = est.timeouts(&b3());
869
        assert_eq!(t.0.as_micros(), 493_169);
870
        assert_eq!(t.1.as_micros(), 724_841);
871

            
872
        let t2 = est.timeouts(&b3());
873
        assert_eq!(t2, t);
874

            
875
        let t2 = est.timeouts(&Action::BuildCircuit { length: 4 });
876
        assert_eq!(t2.0, t.0.mul_f64(10.0 / 6.0));
877
        assert_eq!(t2.1, t.1.mul_f64(10.0 / 6.0));
878
    }
879

            
880
    #[test]
881
    fn pareto_estimate_clear() {
882
        let mut est = ParetoTimeoutEstimator::default();
883

            
884
        // Set the parameters up to mimic the situation in
885
        // `pareto_estimate` above.
886
        let params = NetParameters::from_map(&"cbtmincircs=1 cbtnummodes=2".parse().unwrap());
887
        est.update_params(&params);
888

            
889
        assert_eq!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
890
        assert!(est.learning_timeouts());
891

            
892
        for msec in &[300, 500, 542, 305, 543, 307, 212, 203, 617, 413] {
893
            let d = Duration::from_millis(*msec);
894
            est.note_hop_completed(2, d, true);
895
        }
896
        assert_ne!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
897
        assert!(!est.learning_timeouts());
898
        assert_eq!(est.history.n_recent_timeouts(), 0);
899

            
900
        // 17 timeouts happen and we're still getting real numbers...
901
        for _ in 0..18 {
902
            est.note_circ_timeout(2, Duration::from_secs(2000));
903
        }
904
        assert_ne!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
905

            
906
        // ... but 18 means "reset".
907
        est.note_circ_timeout(2, Duration::from_secs(2000));
908
        assert_eq!(est.timeouts(&b3()).0.as_micros(), 60_000_000);
909

            
910
        // And if we fail 18 bunch more times, it doubles.
911
        for _ in 0..20 {
912
            est.note_circ_timeout(2, Duration::from_secs(2000));
913
        }
914
        assert_eq!(est.timeouts(&b3()).0.as_micros(), 120_000_000);
915
    }
916

            
917
    #[test]
918
    fn default_params() {
919
        let p1 = Params::default();
920
        let p2 = Params::from(&tor_netdir::params::NetParameters::default());
921
        // discount version of derive(eq)
922
        assert_eq!(format!("{:?}", p1), format!("{:?}", p2));
923
    }
924

            
925
    #[test]
926
    fn state_conversion() {
927
        // We have tests elsewhere for converting to and from
928
        // histograms, so all we really need to ddo here is make sure
929
        // that the histogram conversion happens.
930

            
931
        let mut est = ParetoTimeoutEstimator::default();
932
        let mut rng = testing_rng();
933
        for _ in 0..1000 {
934
            let d = Duration::from_millis(rng.gen_range_checked(10..3_000).unwrap());
935
            est.note_hop_completed(2, d, true);
936
        }
937

            
938
        let state = est.build_state().unwrap();
939
        assert_eq!(state.version, 1);
940
        assert!(state.current_timeout.is_some());
941

            
942
        let mut est2 = ParetoTimeoutEstimator::from_state(state);
943
        let act = Action::BuildCircuit { length: 3 };
944
        // This isn't going to be exact, since we're recording histogram bins
945
        // instead of exact timeouts.
946
        let ms1 = est.timeouts(&act).0.as_millis() as i32;
947
        let ms2 = est2.timeouts(&act).0.as_millis() as i32;
948
        assert!((ms1 - ms2).abs() < 50);
949
    }
950

            
951
    // TODO: add tests from Tor.
952
}