1
//! Implement a set of RelayId.
2

            
3
use std::collections::HashSet;
4

            
5
use serde::de::Visitor;
6
use tor_llcrypto::pk::{ed25519::Ed25519Identity, rsa::RsaIdentity};
7

            
8
use crate::{RelayId, RelayIdRef};
9

            
10
/// A set of relay identities, backed by `HashSet`.
11
///
12
/// # Note
13
///
14
/// I'd rather use `HashSet` entirely, but that doesn't let us index by
15
/// RelayIdRef.
16
#[derive(Clone, Debug, Default, Eq, PartialEq)]
17
pub struct RelayIdSet {
18
    /// The Ed25519 members of this set.
19
    ed25519: HashSet<Ed25519Identity>,
20
    /// The RSA members of this set.
21
    rsa: HashSet<RsaIdentity>,
22
}
23

            
24
impl RelayIdSet {
25
    /// Construct a new empty RelayIdSet.
26
2186445
    pub fn new() -> Self {
27
2186445
        Self::default()
28
2186445
    }
29

            
30
    /// Insert `key` into this set.  
31
    ///
32
    /// Return true if it was not already there.
33
249830
    pub fn insert<T: Into<RelayId>>(&mut self, key: T) -> bool {
34
249830
        let key: RelayId = key.into();
35
249830
        match key {
36
124464
            RelayId::Ed25519(key) => self.ed25519.insert(key),
37
125366
            RelayId::Rsa(key) => self.rsa.insert(key),
38
        }
39
249830
    }
40

            
41
    /// Remove `key` from the set.
42
    ///
43
    /// Return true if `key` was present.
44
8
    pub fn remove<'a, T: Into<RelayIdRef<'a>>>(&mut self, key: T) -> bool {
45
8
        let key: RelayIdRef<'a> = key.into();
46
8
        match key {
47
4
            RelayIdRef::Ed25519(key) => self.ed25519.remove(key),
48
4
            RelayIdRef::Rsa(key) => self.rsa.remove(key),
49
        }
50
8
    }
51

            
52
    /// Return true if `key` is a member of this set.
53
40764868
    pub fn contains<'a, T: Into<RelayIdRef<'a>>>(&self, key: T) -> bool {
54
40764868
        let key: RelayIdRef<'a> = key.into();
55
40764868
        match key {
56
20435568
            RelayIdRef::Ed25519(key) => self.ed25519.contains(key),
57
20329300
            RelayIdRef::Rsa(key) => self.rsa.contains(key),
58
        }
59
40764868
    }
60

            
61
    /// Return an iterator over the members of this set.
62
    ///
63
    /// The ordering of the iterator is undefined; do not rely on it.
64
1073380
    pub fn iter(&self) -> impl Iterator<Item = RelayIdRef<'_>> {
65
1073380
        self.ed25519
66
1073380
            .iter()
67
1073584
            .map(|id| id.into())
68
1073604
            .chain(self.rsa.iter().map(|id| id.into()))
69
1073380
    }
70

            
71
    /// Return the number of keys in this set.
72
94
    pub fn len(&self) -> usize {
73
94
        self.ed25519.len() + self.rsa.len()
74
94
    }
75

            
76
    /// Return true if there are not keys in this set.
77
264
    pub fn is_empty(&self) -> bool {
78
264
        self.ed25519.is_empty() && self.rsa.is_empty()
79
264
    }
80
}
81

            
82
impl<ID: Into<RelayId>> Extend<ID> for RelayIdSet {
83
1072539
    fn extend<T: IntoIterator<Item = ID>>(&mut self, iter: T) {
84
1287419
        for item in iter {
85
214880
            self.insert(item);
86
214880
        }
87
1072539
    }
88
}
89

            
90
impl FromIterator<RelayId> for RelayIdSet {
91
23931
    fn from_iter<T: IntoIterator<Item = RelayId>>(iter: T) -> Self {
92
23931
        let mut set = RelayIdSet::new();
93
23931
        set.extend(iter);
94
23931
        set
95
23931
    }
96
}
97

            
98
impl serde::Serialize for RelayIdSet {
99
6
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
100
6
    where
101
6
        S: serde::Serializer,
102
6
    {
103
6
        serializer.collect_seq(self.iter())
104
6
    }
105
}
106

            
107
impl<'de> serde::Deserialize<'de> for RelayIdSet {
108
12
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
109
12
    where
110
12
        D: serde::Deserializer<'de>,
111
12
    {
112
        /// A serde visitor to deserialize a sequence of RelayIds into a
113
        /// RelayIdSet.
114
        struct IdSetVisitor;
115
        impl<'de> Visitor<'de> for IdSetVisitor {
116
            type Value = RelayIdSet;
117

            
118
            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
119
                write!(f, "a list of relay identities")
120
            }
121

            
122
12
            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
123
12
            where
124
12
                A: serde::de::SeqAccess<'de>,
125
12
            {
126
12
                let mut set = RelayIdSet::new();
127
20
                while let Some(key) = seq.next_element::<RelayId>()? {
128
8
                    set.insert(key);
129
8
                }
130
12
                Ok(set)
131
12
            }
132
        }
133
12
        deserializer.deserialize_seq(IdSetVisitor)
134
12
    }
135
}
136

            
137
#[cfg(test)]
138
mod test {
139
    // @@ begin test lint list maintained by maint/add_warning @@
140
    #![allow(clippy::bool_assert_comparison)]
141
    #![allow(clippy::clone_on_copy)]
142
    #![allow(clippy::dbg_macro)]
143
    #![allow(clippy::mixed_attributes_style)]
144
    #![allow(clippy::print_stderr)]
145
    #![allow(clippy::print_stdout)]
146
    #![allow(clippy::single_char_pattern)]
147
    #![allow(clippy::unwrap_used)]
148
    #![allow(clippy::unchecked_duration_subtraction)]
149
    #![allow(clippy::useless_vec)]
150
    #![allow(clippy::needless_pass_by_value)]
151
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
152

            
153
    use super::*;
154
    use hex_literal::hex;
155
    use serde_test::{assert_tokens, Token};
156

            
157
    #[test]
158
    fn basic_usage() {
159
        #![allow(clippy::cognitive_complexity)]
160
        let rsa1 = RsaIdentity::from(hex!("42656c6f7665642c207768617420617265206e61"));
161
        let rsa2 = RsaIdentity::from(hex!("6d657320627574206169723f43686f6f73652074"));
162
        let rsa3 = RsaIdentity::from(hex!("686f752077686174657665722073756974732074"));
163

            
164
        let ed1 = Ed25519Identity::from(hex!(
165
            "6865206c696e653a43616c6c206d652053617070686f2c2063616c6c206d6520"
166
        ));
167
        let ed2 = Ed25519Identity::from(hex!(
168
            "43686c6f7269732c2043616c6c206d65204c616c6167652c206f7220446f7269"
169
        ));
170
        let ed3 = Ed25519Identity::from(hex!(
171
            "732c204f6e6c792c206f6e6c792c2063616c6c206d65207468696e652e000000"
172
        ));
173

            
174
        let mut set = RelayIdSet::new();
175
        assert_eq!(set.is_empty(), true);
176
        assert_eq!(set.len(), 0);
177

            
178
        set.insert(rsa1);
179
        set.insert(rsa2);
180
        set.insert(ed1);
181

            
182
        assert_eq!(set.is_empty(), false);
183
        assert_eq!(set.len(), 3);
184
        assert_eq!(set.contains(&rsa1), true);
185
        assert_eq!(set.contains(&rsa2), true);
186
        assert_eq!(set.contains(&rsa3), false);
187
        assert_eq!(set.contains(&ed1), true);
188
        assert_eq!(set.contains(&ed2), false);
189
        assert_eq!(set.contains(&ed3), false);
190

            
191
        let contents: HashSet<_> = set.iter().collect();
192
        assert_eq!(contents.len(), set.len());
193
        assert!(contents.contains(&RelayIdRef::from(&rsa1)));
194
        assert!(contents.contains(&RelayIdRef::from(&rsa2)));
195
        assert!(contents.contains(&RelayIdRef::from(&ed1)));
196

            
197
        assert_eq!(set.remove(&ed2), false);
198
        assert_eq!(set.remove(&ed1), true);
199
        assert_eq!(set.remove(&rsa3), false);
200
        assert_eq!(set.remove(&rsa1), true);
201
        assert_eq!(set.is_empty(), false);
202
        assert_eq!(set.len(), 1);
203
        assert_eq!(set.contains(&ed1), false);
204
        assert_eq!(set.contains(&rsa1), false);
205
        assert_eq!(set.contains(&rsa2), true);
206

            
207
        let contents2: Vec<_> = set.iter().collect();
208
        assert_eq!(contents2, vec![RelayIdRef::from(&rsa2)]);
209

            
210
        let set2: RelayIdSet = set.iter().map(|id| id.to_owned()).collect();
211
        assert_eq!(set, set2);
212

            
213
        let mut set3 = RelayIdSet::new();
214
        set3.extend(set.iter().map(|id| id.to_owned()));
215
        assert_eq!(set2, set3);
216
    }
217

            
218
    #[test]
219
    fn serde_empty() {
220
        let set = RelayIdSet::new();
221

            
222
        assert_tokens(&set, &[Token::Seq { len: Some(0) }, Token::SeqEnd]);
223
    }
224

            
225
    #[test]
226
    fn serde_singleton_rsa() {
227
        let mut set = RelayIdSet::new();
228
        set.insert(RsaIdentity::from(hex!(
229
            "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
230
        )));
231

            
232
        assert_tokens(
233
            &set,
234
            &[
235
                Token::Seq { len: Some(1) },
236
                Token::Str("$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"),
237
                Token::SeqEnd,
238
            ],
239
        );
240
    }
241

            
242
    #[test]
243
    fn serde_singleton_ed25519() {
244
        let mut set = RelayIdSet::new();
245
        set.insert(Ed25519Identity::from(hex!(
246
            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
247
        )));
248

            
249
        assert_tokens(
250
            &set,
251
            &[
252
                Token::Seq { len: Some(1) },
253
                Token::String("ed25519:u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7u7s"),
254
                Token::SeqEnd,
255
            ],
256
        );
257
    }
258
}