1
//! Functions for applying the correct weights to relays when choosing
2
//! a relay at random.
3
//!
4
//! The weight to use when picking a relay depends on several factors:
5
//!
6
//! - The relay's *apparent bandwidth*.  (This is ideally measured by a set of
7
//!   bandwidth authorities, but if no bandwidth authorities are running (as on
8
//!   a test network), we might fall back either to relays' self-declared
9
//!   values, or we might treat all relays as having equal bandwidth.)
10
//! - The role that we're selecting a relay to play.  (See [`WeightRole`]).
11
//! - The flags that a relay has in the consensus, and their scarcity.  If a
12
//!   relay provides particularly scarce functionality, we might choose not to
13
//!   use it for other roles, or to use it less commonly for them.
14

            
15
use crate::params::NetParameters;
16
use crate::ConsensusRelays;
17
use bitflags::bitflags;
18
use tor_netdoc::doc::netstatus::{self, MdConsensus, MdConsensusRouterStatus, NetParams};
19

            
20
/// Helper: Calculate the function we should use to find initial relay
21
/// bandwidths.
22
3290
fn pick_bandwidth_fn<'a, I>(mut weights: I) -> BandwidthFn
23
3290
where
24
3290
    I: Clone + Iterator<Item = &'a netstatus::RelayWeight>,
25
3290
{
26
3850
    let has_measured = weights.clone().any(|w| w.is_measured());
27
3477
    let has_nonzero = weights.clone().any(|w| w.is_nonzero());
28
3852
    let has_nonzero_measured = weights.any(|w| w.is_measured() && w.is_nonzero());
29
3290

            
30
3290
    if !has_nonzero {
31
        // If every value is zero, we should just pretend everything has
32
        // bandwidth == 1.
33
41
        BandwidthFn::Uniform
34
3249
    } else if !has_measured {
35
        // If there are no measured values, then we can look at unmeasured
36
        // weights.
37
97
        BandwidthFn::IncludeUnmeasured
38
3152
    } else if has_nonzero_measured {
39
        // Otherwise, there are measured values; we should look at those only, if
40
        // any of them is nonzero.
41
3150
        BandwidthFn::MeasuredOnly
42
    } else {
43
        // This is a bit of an ugly case: We have measured values, but they're
44
        // all zero.  If this happens, the bandwidth authorities exist but they
45
        // very confused: we should fall back to uniform weighting.
46
2
        BandwidthFn::Uniform
47
    }
48
3290
}
49

            
50
/// Internal: how should we find the base bandwidth of each relay?  This
51
/// value is global over a whole directory, and depends on the bandwidth
52
/// weights in the consensus.
53
70
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
54
enum BandwidthFn {
55
    /// There are no weights at all in the consensus: weight every
56
    /// relay as 1.
57
    Uniform,
58
    /// There are no measured weights in the consensus: count
59
    /// unmeasured weights as the weights for relays.
60
    IncludeUnmeasured,
61
    /// There are measured relays in the consensus; only use those.
62
    MeasuredOnly,
63
}
64

            
65
impl BandwidthFn {
66
    /// Apply this function to the measured or unmeasured bandwidth
67
    /// of a single relay.
68
7616755
    fn apply(&self, w: &netstatus::RelayWeight) -> u32 {
69
7616755
        use netstatus::RelayWeight::*;
70
7616755
        use BandwidthFn::*;
71
7616755
        match (self, w) {
72
2002
            (Uniform, _) => 1,
73
8236
            (IncludeUnmeasured, Unmeasured(u)) => *u,
74
2
            (IncludeUnmeasured, Measured(m)) => *m,
75
124
            (MeasuredOnly, Unmeasured(_)) => 0,
76
7606391
            (MeasuredOnly, Measured(m)) => *m,
77
            (_, _) => 0,
78
        }
79
7616755
    }
80
}
81

            
82
/// Possible ways to weight relays when selecting them a random.
83
///
84
/// Relays are weighted by a function of their bandwidth that
85
/// depends on how scarce that "kind" of bandwidth is.  For
86
/// example, if Exit bandwidth is rare, then Exits should be
87
/// less likely to get chosen for the middle hop of a path.
88
#[derive(Clone, Debug, Copy)]
89
#[non_exhaustive]
90
pub enum WeightRole {
91
    /// Selecting a relay to use as a guard
92
    Guard,
93
    /// Selecting a relay to use as a middle relay in a circuit.
94
    Middle,
95
    /// Selecting a relay to use to deliver traffic to the internet.
96
    Exit,
97
    /// Selecting a relay for a one-hop BEGIN_DIR directory request.
98
    BeginDir,
99
    /// Selecting a relay with no additional weight beyond its bandwidth.
100
    Unweighted,
101
    /// Selecting a relay for use as a hidden service introduction point
102
    HsIntro,
103
    // Note: There is no `HsRend` role, since in practice when we want to pick a
104
    // rendezvous point we use a pre-built circuit from our circuit-pool, the
105
    // last hop of which was selected with the `Middle` weight.  Fortunately,
106
    // the weighting rules for picking rendezvous points are the same as for
107
    // picking middle relays.
108
}
109

            
110
/// Description for how to weight a single kind of relay for each WeightRole.
111
#[derive(Clone, Debug, Copy)]
112
struct RelayWeight {
113
    /// How to weight this kind of relay when picking a guard relay.
114
    as_guard: u32,
115
    /// How to weight this kind of relay when picking a middle relay.
116
    as_middle: u32,
117
    /// How to weight this kind of relay when picking a exit relay.
118
    as_exit: u32,
119
    /// How to weight this kind of relay when picking a one-hop BEGIN_DIR.
120
    as_dir: u32,
121
}
122

            
123
impl std::ops::Mul<u32> for RelayWeight {
124
    type Output = Self;
125
13128
    fn mul(self, rhs: u32) -> Self {
126
13128
        RelayWeight {
127
13128
            as_guard: self.as_guard * rhs,
128
13128
            as_middle: self.as_middle * rhs,
129
13128
            as_exit: self.as_exit * rhs,
130
13128
            as_dir: self.as_dir * rhs,
131
13128
        }
132
13128
    }
133
}
134
impl std::ops::Div<u32> for RelayWeight {
135
    type Output = Self;
136
13128
    fn div(self, rhs: u32) -> Self {
137
13128
        RelayWeight {
138
13128
            as_guard: self.as_guard / rhs,
139
13128
            as_middle: self.as_middle / rhs,
140
13128
            as_exit: self.as_exit / rhs,
141
13128
            as_dir: self.as_dir / rhs,
142
13128
        }
143
13128
    }
144
}
145

            
146
impl RelayWeight {
147
    /// Return the largest weight that we give for this kind of relay.
148
    // The unwrap() is safe because array is nonempty.
149
    #[allow(clippy::unwrap_used)]
150
26256
    fn max_weight(&self) -> u32 {
151
26256
        [self.as_guard, self.as_middle, self.as_exit, self.as_dir]
152
26256
            .iter()
153
26256
            .max()
154
26256
            .copied()
155
26256
            .unwrap()
156
26256
    }
157
    /// Return the weight we should give this kind of relay's
158
    /// bandwidth for a given role.
159
7490211
    fn for_role(&self, role: WeightRole) -> u32 {
160
7490211
        match role {
161
2300152
            WeightRole::Guard => self.as_guard,
162
3076708
            WeightRole::Middle => self.as_middle,
163
1099633
            WeightRole::Exit => self.as_exit,
164
867276
            WeightRole::BeginDir => self.as_dir,
165
19906
            WeightRole::HsIntro => self.as_middle, // TODO SPEC is this right?
166
126536
            WeightRole::Unweighted => 1,
167
        }
168
7490211
    }
169
}
170

            
171
bitflags! {
172
    /// A kind of relay, for the purposes of selecting a relay by weight.
173
    ///
174
    /// Relays can have or lack the Guard flag, the Exit flag, and the
175
    /// V2Dir flag. All together, this makes 8 kinds of relays.
176
10
    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
177
    struct WeightKind: u8 {
178
        /// Flag in weightkind for Guard relays.
179
        const GUARD = 1 << 0;
180
        /// Flag in weightkind for Exit relays.
181
        const EXIT = 1 << 1;
182
        /// Flag in weightkind for V2Dir relays.
183
        const DIR = 1 << 2;
184
    }
185
}
186

            
187
impl WeightKind {
188
    /// Return the appropriate WeightKind for a relay.
189
7490209
    fn for_rs(rs: &MdConsensusRouterStatus) -> Self {
190
7490209
        let mut r = WeightKind::empty();
191
7490209
        if rs.is_flagged_guard() {
192
4764283
            r |= WeightKind::GUARD;
193
4764283
        }
194
7490209
        if rs.is_flagged_exit() {
195
4296961
            r |= WeightKind::EXIT;
196
4296961
        }
197
7490209
        if rs.is_flagged_v2dir() {
198
7489803
            r |= WeightKind::DIR;
199
7489803
        }
200
7490209
        r
201
7490209
    }
202
    /// Return the index to use for this kind of a relay within a WeightSet.
203
7490211
    fn idx(self) -> usize {
204
7490211
        self.bits() as usize
205
7490211
    }
206
}
207

            
208
/// Information derived from a consensus to use when picking relays by
209
/// weighted bandwidth.
210
56
#[derive(Debug, Clone)]
211
pub(crate) struct WeightSet {
212
    /// How to find the bandwidth to use when picking a relay by weighted
213
    /// bandwidth.
214
    ///
215
    /// (This tells us us whether to count unmeasured relays, whether
216
    /// to look at bandwidths at all, etc.)
217
    bandwidth_fn: BandwidthFn,
218
    /// Number of bits that we need to right-shift our weighted products
219
    /// so that their sum won't overflow u64::MAX.
220
    shift: u8,
221
    /// A set of RelayWeight values, indexed by [`WeightKind::idx`], used
222
    /// to weight different kinds of relays.
223
    w: [RelayWeight; 8],
224
}
225

            
226
impl WeightSet {
227
    /// Find the actual 64-bit weight to use for a given routerstatus when
228
    /// considering it for a given role.
229
    ///
230
    /// NOTE: This function _does not_ consider whether the relay in question
231
    /// actually matches the given role.  For example, if `role` is Guard
232
    /// we don't check whether or not `rs` actually has the Guard flag.
233
7490199
    pub(crate) fn weight_rs_for_role(&self, rs: &MdConsensusRouterStatus, role: WeightRole) -> u64 {
234
7490199
        self.weight_bw_for_role(WeightKind::for_rs(rs), rs.weight(), role)
235
7490199
    }
236

            
237
    /// Find the 64-bit weight to report for a relay of `kind` whose weight in
238
    /// the consensus is `relay_weight` when using it for `role`.
239
7490211
    fn weight_bw_for_role(
240
7490211
        &self,
241
7490211
        kind: WeightKind,
242
7490211
        relay_weight: &netstatus::RelayWeight,
243
7490211
        role: WeightRole,
244
7490211
    ) -> u64 {
245
7490211
        let ws = &self.w[kind.idx()];
246
7490211

            
247
7490211
        let router_bw = self.bandwidth_fn.apply(relay_weight);
248
7490211
        // Note a subtlety here: we multiply the two values _before_
249
7490211
        // we shift, to improve accuracy.  We know that this will be
250
7490211
        // safe, since the inputs are both u32, and so cannot overflow
251
7490211
        // a u64.
252
7490211
        let router_weight = u64::from(router_bw) * u64::from(ws.for_role(role));
253
7490211
        router_weight >> self.shift
254
7490211
    }
255

            
256
    /// Compute the correct WeightSet for a provided MdConsensus.
257
3280
    pub(crate) fn from_consensus(consensus: &MdConsensus, params: &NetParameters) -> Self {
258
11241
        let bandwidth_fn = pick_bandwidth_fn(consensus.c_relays().iter().map(|rs| rs.weight()));
259
3280
        let weight_scale = params.bw_weight_scale.into();
260
3280

            
261
3280
        let total_bw = consensus
262
3280
            .c_relays()
263
3280
            .iter()
264
126640
            .map(|rs| u64::from(bandwidth_fn.apply(rs.weight())))
265
3280
            .sum();
266
3280
        let p = consensus.bandwidth_weights();
267
3280

            
268
3280
        Self::from_parts(bandwidth_fn, total_bw, weight_scale, p).validate(consensus)
269
3280
    }
270

            
271
    /// Compute the correct WeightSet given a bandwidth function, a
272
    /// weight-scaling parameter, a total amount of bandwidth for all
273
    /// relays in the consensus, and a set of bandwidth parameters.
274
3282
    fn from_parts(
275
3282
        bandwidth_fn: BandwidthFn,
276
3282
        total_bw: u64,
277
3282
        weight_scale: u32,
278
3282
        p: &NetParams<i32>,
279
3282
    ) -> Self {
280
3282
        /// Find a single RelayWeight, given the names that its bandwidth
281
3282
        /// parameters have. The `g` parameter is the weight as a guard, the
282
3282
        /// `m` parameter is the weight as a middle relay, the `e` parameter is
283
3282
        /// the weight as an exit, and the `d` parameter is the weight as a
284
3282
        /// directory.
285
3282
        #[allow(clippy::many_single_char_names)]
286
13128
        fn single(p: &NetParams<i32>, g: &str, m: &str, e: &str, d: &str) -> RelayWeight {
287
13128
            RelayWeight {
288
13128
                as_guard: w_param(p, g),
289
13128
                as_middle: w_param(p, m),
290
13128
                as_exit: w_param(p, e),
291
13128
                as_dir: w_param(p, d),
292
13128
            }
293
13128
        }
294
3282

            
295
3282
        // Prevent division by zero in case we're called with a bogus
296
3282
        // input.  (That shouldn't be possible.)
297
3282
        let weight_scale = weight_scale.max(1);
298
3282

            
299
3282
        // For non-V2Dir relays, we have names for most of their weights.
300
3282
        //
301
3282
        // (There is no Wge, since we only use Guard relays as guards.  By the
302
3282
        // same logic, Wme has no reason to exist, but according to the spec it
303
3282
        // does.)
304
3282
        let w_none = single(p, "Wgm", "Wmm", "Wem", "Wbm");
305
3282
        let w_guard = single(p, "Wgg", "Wmg", "Weg", "Wbg");
306
3282
        let w_exit = single(p, "---", "Wme", "Wee", "Wbe");
307
3282
        let w_both = single(p, "Wgd", "Wmd", "Wed", "Wbd");
308
3282

            
309
3282
        // Note that the positions of the elements in this array need to
310
3282
        // match the values returned by WeightKind.as_idx().
311
3282
        let w = [
312
3282
            w_none,
313
3282
            w_guard,
314
3282
            w_exit,
315
3282
            w_both,
316
3282
            // The V2Dir values are the same as the non-V2Dir values, except
317
3282
            // each is multiplied by an additional factor.
318
3282
            //
319
3282
            // (We don't need to check for overflow here, since the
320
3282
            // authorities make sure that the inputs don't get too big.)
321
3282
            (w_none * w_param(p, "Wmb")) / weight_scale,
322
3282
            (w_guard * w_param(p, "Wgb")) / weight_scale,
323
3282
            (w_exit * w_param(p, "Web")) / weight_scale,
324
3282
            (w_both * w_param(p, "Wdb")) / weight_scale,
325
3282
        ];
326
3282

            
327
3282
        // This is the largest weight value.
328
3282
        // The unwrap() is safe because `w` is nonempty.
329
3282
        #[allow(clippy::unwrap_used)]
330
3282
        let w_max = w.iter().map(RelayWeight::max_weight).max().unwrap();
331
3282

            
332
3282
        // We want "shift" such that (total * w_max) >> shift <= u64::max
333
3282
        let shift = calculate_shift(total_bw, u64::from(w_max)) as u8;
334
3282

            
335
3282
        WeightSet {
336
3282
            bandwidth_fn,
337
3282
            shift,
338
3282
            w,
339
3282
        }
340
3282
    }
341

            
342
    /// Assert that we have correctly computed our shift values so that
343
    /// our total weighted bws do not exceed u64::MAX.
344
3280
    fn validate(self, consensus: &MdConsensus) -> Self {
345
        use WeightRole::*;
346
16400
        for role in [Guard, Middle, Exit, BeginDir, Unweighted] {
347
16400
            let _: u64 = consensus
348
16400
                .c_relays()
349
16400
                .iter()
350
633200
                .map(|rs| self.weight_rs_for_role(rs, role))
351
633200
                .fold(0_u64, |a, b| {
352
632660
                    a.checked_add(b)
353
632660
                        .expect("Incorrect relay weight calculation: total exceeded u64::MAX!")
354
633200
                });
355
16400
        }
356
3280
        self
357
3280
    }
358
}
359

            
360
/// The value to return if a weight parameter is absent.
361
///
362
/// (If there are no weights at all, then it's correct to set them all to 1,
363
/// and just use the bandwidths.  If _some_ are present and some are absent,
364
/// then the spec doesn't say what to do, but this behavior appears
365
/// reasonable.)
366
const DFLT_WEIGHT: i32 = 1;
367

            
368
/// Return the weight param named 'kwd' in p.
369
///
370
/// Returns DFLT_WEIGHT if there is no such parameter, and 0
371
/// if `kwd` is "---".
372
65640
fn w_param(p: &NetParams<i32>, kwd: &str) -> u32 {
373
65640
    if kwd == "---" {
374
3282
        0
375
    } else {
376
62358
        clamp_to_pos(*p.get(kwd).unwrap_or(&DFLT_WEIGHT))
377
    }
378
65640
}
379

            
380
/// If `inp` is less than 0, return 0.  Otherwise return `inp` as a u32.
381
62368
fn clamp_to_pos(inp: i32) -> u32 {
382
62368
    // (The spec says that we might encounter negative values here, though
383
62368
    // we never actually generate them, and don't plan to generate them.)
384
62368
    if inp < 0 {
385
4
        0
386
    } else {
387
62364
        inp as u32
388
    }
389
62368
}
390

            
391
/// Compute a 'shift' value such that `(a * b) >> shift` will be contained
392
/// inside 64 bits.
393
3290
fn calculate_shift(a: u64, b: u64) -> u32 {
394
3290
    let bits_for_product = log2_upper(a) + log2_upper(b);
395
3290
    if bits_for_product < 64 {
396
3286
        0
397
    } else {
398
4
        bits_for_product - 64
399
    }
400
3290
}
401

            
402
/// Return an upper bound for the log2 of n.
403
///
404
/// This function overestimates whenever n is a power of two, but that doesn't
405
/// much matter for the uses we're giving it here.
406
6590
fn log2_upper(n: u64) -> u32 {
407
6590
    64 - n.leading_zeros()
408
6590
}
409

            
410
#[cfg(test)]
411
mod test {
412
    // @@ begin test lint list maintained by maint/add_warning @@
413
    #![allow(clippy::bool_assert_comparison)]
414
    #![allow(clippy::clone_on_copy)]
415
    #![allow(clippy::dbg_macro)]
416
    #![allow(clippy::print_stderr)]
417
    #![allow(clippy::print_stdout)]
418
    #![allow(clippy::single_char_pattern)]
419
    #![allow(clippy::unwrap_used)]
420
    #![allow(clippy::unchecked_duration_subtraction)]
421
    #![allow(clippy::useless_vec)]
422
    #![allow(clippy::needless_pass_by_value)]
423
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
424
    use super::*;
425
    use netstatus::RelayWeight as RW;
426
    use std::net::SocketAddr;
427
    use std::time::{Duration, SystemTime};
428
    use tor_basic_utils::test_rng::testing_rng;
429
    use tor_netdoc::doc::netstatus::{Lifetime, RelayFlags, RouterStatusBuilder};
430

            
431
    #[test]
432
    fn t_clamp() {
433
        assert_eq!(clamp_to_pos(32), 32);
434
        assert_eq!(clamp_to_pos(std::i32::MAX), std::i32::MAX as u32);
435
        assert_eq!(clamp_to_pos(0), 0);
436
        assert_eq!(clamp_to_pos(-1), 0);
437
        assert_eq!(clamp_to_pos(std::i32::MIN), 0);
438
    }
439

            
440
    #[test]
441
    fn t_log2() {
442
        assert_eq!(log2_upper(std::u64::MAX), 64);
443
        assert_eq!(log2_upper(0), 0);
444
        assert_eq!(log2_upper(1), 1);
445
        assert_eq!(log2_upper(63), 6);
446
        assert_eq!(log2_upper(64), 7); // a little buggy but harmless.
447
    }
448

            
449
    #[test]
450
    fn t_calc_shift() {
451
        assert_eq!(calculate_shift(1 << 20, 1 << 20), 0);
452
        assert_eq!(calculate_shift(1 << 50, 1 << 10), 0);
453
        assert_eq!(calculate_shift(1 << 32, 1 << 33), 3);
454
        assert!(((1_u64 << 32) >> 3).checked_mul(1_u64 << 33).is_some());
455
        assert_eq!(calculate_shift(432 << 40, 7777 << 40), 38);
456
        assert!(((432_u64 << 40) >> 38)
457
            .checked_mul(7777_u64 << 40)
458
            .is_some());
459
    }
460

            
461
    #[test]
462
    fn t_pick_bwfunc() {
463
        let empty = [];
464
        assert_eq!(pick_bandwidth_fn(empty.iter()), BandwidthFn::Uniform);
465

            
466
        let all_zero = [RW::Unmeasured(0), RW::Measured(0), RW::Unmeasured(0)];
467
        assert_eq!(pick_bandwidth_fn(all_zero.iter()), BandwidthFn::Uniform);
468

            
469
        let all_unmeasured = [RW::Unmeasured(9), RW::Unmeasured(2222)];
470
        assert_eq!(
471
            pick_bandwidth_fn(all_unmeasured.iter()),
472
            BandwidthFn::IncludeUnmeasured
473
        );
474

            
475
        let some_measured = [
476
            RW::Unmeasured(10),
477
            RW::Measured(7),
478
            RW::Measured(4),
479
            RW::Unmeasured(0),
480
        ];
481
        assert_eq!(
482
            pick_bandwidth_fn(some_measured.iter()),
483
            BandwidthFn::MeasuredOnly
484
        );
485

            
486
        // This corresponds to an open question in
487
        // `pick_bandwidth_fn`, about what to do when the only nonzero
488
        // weights are unmeasured.
489
        let measured_all_zero = [RW::Unmeasured(10), RW::Measured(0)];
490
        assert_eq!(
491
            pick_bandwidth_fn(measured_all_zero.iter()),
492
            BandwidthFn::Uniform
493
        );
494
    }
495

            
496
    #[test]
497
    fn t_apply_bwfn() {
498
        use netstatus::RelayWeight::*;
499
        use BandwidthFn::*;
500

            
501
        assert_eq!(Uniform.apply(&Measured(7)), 1);
502
        assert_eq!(Uniform.apply(&Unmeasured(0)), 1);
503

            
504
        assert_eq!(IncludeUnmeasured.apply(&Measured(7)), 7);
505
        assert_eq!(IncludeUnmeasured.apply(&Unmeasured(8)), 8);
506

            
507
        assert_eq!(MeasuredOnly.apply(&Measured(9)), 9);
508
        assert_eq!(MeasuredOnly.apply(&Unmeasured(10)), 0);
509
    }
510

            
511
    // From a fairly recent Tor consensus.
512
    const TESTVEC_PARAMS: &str =
513
        "Wbd=0 Wbe=0 Wbg=4096 Wbm=10000 Wdb=10000 Web=10000 Wed=10000 Wee=10000 Weg=10000 Wem=10000 Wgb=10000 Wgd=0 Wgg=5904 Wgm=5904 Wmb=10000 Wmd=0 Wme=0 Wmg=4096 Wmm=10000";
514

            
515
    #[test]
516
    fn t_weightset_basic() {
517
        let total_bandwidth = 1_000_000_000;
518
        let params = TESTVEC_PARAMS.parse().unwrap();
519
        let ws = WeightSet::from_parts(BandwidthFn::MeasuredOnly, total_bandwidth, 10000, &params);
520

            
521
        assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
522
        assert_eq!(ws.shift, 0);
523

            
524
        assert_eq!(ws.w[0].as_guard, 5904);
525
        assert_eq!(ws.w[(WeightKind::GUARD.bits()) as usize].as_guard, 5904);
526
        assert_eq!(ws.w[(WeightKind::EXIT.bits()) as usize].as_exit, 10000);
527
        assert_eq!(
528
            ws.w[(WeightKind::EXIT | WeightKind::GUARD).bits() as usize].as_dir,
529
            0
530
        );
531
        assert_eq!(
532
            ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
533
            4096
534
        );
535
        assert_eq!(
536
            ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
537
            4096
538
        );
539

            
540
        assert_eq!(
541
            ws.weight_bw_for_role(
542
                WeightKind::GUARD | WeightKind::DIR,
543
                &RW::Unmeasured(7777),
544
                WeightRole::Guard
545
            ),
546
            0
547
        );
548

            
549
        assert_eq!(
550
            ws.weight_bw_for_role(
551
                WeightKind::GUARD | WeightKind::DIR,
552
                &RW::Measured(7777),
553
                WeightRole::Guard
554
            ),
555
            7777 * 5904
556
        );
557

            
558
        assert_eq!(
559
            ws.weight_bw_for_role(
560
                WeightKind::GUARD | WeightKind::DIR,
561
                &RW::Measured(7777),
562
                WeightRole::Middle
563
            ),
564
            7777 * 4096
565
        );
566

            
567
        assert_eq!(
568
            ws.weight_bw_for_role(
569
                WeightKind::GUARD | WeightKind::DIR,
570
                &RW::Measured(7777),
571
                WeightRole::Exit
572
            ),
573
            7777 * 10000
574
        );
575

            
576
        assert_eq!(
577
            ws.weight_bw_for_role(
578
                WeightKind::GUARD | WeightKind::DIR,
579
                &RW::Measured(7777),
580
                WeightRole::BeginDir
581
            ),
582
            7777 * 4096
583
        );
584

            
585
        assert_eq!(
586
            ws.weight_bw_for_role(
587
                WeightKind::GUARD | WeightKind::DIR,
588
                &RW::Measured(7777),
589
                WeightRole::Unweighted
590
            ),
591
            7777
592
        );
593

            
594
        // Now try those last few with routerstatuses.
595
        let rs = rs_builder()
596
            .set_flags(RelayFlags::GUARD | RelayFlags::V2DIR)
597
            .weight(RW::Measured(7777))
598
            .build()
599
            .unwrap();
600
        assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Exit), 7777 * 10000);
601
        assert_eq!(
602
            ws.weight_rs_for_role(&rs, WeightRole::BeginDir),
603
            7777 * 4096
604
        );
605
        assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Unweighted), 7777);
606
    }
607

            
608
    /// Return a routerstatus builder set up to deliver a routerstatus
609
    /// with most features disabled.
610
    fn rs_builder() -> RouterStatusBuilder<[u8; 32]> {
611
        MdConsensus::builder()
612
            .rs()
613
            .identity([9; 20].into())
614
            .add_or_port(SocketAddr::from(([127, 0, 0, 1], 9001)))
615
            .doc_digest([9; 32])
616
            .protos("".parse().unwrap())
617
            .clone()
618
    }
619

            
620
    #[test]
621
    fn weight_flags() {
622
        let rs1 = rs_builder().set_flags(RelayFlags::EXIT).build().unwrap();
623
        assert_eq!(WeightKind::for_rs(&rs1), WeightKind::EXIT);
624

            
625
        let rs1 = rs_builder().set_flags(RelayFlags::GUARD).build().unwrap();
626
        assert_eq!(WeightKind::for_rs(&rs1), WeightKind::GUARD);
627

            
628
        let rs1 = rs_builder().set_flags(RelayFlags::V2DIR).build().unwrap();
629
        assert_eq!(WeightKind::for_rs(&rs1), WeightKind::DIR);
630

            
631
        let rs1 = rs_builder().build().unwrap();
632
        assert_eq!(WeightKind::for_rs(&rs1), WeightKind::empty());
633

            
634
        let rs1 = rs_builder().set_flags(RelayFlags::all()).build().unwrap();
635
        assert_eq!(
636
            WeightKind::for_rs(&rs1),
637
            WeightKind::EXIT | WeightKind::GUARD | WeightKind::DIR
638
        );
639
    }
640

            
641
    #[test]
642
    fn weightset_from_consensus() {
643
        use rand::Rng;
644
        let now = SystemTime::now();
645
        let one_hour = Duration::new(3600, 0);
646
        let mut rng = testing_rng();
647
        let mut bld = MdConsensus::builder();
648
        bld.consensus_method(34)
649
            .lifetime(Lifetime::new(now, now + one_hour, now + 2 * one_hour).unwrap())
650
            .weights(TESTVEC_PARAMS.parse().unwrap());
651

            
652
        // We're going to add a huge amount of unmeasured bandwidth,
653
        // and a reasonable amount of  measured bandwidth.
654
        for _ in 0..10 {
655
            rs_builder()
656
                .identity(rng.gen::<[u8; 20]>().into()) // random id
657
                .weight(RW::Unmeasured(1_000_000))
658
                .set_flags(RelayFlags::GUARD | RelayFlags::EXIT)
659
                .build_into(&mut bld)
660
                .unwrap();
661
        }
662
        for n in 0..30 {
663
            rs_builder()
664
                .identity(rng.gen::<[u8; 20]>().into()) // random id
665
                .weight(RW::Measured(1_000 * n))
666
                .set_flags(RelayFlags::GUARD | RelayFlags::EXIT)
667
                .build_into(&mut bld)
668
                .unwrap();
669
        }
670

            
671
        let consensus = bld.testing_consensus().unwrap();
672
        let params = NetParameters::default();
673
        let ws = WeightSet::from_consensus(&consensus, &params);
674

            
675
        assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
676
        assert_eq!(ws.shift, 0);
677
        assert_eq!(ws.w[0].as_guard, 5904);
678
        assert_eq!(ws.w[5].as_guard, 5904);
679
        assert_eq!(ws.w[5].as_middle, 4096);
680
    }
681
}