1
//! Iterator helpers for Arti.
2

            
3
/// Iterator extension trait to implement a counting filter.
4
pub trait IteratorExt: Iterator {
5
    /// Return an iterator that contains every member of this iterator, and
6
    /// which records its progress in `count`.
7
    ///
8
    /// The values in `count` are initially set to zero.  Then, every time the
9
    /// filter considers an item, it will either increment `count.n_accepted` or
10
    /// `count.n_rejected`.
11
    ///
12
    /// Note that if the iterator is dropped before it is exhausted, the count will not
13
    /// be complete.
14
    ///
15
    /// # Examples
16
    ///
17
    /// ```
18
    /// use tor_basic_utils::iter::{IteratorExt, FilterCount};
19
    ///
20
    /// let mut count = FilterCount::default();
21
    /// let emoji : String = "Hello 🙂 World 🌏!"
22
    ///     .chars()
23
    ///     .filter_cnt(&mut count, |ch| !ch.is_ascii())
24
    ///     .collect();
25
    /// assert_eq!(emoji, "🙂🌏");
26
    /// assert_eq!(count, FilterCount { n_accepted: 2, n_rejected: 14});
27
    /// ```
28
    //
29
    // In Arti, we mostly use this iterator for reporting issues when we're
30
    // unable to find a suitable relay for some purpose: it makes it easy to
31
    // tabulate which filters in a chain of filters rejected how many of the
32
    // potential candidates.
33
1433192
    fn filter_cnt<P>(self, count: &mut FilterCount, pred: P) -> CountingFilter<'_, P, Self>
34
1433192
    where
35
1433192
        Self: Sized,
36
1433192
        P: FnMut(&Self::Item) -> bool,
37
1433192
    {
38
1433192
        *count = FilterCount::default();
39
1433192
        CountingFilter {
40
1433192
            inner: self,
41
1433192
            pred,
42
1433192
            count,
43
1433192
        }
44
1433192
    }
45
}
46

            
47
impl<I> IteratorExt for I where I: Iterator {}
48

            
49
/// A record of how many items a [`CountingFilter`] returned by
50
/// [`IteratorExt::filter_cnt`] accepted and rejected.
51
///
52
/// In `tor-guardmgr` we use this type to keep track of which filters reject which guards.
53
//
54
// SEMVER NOTE: This type has public members, is exhaustive, and is re-exposed
55
// from various error types elsewhere in arti.  Probably you should not change
56
// its members.  If you do, you will need to mark it as a breaking change
57
// everywhere that it is re-exported.
58
#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
59
#[allow(clippy::exhaustive_structs)]
60
pub struct FilterCount {
61
    /// The number of items that the filter considered and accepted.
62
    pub n_accepted: usize,
63
    /// The number of items that the filter considered and accepted.
64
    pub n_rejected: usize,
65
}
66

            
67
/// An iterator to implement [`IteratorExt::filter_cnt`].
68
pub struct CountingFilter<'a, P, I> {
69
    /// The inner iterator that we're taking items from.
70
    inner: I,
71
    /// The predicate we're using to decide which items are accepted.
72
    pred: P,
73
    /// The count of the number of items accepted and rejected so far.
74
    count: &'a mut FilterCount,
75
}
76

            
77
impl<'a, P, I> Iterator for CountingFilter<'a, P, I>
78
where
79
    P: FnMut(&I::Item) -> bool,
80
    I: Iterator,
81
{
82
    type Item = I::Item;
83

            
84
1491564
    fn next(&mut self) -> Option<Self::Item> {
85
1492511
        for item in &mut self.inner {
86
1453977
            if (self.pred)(&item) {
87
1453030
                self.count.n_accepted += 1;
88
1453030
                return Some(item);
89
947
            } else {
90
947
                self.count.n_rejected += 1;
91
947
            }
92
        }
93
38534
        None
94
1491564
    }
95
}
96

            
97
impl FilterCount {
98
    /// Return a wrapper that can be displayed as the fraction of rejected items.
99
    ///
100
    /// # Example
101
    ///
102
    /// ```
103
    /// # use tor_basic_utils::iter::{IteratorExt, FilterCount};
104
    /// let mut count = FilterCount::default();
105
    /// let sum_of_evens : u32 = (1..=10)
106
    ///     .filter_cnt(&mut count, |x| *x % 2 == 0)
107
    ///     .sum();
108
    /// assert_eq!(format!("Rejected {} as odd", count.display_frac_rejected()),
109
    ///     "Rejected 5/10 as odd".to_string());
110
    /// ```
111
1548
    pub fn display_frac_rejected(&self) -> DisplayFracRejected<'_> {
112
1548
        DisplayFracRejected(self)
113
1548
    }
114

            
115
    /// Count and return the provided boolean value.
116
    ///
117
    /// This is an alternative way to use `FilterCount` when you have to provide
118
    /// a function that takes a predicate rather than a member of an iterator
119
    /// chain.
120
    ///
121
    /// # Example
122
    ///
123
    /// ```
124
    /// # use tor_basic_utils::iter::FilterCount;
125
    /// let mut count = FilterCount::default();
126
    /// let mut emoji = "Hello 🙂 World 🌏!".to_string();
127
    /// emoji.retain(|ch| count.count(!ch.is_ascii()));
128
    /// assert_eq!(emoji, "🙂🌏");
129
    /// assert_eq!(count, FilterCount { n_accepted: 2, n_rejected: 14});
130
    /// ```
131
49032406
    pub fn count(&mut self, accept: bool) -> bool {
132
49032406
        if accept {
133
32501689
            self.n_accepted += 1;
134
32501691
        } else {
135
16530717
            self.n_rejected += 1;
136
16530717
        }
137
49032406
        accept
138
49032406
    }
139
}
140

            
141
/// Return value from [`FilterCount::display_frac_rejected`].
142
#[derive(Debug, Clone)]
143
pub struct DisplayFracRejected<'a>(&'a FilterCount);
144

            
145
impl<'a> std::fmt::Display for DisplayFracRejected<'a> {
146
1548
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147
1548
        write!(
148
1548
            f,
149
1548
            "{}/{}",
150
1548
            self.0.n_rejected,
151
1548
            self.0.n_accepted + self.0.n_rejected
152
1548
        )
153
1548
    }
154
}
155

            
156
#[cfg(test)]
157
mod test {
158
    // @@ begin test lint list maintained by maint/add_warning @@
159
    #![allow(clippy::bool_assert_comparison)]
160
    #![allow(clippy::clone_on_copy)]
161
    #![allow(clippy::dbg_macro)]
162
    #![allow(clippy::mixed_attributes_style)]
163
    #![allow(clippy::print_stderr)]
164
    #![allow(clippy::print_stdout)]
165
    #![allow(clippy::single_char_pattern)]
166
    #![allow(clippy::unwrap_used)]
167
    #![allow(clippy::unchecked_duration_subtraction)]
168
    #![allow(clippy::useless_vec)]
169
    #![allow(clippy::needless_pass_by_value)]
170
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
171
    use super::*;
172

            
173
    #[test]
174
    fn counting_filter() {
175
        let mut count = FilterCount::default();
176
        let v = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
177
        let first_even = v
178
            .iter()
179
            .filter_cnt(&mut count, |val| **val % 2 == 0)
180
            .next()
181
            .unwrap();
182
        assert_eq!(*first_even, 2);
183
        assert_eq!(count.n_accepted, 1);
184
        assert_eq!(count.n_rejected, 1);
185

            
186
        let sum_even: usize = v.iter().filter_cnt(&mut count, |val| **val % 2 == 0).sum();
187
        assert_eq!(sum_even, 20);
188
        assert_eq!(count.n_accepted, 4);
189
        assert_eq!(count.n_rejected, 5);
190
    }
191

            
192
    #[test]
193
    fn counting_with_predicates() {
194
        let mut count = FilterCount::default();
195
        let v = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
196
        let first_even = v.iter().find(|val| count.count(**val % 2 == 0)).unwrap();
197
        assert_eq!(*first_even, 2);
198
        assert_eq!(count.n_accepted, 1);
199
        assert_eq!(count.n_rejected, 1);
200

            
201
        let mut count = FilterCount::default();
202
        let sum_even: usize = v.iter().filter(|val| count.count(**val % 2 == 0)).sum();
203
        assert_eq!(sum_even, 20);
204
        assert_eq!(count.n_accepted, 4);
205
        assert_eq!(count.n_rejected, 5);
206
    }
207

            
208
    #[test]
209
    fn fooz() {}
210
}