1use crate::params::NetParameters;
16use crate::ConsensusRelays;
17use bitflags::bitflags;
18use tor_netdoc::doc::netstatus::{self, MdConsensus, MdConsensusRouterStatus, NetParams};
19
20fn pick_bandwidth_fn<'a, I>(mut weights: I) -> BandwidthFn
23where
24 I: Clone + Iterator<Item = &'a netstatus::RelayWeight>,
25{
26 let has_measured = weights.clone().any(|w| w.is_measured());
27 let has_nonzero = weights.clone().any(|w| w.is_nonzero());
28 let has_nonzero_measured = weights.any(|w| w.is_measured() && w.is_nonzero());
29
30 if !has_nonzero {
31 BandwidthFn::Uniform
34 } else if !has_measured {
35 BandwidthFn::IncludeUnmeasured
38 } else if has_nonzero_measured {
39 BandwidthFn::MeasuredOnly
42 } else {
43 BandwidthFn::Uniform
47 }
48}
49
50#[derive(Copy, Clone, Debug, PartialEq, Eq)]
54enum BandwidthFn {
55 Uniform,
58 IncludeUnmeasured,
61 MeasuredOnly,
63}
64
65impl BandwidthFn {
66 fn apply(&self, w: &netstatus::RelayWeight) -> u32 {
69 use netstatus::RelayWeight::*;
70 use BandwidthFn::*;
71 match (self, w) {
72 (Uniform, _) => 1,
73 (IncludeUnmeasured, Unmeasured(u)) => *u,
74 (IncludeUnmeasured, Measured(m)) => *m,
75 (MeasuredOnly, Unmeasured(_)) => 0,
76 (MeasuredOnly, Measured(m)) => *m,
77 (_, _) => 0,
78 }
79 }
80}
81
82#[derive(Clone, Debug, Copy)]
89#[non_exhaustive]
90pub enum WeightRole {
91 Guard,
93 Middle,
95 Exit,
97 BeginDir,
99 Unweighted,
101 HsIntro,
103 HsRend,
105}
106
107#[derive(Clone, Debug, Copy)]
109struct RelayWeight {
110 as_guard: u32,
112 as_middle: u32,
114 as_exit: u32,
116 as_dir: u32,
118}
119
120impl std::ops::Mul<u32> for RelayWeight {
121 type Output = Self;
122 fn mul(self, rhs: u32) -> Self {
123 RelayWeight {
124 as_guard: self.as_guard * rhs,
125 as_middle: self.as_middle * rhs,
126 as_exit: self.as_exit * rhs,
127 as_dir: self.as_dir * rhs,
128 }
129 }
130}
131impl std::ops::Div<u32> for RelayWeight {
132 type Output = Self;
133 fn div(self, rhs: u32) -> Self {
134 RelayWeight {
135 as_guard: self.as_guard / rhs,
136 as_middle: self.as_middle / rhs,
137 as_exit: self.as_exit / rhs,
138 as_dir: self.as_dir / rhs,
139 }
140 }
141}
142
143impl RelayWeight {
144 #[allow(clippy::unwrap_used)]
147 fn max_weight(&self) -> u32 {
148 [self.as_guard, self.as_middle, self.as_exit, self.as_dir]
149 .iter()
150 .max()
151 .copied()
152 .unwrap()
153 }
154 fn for_role(&self, role: WeightRole) -> u32 {
157 match role {
158 WeightRole::Guard => self.as_guard,
159 WeightRole::Middle => self.as_middle,
160 WeightRole::Exit => self.as_exit,
161 WeightRole::BeginDir => self.as_dir,
162 WeightRole::HsIntro => self.as_middle, WeightRole::HsRend => self.as_middle, WeightRole::Unweighted => 1,
165 }
166 }
167}
168
169bitflags! {
170 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
175 struct WeightKind: u8 {
176 const GUARD = 1 << 0;
178 const EXIT = 1 << 1;
180 const DIR = 1 << 2;
182 }
183}
184
185impl WeightKind {
186 fn for_rs(rs: &MdConsensusRouterStatus) -> Self {
188 let mut r = WeightKind::empty();
189 if rs.is_flagged_guard() {
190 r |= WeightKind::GUARD;
191 }
192 if rs.is_flagged_exit() {
193 r |= WeightKind::EXIT;
194 }
195 if rs.is_flagged_v2dir() {
196 r |= WeightKind::DIR;
197 }
198 r
199 }
200 fn idx(self) -> usize {
202 self.bits() as usize
203 }
204}
205
206#[derive(Debug, Clone)]
209pub(crate) struct WeightSet {
210 bandwidth_fn: BandwidthFn,
216 shift: u8,
227 w: [RelayWeight; 8],
230}
231
232impl WeightSet {
233 pub(crate) fn weight_rs_for_role(&self, rs: &MdConsensusRouterStatus, role: WeightRole) -> u64 {
240 self.weight_bw_for_role(WeightKind::for_rs(rs), rs.weight(), role)
241 }
242
243 fn weight_bw_for_role(
246 &self,
247 kind: WeightKind,
248 relay_weight: &netstatus::RelayWeight,
249 role: WeightRole,
250 ) -> u64 {
251 let ws = &self.w[kind.idx()];
252
253 let router_bw = self.bandwidth_fn.apply(relay_weight);
254 let router_weight = u64::from(router_bw) * u64::from(ws.for_role(role));
259 router_weight >> self.shift
260 }
261
262 pub(crate) fn from_consensus(consensus: &MdConsensus, params: &NetParameters) -> Self {
264 let bandwidth_fn = pick_bandwidth_fn(consensus.c_relays().iter().map(|rs| rs.weight()));
265 let weight_scale = params.bw_weight_scale.into();
266
267 let total_bw = consensus
268 .c_relays()
269 .iter()
270 .map(|rs| u64::from(bandwidth_fn.apply(rs.weight())))
271 .sum();
272 let p = consensus.bandwidth_weights();
273
274 Self::from_parts(bandwidth_fn, total_bw, weight_scale, p).validate(consensus)
275 }
276
277 fn from_parts(
281 bandwidth_fn: BandwidthFn,
282 total_bw: u64,
283 weight_scale: u32,
284 p: &NetParams<i32>,
285 ) -> Self {
286 #[allow(clippy::many_single_char_names)]
292 fn single(p: &NetParams<i32>, g: &str, m: &str, e: &str, d: &str) -> RelayWeight {
293 RelayWeight {
294 as_guard: w_param(p, g),
295 as_middle: w_param(p, m),
296 as_exit: w_param(p, e),
297 as_dir: w_param(p, d),
298 }
299 }
300
301 let weight_scale = weight_scale.max(1);
304
305 let w_none = single(p, "Wgm", "Wmm", "Wem", "Wbm");
311 let w_guard = single(p, "Wgg", "Wmg", "Weg", "Wbg");
312 let w_exit = single(p, "---", "Wme", "Wee", "Wbe");
313 let w_both = single(p, "Wgd", "Wmd", "Wed", "Wbd");
314
315 let w = [
318 w_none,
319 w_guard,
320 w_exit,
321 w_both,
322 (w_none * w_param(p, "Wmb")) / weight_scale,
328 (w_guard * w_param(p, "Wgb")) / weight_scale,
329 (w_exit * w_param(p, "Web")) / weight_scale,
330 (w_both * w_param(p, "Wdb")) / weight_scale,
331 ];
332
333 #[allow(clippy::unwrap_used)]
336 let w_max = w.iter().map(RelayWeight::max_weight).max().unwrap();
337
338 let shift = calculate_shift(total_bw, u64::from(w_max)) as u8;
340
341 WeightSet {
342 bandwidth_fn,
343 shift,
344 w,
345 }
346 }
347
348 fn validate(self, consensus: &MdConsensus) -> Self {
351 use WeightRole::*;
352 for role in [Guard, Middle, Exit, BeginDir, Unweighted] {
353 let _: u64 = consensus
354 .c_relays()
355 .iter()
356 .map(|rs| self.weight_rs_for_role(rs, role))
357 .fold(0_u64, |a, b| {
358 a.checked_add(b)
359 .expect("Incorrect relay weight calculation: total exceeded u64::MAX!")
360 });
361 }
362 self
363 }
364}
365
366const DFLT_WEIGHT: i32 = 1;
373
374fn w_param(p: &NetParams<i32>, kwd: &str) -> u32 {
379 if kwd == "---" {
380 0
381 } else {
382 clamp_to_pos(*p.get(kwd).unwrap_or(&DFLT_WEIGHT))
383 }
384}
385
386fn clamp_to_pos(inp: i32) -> u32 {
388 if inp < 0 {
391 0
392 } else {
393 inp as u32
394 }
395}
396
397fn calculate_shift(a: u64, b: u64) -> u32 {
400 let bits_for_product = log2_upper(a) + log2_upper(b);
401 bits_for_product.saturating_sub(64)
402}
403
404fn log2_upper(n: u64) -> u32 {
409 64 - n.leading_zeros()
410}
411
412#[cfg(test)]
413mod test {
414 #![allow(clippy::bool_assert_comparison)]
416 #![allow(clippy::clone_on_copy)]
417 #![allow(clippy::dbg_macro)]
418 #![allow(clippy::mixed_attributes_style)]
419 #![allow(clippy::print_stderr)]
420 #![allow(clippy::print_stdout)]
421 #![allow(clippy::single_char_pattern)]
422 #![allow(clippy::unwrap_used)]
423 #![allow(clippy::unchecked_duration_subtraction)]
424 #![allow(clippy::useless_vec)]
425 #![allow(clippy::needless_pass_by_value)]
426 use super::*;
428 use netstatus::RelayWeight as RW;
429 use std::net::SocketAddr;
430 use std::time::{Duration, SystemTime};
431 use tor_basic_utils::test_rng::testing_rng;
432 use tor_netdoc::doc::netstatus::{Lifetime, RelayFlags, RouterStatusBuilder};
433
434 #[test]
435 fn t_clamp() {
436 assert_eq!(clamp_to_pos(32), 32);
437 assert_eq!(clamp_to_pos(i32::MAX), i32::MAX as u32);
438 assert_eq!(clamp_to_pos(0), 0);
439 assert_eq!(clamp_to_pos(-1), 0);
440 assert_eq!(clamp_to_pos(i32::MIN), 0);
441 }
442
443 #[test]
444 fn t_log2() {
445 assert_eq!(log2_upper(u64::MAX), 64);
446 assert_eq!(log2_upper(0), 0);
447 assert_eq!(log2_upper(1), 1);
448 assert_eq!(log2_upper(63), 6);
449 assert_eq!(log2_upper(64), 7); }
451
452 #[test]
453 fn t_calc_shift() {
454 assert_eq!(calculate_shift(1 << 20, 1 << 20), 0);
455 assert_eq!(calculate_shift(1 << 50, 1 << 10), 0);
456 assert_eq!(calculate_shift(1 << 32, 1 << 33), 3);
457 assert!(((1_u64 << 32) >> 3).checked_mul(1_u64 << 33).is_some());
458 assert_eq!(calculate_shift(432 << 40, 7777 << 40), 38);
459 assert!(((432_u64 << 40) >> 38)
460 .checked_mul(7777_u64 << 40)
461 .is_some());
462 }
463
464 #[test]
465 fn t_pick_bwfunc() {
466 let empty = [];
467 assert_eq!(pick_bandwidth_fn(empty.iter()), BandwidthFn::Uniform);
468
469 let all_zero = [RW::Unmeasured(0), RW::Measured(0), RW::Unmeasured(0)];
470 assert_eq!(pick_bandwidth_fn(all_zero.iter()), BandwidthFn::Uniform);
471
472 let all_unmeasured = [RW::Unmeasured(9), RW::Unmeasured(2222)];
473 assert_eq!(
474 pick_bandwidth_fn(all_unmeasured.iter()),
475 BandwidthFn::IncludeUnmeasured
476 );
477
478 let some_measured = [
479 RW::Unmeasured(10),
480 RW::Measured(7),
481 RW::Measured(4),
482 RW::Unmeasured(0),
483 ];
484 assert_eq!(
485 pick_bandwidth_fn(some_measured.iter()),
486 BandwidthFn::MeasuredOnly
487 );
488
489 let measured_all_zero = [RW::Unmeasured(10), RW::Measured(0)];
493 assert_eq!(
494 pick_bandwidth_fn(measured_all_zero.iter()),
495 BandwidthFn::Uniform
496 );
497 }
498
499 #[test]
500 fn t_apply_bwfn() {
501 use netstatus::RelayWeight::*;
502 use BandwidthFn::*;
503
504 assert_eq!(Uniform.apply(&Measured(7)), 1);
505 assert_eq!(Uniform.apply(&Unmeasured(0)), 1);
506
507 assert_eq!(IncludeUnmeasured.apply(&Measured(7)), 7);
508 assert_eq!(IncludeUnmeasured.apply(&Unmeasured(8)), 8);
509
510 assert_eq!(MeasuredOnly.apply(&Measured(9)), 9);
511 assert_eq!(MeasuredOnly.apply(&Unmeasured(10)), 0);
512 }
513
514 const TESTVEC_PARAMS: &str =
516 "Wbd=0 Wbe=0 Wbg=4096 Wbm=10000 Wdb=10000 Web=10000 Wed=10000 Wee=10000 Weg=10000 Wem=10000 Wgb=10000 Wgd=0 Wgg=5904 Wgm=5904 Wmb=10000 Wmd=0 Wme=0 Wmg=4096 Wmm=10000";
517
518 #[test]
519 fn t_weightset_basic() {
520 let total_bandwidth = 1_000_000_000;
521 let params = TESTVEC_PARAMS.parse().unwrap();
522 let ws = WeightSet::from_parts(BandwidthFn::MeasuredOnly, total_bandwidth, 10000, ¶ms);
523
524 assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
525 assert_eq!(ws.shift, 0);
526
527 assert_eq!(ws.w[0].as_guard, 5904);
528 assert_eq!(ws.w[(WeightKind::GUARD.bits()) as usize].as_guard, 5904);
529 assert_eq!(ws.w[(WeightKind::EXIT.bits()) as usize].as_exit, 10000);
530 assert_eq!(
531 ws.w[(WeightKind::EXIT | WeightKind::GUARD).bits() as usize].as_dir,
532 0
533 );
534 assert_eq!(
535 ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
536 4096
537 );
538 assert_eq!(
539 ws.w[(WeightKind::GUARD | WeightKind::DIR).bits() as usize].as_dir,
540 4096
541 );
542
543 assert_eq!(
544 ws.weight_bw_for_role(
545 WeightKind::GUARD | WeightKind::DIR,
546 &RW::Unmeasured(7777),
547 WeightRole::Guard
548 ),
549 0
550 );
551
552 assert_eq!(
553 ws.weight_bw_for_role(
554 WeightKind::GUARD | WeightKind::DIR,
555 &RW::Measured(7777),
556 WeightRole::Guard
557 ),
558 7777 * 5904
559 );
560
561 assert_eq!(
562 ws.weight_bw_for_role(
563 WeightKind::GUARD | WeightKind::DIR,
564 &RW::Measured(7777),
565 WeightRole::Middle
566 ),
567 7777 * 4096
568 );
569
570 assert_eq!(
571 ws.weight_bw_for_role(
572 WeightKind::GUARD | WeightKind::DIR,
573 &RW::Measured(7777),
574 WeightRole::Exit
575 ),
576 7777 * 10000
577 );
578
579 assert_eq!(
580 ws.weight_bw_for_role(
581 WeightKind::GUARD | WeightKind::DIR,
582 &RW::Measured(7777),
583 WeightRole::BeginDir
584 ),
585 7777 * 4096
586 );
587
588 assert_eq!(
589 ws.weight_bw_for_role(
590 WeightKind::GUARD | WeightKind::DIR,
591 &RW::Measured(7777),
592 WeightRole::Unweighted
593 ),
594 7777
595 );
596
597 let rs = rs_builder()
599 .set_flags(RelayFlags::GUARD | RelayFlags::V2DIR)
600 .weight(RW::Measured(7777))
601 .build()
602 .unwrap();
603 assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Exit), 7777 * 10000);
604 assert_eq!(
605 ws.weight_rs_for_role(&rs, WeightRole::BeginDir),
606 7777 * 4096
607 );
608 assert_eq!(ws.weight_rs_for_role(&rs, WeightRole::Unweighted), 7777);
609 }
610
611 fn rs_builder() -> RouterStatusBuilder<[u8; 32]> {
614 MdConsensus::builder()
615 .rs()
616 .identity([9; 20].into())
617 .add_or_port(SocketAddr::from(([127, 0, 0, 1], 9001)))
618 .doc_digest([9; 32])
619 .protos("".parse().unwrap())
620 .clone()
621 }
622
623 #[test]
624 fn weight_flags() {
625 let rs1 = rs_builder().set_flags(RelayFlags::EXIT).build().unwrap();
626 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::EXIT);
627
628 let rs1 = rs_builder().set_flags(RelayFlags::GUARD).build().unwrap();
629 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::GUARD);
630
631 let rs1 = rs_builder().set_flags(RelayFlags::V2DIR).build().unwrap();
632 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::DIR);
633
634 let rs1 = rs_builder().build().unwrap();
635 assert_eq!(WeightKind::for_rs(&rs1), WeightKind::empty());
636
637 let rs1 = rs_builder().set_flags(RelayFlags::all()).build().unwrap();
638 assert_eq!(
639 WeightKind::for_rs(&rs1),
640 WeightKind::EXIT | WeightKind::GUARD | WeightKind::DIR
641 );
642 }
643
644 #[test]
645 fn weightset_from_consensus() {
646 use rand::Rng;
647 let now = SystemTime::now();
648 let one_hour = Duration::new(3600, 0);
649 let mut rng = testing_rng();
650 let mut bld = MdConsensus::builder();
651 bld.consensus_method(34)
652 .lifetime(Lifetime::new(now, now + one_hour, now + 2 * one_hour).unwrap())
653 .weights(TESTVEC_PARAMS.parse().unwrap());
654
655 for _ in 0..10 {
658 rs_builder()
659 .identity(rng.random::<[u8; 20]>().into()) .weight(RW::Unmeasured(1_000_000))
661 .set_flags(RelayFlags::GUARD | RelayFlags::EXIT)
662 .build_into(&mut bld)
663 .unwrap();
664 }
665 for n in 0..30 {
666 rs_builder()
667 .identity(rng.random::<[u8; 20]>().into()) .weight(RW::Measured(1_000 * n))
669 .set_flags(RelayFlags::GUARD | RelayFlags::EXIT)
670 .build_into(&mut bld)
671 .unwrap();
672 }
673
674 let consensus = bld.testing_consensus().unwrap();
675 let params = NetParameters::default();
676 let ws = WeightSet::from_consensus(&consensus, ¶ms);
677
678 assert_eq!(ws.bandwidth_fn, BandwidthFn::MeasuredOnly);
679 assert_eq!(ws.shift, 0);
680 assert_eq!(ws.w[0].as_guard, 5904);
681 assert_eq!(ws.w[5].as_guard, 5904);
682 assert_eq!(ws.w[5].as_middle, 4096);
683 }
684}