tor_basic_utils/iter.rs
1//! Iterator helpers for Arti.
2
3/// Iterator extension trait to implement a counting filter.
4pub trait IteratorExt: Iterator {
5 /// Return an iterator that contains every member of this iterator, and
6 /// which records its progress in `count`.
7 ///
8 /// The values in `count` are initially set to zero. Then, every time the
9 /// filter considers an item, it will either increment `count.n_accepted` or
10 /// `count.n_rejected`.
11 ///
12 /// Note that if the iterator is dropped before it is exhausted, the count will not
13 /// be complete.
14 ///
15 /// # Examples
16 ///
17 /// ```
18 /// use tor_basic_utils::iter::{IteratorExt, FilterCount};
19 ///
20 /// let mut count = FilterCount::default();
21 /// let emoji : String = "Hello 🙂 World 🌏!"
22 /// .chars()
23 /// .filter_cnt(&mut count, |ch| !ch.is_ascii())
24 /// .collect();
25 /// assert_eq!(emoji, "🙂🌏");
26 /// assert_eq!(count, FilterCount { n_accepted: 2, n_rejected: 14});
27 /// ```
28 //
29 // In Arti, we mostly use this iterator for reporting issues when we're
30 // unable to find a suitable relay for some purpose: it makes it easy to
31 // tabulate which filters in a chain of filters rejected how many of the
32 // potential candidates.
33 fn filter_cnt<P>(self, count: &mut FilterCount, pred: P) -> CountingFilter<'_, P, Self>
34 where
35 Self: Sized,
36 P: FnMut(&Self::Item) -> bool,
37 {
38 *count = FilterCount::default();
39 CountingFilter {
40 inner: self,
41 pred,
42 count,
43 }
44 }
45}
46
47impl<I> IteratorExt for I where I: Iterator {}
48
49/// A record of how many items a [`CountingFilter`] returned by
50/// [`IteratorExt::filter_cnt`] accepted and rejected.
51///
52/// In `tor-guardmgr` we use this type to keep track of which filters reject which guards.
53//
54// SEMVER NOTE: This type has public members, is exhaustive, and is re-exposed
55// from various error types elsewhere in arti. Probably you should not change
56// its members. If you do, you will need to mark it as a breaking change
57// everywhere that it is re-exported.
58#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
59#[allow(clippy::exhaustive_structs)]
60pub struct FilterCount {
61 /// The number of items that the filter considered and accepted.
62 pub n_accepted: usize,
63 /// The number of items that the filter considered and accepted.
64 pub n_rejected: usize,
65}
66
67/// An iterator to implement [`IteratorExt::filter_cnt`].
68pub struct CountingFilter<'a, P, I> {
69 /// The inner iterator that we're taking items from.
70 inner: I,
71 /// The predicate we're using to decide which items are accepted.
72 pred: P,
73 /// The count of the number of items accepted and rejected so far.
74 count: &'a mut FilterCount,
75}
76
77impl<'a, P, I> Iterator for CountingFilter<'a, P, I>
78where
79 P: FnMut(&I::Item) -> bool,
80 I: Iterator,
81{
82 type Item = I::Item;
83
84 fn next(&mut self) -> Option<Self::Item> {
85 for item in &mut self.inner {
86 if (self.pred)(&item) {
87 self.count.n_accepted += 1;
88 return Some(item);
89 } else {
90 self.count.n_rejected += 1;
91 }
92 }
93 None
94 }
95}
96
97impl FilterCount {
98 /// Return a wrapper that can be displayed as the fraction of rejected items.
99 ///
100 /// # Example
101 ///
102 /// ```
103 /// # use tor_basic_utils::iter::{IteratorExt, FilterCount};
104 /// let mut count = FilterCount::default();
105 /// let sum_of_evens : u32 = (1..=10)
106 /// .filter_cnt(&mut count, |x| *x % 2 == 0)
107 /// .sum();
108 /// assert_eq!(format!("Rejected {} as odd", count.display_frac_rejected()),
109 /// "Rejected 5/10 as odd".to_string());
110 /// ```
111 pub fn display_frac_rejected(&self) -> DisplayFracRejected<'_> {
112 DisplayFracRejected(self)
113 }
114
115 /// Count and return the provided boolean value.
116 ///
117 /// This is an alternative way to use `FilterCount` when you have to provide
118 /// a function that takes a predicate rather than a member of an iterator
119 /// chain.
120 ///
121 /// # Example
122 ///
123 /// ```
124 /// # use tor_basic_utils::iter::FilterCount;
125 /// let mut count = FilterCount::default();
126 /// let mut emoji = "Hello 🙂 World 🌏!".to_string();
127 /// emoji.retain(|ch| count.count(!ch.is_ascii()));
128 /// assert_eq!(emoji, "🙂🌏");
129 /// assert_eq!(count, FilterCount { n_accepted: 2, n_rejected: 14});
130 /// ```
131 pub fn count(&mut self, accept: bool) -> bool {
132 if accept {
133 self.n_accepted += 1;
134 } else {
135 self.n_rejected += 1;
136 }
137 accept
138 }
139}
140
141/// Return value from [`FilterCount::display_frac_rejected`].
142#[derive(Debug, Clone)]
143pub struct DisplayFracRejected<'a>(&'a FilterCount);
144
145impl<'a> std::fmt::Display for DisplayFracRejected<'a> {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 write!(
148 f,
149 "{}/{}",
150 self.0.n_rejected,
151 self.0.n_accepted + self.0.n_rejected
152 )
153 }
154}
155
156#[cfg(test)]
157mod test {
158 // @@ begin test lint list maintained by maint/add_warning @@
159 #![allow(clippy::bool_assert_comparison)]
160 #![allow(clippy::clone_on_copy)]
161 #![allow(clippy::dbg_macro)]
162 #![allow(clippy::mixed_attributes_style)]
163 #![allow(clippy::print_stderr)]
164 #![allow(clippy::print_stdout)]
165 #![allow(clippy::single_char_pattern)]
166 #![allow(clippy::unwrap_used)]
167 #![allow(clippy::unchecked_duration_subtraction)]
168 #![allow(clippy::useless_vec)]
169 #![allow(clippy::needless_pass_by_value)]
170 //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
171 use super::*;
172
173 #[test]
174 fn counting_filter() {
175 let mut count = FilterCount::default();
176 let v = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
177 let first_even = v
178 .iter()
179 .filter_cnt(&mut count, |val| **val % 2 == 0)
180 .next()
181 .unwrap();
182 assert_eq!(*first_even, 2);
183 assert_eq!(count.n_accepted, 1);
184 assert_eq!(count.n_rejected, 1);
185
186 let sum_even: usize = v.iter().filter_cnt(&mut count, |val| **val % 2 == 0).sum();
187 assert_eq!(sum_even, 20);
188 assert_eq!(count.n_accepted, 4);
189 assert_eq!(count.n_rejected, 5);
190 }
191
192 #[test]
193 fn counting_with_predicates() {
194 let mut count = FilterCount::default();
195 let v = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
196 let first_even = v.iter().find(|val| count.count(**val % 2 == 0)).unwrap();
197 assert_eq!(*first_even, 2);
198 assert_eq!(count.n_accepted, 1);
199 assert_eq!(count.n_rejected, 1);
200
201 let mut count = FilterCount::default();
202 let sum_even: usize = v.iter().filter(|val| count.count(**val % 2 == 0)).sum();
203 assert_eq!(sum_even, 20);
204 assert_eq!(count.n_accepted, 4);
205 assert_eq!(count.n_rejected, 5);
206 }
207
208 #[test]
209 fn fooz() {}
210}