1use tor_basic_utils::{n_key_list, n_key_set};
5use tor_llcrypto::pk::ed25519::Ed25519Identity;
6use tor_llcrypto::pk::rsa::RsaIdentity;
7
8use crate::{HasRelayIds, RelayIdRef};
9
10n_key_list! {
11 #[derive(Clone, Debug)]
23 pub struct[H:HasRelayIds] ListByRelayIds[H] for H
24 {
25 (Option) rsa: RsaIdentity { rsa_identity() },
26 (Option) ed25519: Ed25519Identity { ed_identity() },
27 }
28}
29
30n_key_set! {
31 #[derive(Clone, Debug)]
44 pub struct[H:HasRelayIds] ByRelayIds[H] for H
45 {
46 (Option) rsa: RsaIdentity { rsa_identity() },
47 (Option) ed25519: Ed25519Identity { ed_identity() },
48 }
49}
50
51impl<H: HasRelayIds> ByRelayIds<H> {
52 pub fn by_id<'a, T>(&self, key: T) -> Option<&H>
54 where
55 T: Into<RelayIdRef<'a>>,
56 {
57 match key.into() {
58 RelayIdRef::Ed25519(ed) => self.by_ed25519(ed),
59 RelayIdRef::Rsa(rsa) => self.by_rsa(rsa),
60 }
61 }
62
63 pub fn remove_by_id<'a, T>(&mut self, key: T) -> Option<H>
65 where
66 T: Into<RelayIdRef<'a>>,
67 {
68 match key.into() {
69 RelayIdRef::Ed25519(ed) => self.remove_by_ed25519(ed),
70 RelayIdRef::Rsa(rsa) => self.remove_by_rsa(rsa),
71 }
72 }
73
74 pub fn modify_by_id<'a, T, F>(&mut self, key: T, func: F) -> Vec<H>
78 where
79 T: Into<RelayIdRef<'a>>,
80 F: FnOnce(&mut H),
81 {
82 match key.into() {
83 RelayIdRef::Ed25519(ed) => self.modify_by_ed25519(ed, func),
84 RelayIdRef::Rsa(rsa) => self.modify_by_rsa(rsa, func),
85 }
86 }
87
88 pub fn by_all_ids<T>(&self, key: &T) -> Option<&H>
93 where
94 T: HasRelayIds,
95 {
96 let any_id = key.identities().next()?;
97 self.by_id(any_id)
98 .filter(|val| val.has_all_relay_ids_from(key))
99 }
100
101 pub fn modify_by_all_ids<T, F>(&mut self, key: &T, func: F) -> Vec<H>
106 where
107 T: HasRelayIds,
108 F: FnOnce(&mut H),
109 {
110 let any_id = match key.identities().next() {
111 Some(id) => id,
112 None => return Vec::new(),
113 };
114 self.modify_by_id(any_id, |val| {
115 if val.has_all_relay_ids_from(key) {
116 func(val);
117 }
118 })
119 }
120
121 pub fn remove_exact<T>(&mut self, key: &T) -> Option<H>
124 where
125 T: HasRelayIds,
126 {
127 let any_id = key.identities().next()?;
128 if self
129 .by_id(any_id)
130 .filter(|ent| ent.same_relay_ids(key))
131 .is_some()
132 {
133 self.remove_by_id(any_id)
134 } else {
135 None
136 }
137 }
138
139 pub fn remove_by_all_ids<T>(&mut self, key: &T) -> Option<H>
143 where
144 T: HasRelayIds,
145 {
146 let any_id = key.identities().next()?;
147 if self
148 .by_id(any_id)
149 .filter(|ent| ent.has_all_relay_ids_from(key))
150 .is_some()
151 {
152 self.remove_by_id(any_id)
153 } else {
154 None
155 }
156 }
157
158 pub fn all_overlapping<T>(&self, key: &T) -> Vec<&H>
163 where
164 T: HasRelayIds,
165 {
166 use by_address::ByAddress;
167 use std::collections::HashSet;
168
169 let mut items: HashSet<ByAddress<&H>> = HashSet::new();
170
171 for ident in key.identities() {
172 if let Some(found) = self.by_id(ident) {
173 items.insert(ByAddress(found));
174 }
175 }
176
177 items.into_iter().map(|by_addr| by_addr.0).collect()
178 }
179}
180
181impl<H: HasRelayIds> ListByRelayIds<H> {
182 pub fn by_id<'a, T>(&self, key: T) -> ListByRelayIdsIter<H>
184 where
185 T: Into<RelayIdRef<'a>>,
186 {
187 match key.into() {
188 RelayIdRef::Ed25519(ed) => self.by_ed25519(ed),
189 RelayIdRef::Rsa(rsa) => self.by_rsa(rsa),
190 }
191 }
192
193 pub fn by_all_ids<'a>(&'a self, key: &'a impl HasRelayIds) -> impl Iterator<Item = &'a H> + 'a {
197 key.identities()
198 .next()
199 .map_or_else(Default::default, |id| self.by_id(id))
200 .filter(|val| val.has_all_relay_ids_from(key))
201 }
202
203 pub fn all_overlapping<T>(&self, key: &T) -> Vec<&H>
208 where
209 T: HasRelayIds,
210 {
211 use by_address::ByAddress;
212 use std::collections::HashSet;
213
214 let mut items: HashSet<ByAddress<&H>> = HashSet::new();
215
216 for ident in key.identities() {
217 for found in self.by_id(ident) {
218 items.insert(ByAddress(found));
219 }
220 }
221
222 items.into_iter().map(|by_addr| by_addr.0).collect()
223 }
224
225 pub fn all_subset<T>(&self, key: &T) -> Vec<&H>
231 where
232 T: HasRelayIds,
233 {
234 use by_address::ByAddress;
235 use std::collections::HashSet;
236
237 let mut items: HashSet<ByAddress<&H>> = HashSet::new();
238
239 for ident in key.identities() {
240 for found in self.by_id(ident) {
241 if key.has_all_relay_ids_from(found) {
243 items.insert(ByAddress(found));
244 }
245 }
246 }
247
248 items.into_iter().map(|by_addr| by_addr.0).collect()
249 }
250
251 pub fn remove_by_id<'a, T>(&mut self, key: T, filter: impl FnMut(&H) -> bool) -> Vec<H>
253 where
254 T: Into<RelayIdRef<'a>>,
255 {
256 match key.into() {
257 RelayIdRef::Ed25519(ed) => self.remove_by_ed25519(ed, filter),
258 RelayIdRef::Rsa(rsa) => self.remove_by_rsa(rsa, filter),
259 }
260 }
261
262 pub fn remove_exact<T>(&mut self, key: &T) -> Vec<H>
265 where
266 T: HasRelayIds,
267 {
268 let Some(id) = key.identities().next() else {
269 return Vec::new();
270 };
271
272 self.remove_by_id(id, |val| val.same_relay_ids(key))
273 }
274
275 pub fn remove_by_all_ids<T>(&mut self, key: &T) -> Vec<H>
279 where
280 T: HasRelayIds,
281 {
282 let Some(id) = key.identities().next() else {
283 return Vec::new();
284 };
285
286 self.remove_by_id(id, |val| val.has_all_relay_ids_from(key))
287 }
288}
289
290pub use tor_basic_utils::n_key_list::Error as ListByRelayIdsError;
291pub use tor_basic_utils::n_key_set::Error as ByRelayIdsError;
292
293#[cfg(test)]
294mod test {
295 #![allow(clippy::bool_assert_comparison)]
297 #![allow(clippy::clone_on_copy)]
298 #![allow(clippy::dbg_macro)]
299 #![allow(clippy::mixed_attributes_style)]
300 #![allow(clippy::print_stderr)]
301 #![allow(clippy::print_stdout)]
302 #![allow(clippy::single_char_pattern)]
303 #![allow(clippy::unwrap_used)]
304 #![allow(clippy::unchecked_duration_subtraction)]
305 #![allow(clippy::useless_vec)]
306 #![allow(clippy::needless_pass_by_value)]
307 use super::*;
310 use crate::{RelayIds, RelayIdsBuilder};
311
312 fn sort<T: std::cmp::Ord>(i: impl Iterator<Item = T>) -> Vec<T> {
313 let mut v: Vec<_> = i.collect();
314 v.sort();
315 v
316 }
317
318 #[test]
319 #[allow(clippy::cognitive_complexity)]
320 fn lookup() {
321 let rsa1: RsaIdentity = (*b"12345678901234567890").into();
322 let rsa2: RsaIdentity = (*b"abcefghijklmnopqrstu").into();
323 let rsa3: RsaIdentity = (*b"abcefghijklmnopQRSTU").into();
324 let ed1: Ed25519Identity = (*b"12345678901234567890123456789012").into();
325 let ed2: Ed25519Identity = (*b"abcefghijklmnopqrstuvwxyzABCDEFG").into();
326 let ed3: Ed25519Identity = (*b"abcefghijklmnopqrstuvwxyz1234567").into();
327
328 let keys1 = RelayIdsBuilder::default()
329 .rsa_identity(rsa1)
330 .ed_identity(ed1)
331 .build()
332 .unwrap();
333
334 let keys2 = RelayIdsBuilder::default()
335 .rsa_identity(rsa2)
336 .ed_identity(ed2)
337 .build()
338 .unwrap();
339
340 let mut set = ByRelayIds::new();
344 set.insert(keys1.clone());
345 set.insert(keys2.clone());
346
347 let mut list = ListByRelayIds::new();
348 list.insert(keys1.clone());
349 list.insert(keys2.clone());
350
351 assert_eq!(set.by_id(&rsa1), Some(&keys1));
353 assert_eq!(set.by_id(&ed1), Some(&keys1));
354 assert_eq!(set.by_id(&rsa2), Some(&keys2));
355 assert_eq!(set.by_id(&ed2), Some(&keys2));
356 assert_eq!(set.by_id(&rsa3), None);
357 assert_eq!(set.by_id(&ed3), None);
358 assert_eq!(sort(list.by_id(&rsa1)), [&keys1]);
359 assert_eq!(sort(list.by_id(&ed1)), [&keys1]);
360 assert_eq!(sort(list.by_id(&rsa2)), [&keys2]);
361 assert_eq!(sort(list.by_id(&ed2)), [&keys2]);
362 assert_eq!(list.by_id(&rsa3).len(), 0);
363 assert_eq!(list.by_id(&ed3).len(), 0);
364
365 assert_eq!(set.by_all_ids(&keys1), Some(&keys1));
367 assert_eq!(set.by_all_ids(&keys2), Some(&keys2));
368 assert_eq!(set.by_all_ids(&RelayIds::empty()), None);
369 assert_eq!(sort(list.by_all_ids(&keys1)), [&keys1]);
370 assert_eq!(sort(list.by_all_ids(&keys2)), [&keys2]);
371 assert!(sort(list.by_all_ids(&RelayIds::empty())).is_empty());
372 {
373 let search = RelayIdsBuilder::default()
374 .rsa_identity(rsa1)
375 .build()
376 .unwrap();
377 assert_eq!(set.by_all_ids(&search), Some(&keys1));
378 assert_eq!(sort(list.by_all_ids(&search)), [&keys1]);
379 }
380 {
381 let search = RelayIdsBuilder::default()
382 .rsa_identity(rsa1)
383 .ed_identity(ed2)
384 .build()
385 .unwrap();
386 assert_eq!(set.by_all_ids(&search), None);
387 assert!(sort(list.by_all_ids(&search)).is_empty());
388 }
389
390 assert_eq!(set.all_overlapping(&keys1), vec![&keys1]);
392 assert_eq!(set.all_overlapping(&keys2), vec![&keys2]);
393 assert_eq!(list.all_overlapping(&keys1), vec![&keys1]);
394 assert_eq!(list.all_overlapping(&keys2), vec![&keys2]);
395 {
396 let search = RelayIdsBuilder::default()
397 .rsa_identity(rsa1)
398 .ed_identity(ed2)
399 .build()
400 .unwrap();
401 let answer = set.all_overlapping(&search);
402 assert_eq!(answer.len(), 2);
403 assert!(answer.contains(&&keys1));
404 assert!(answer.contains(&&keys2));
405 let answer = list.all_overlapping(&search);
406 assert_eq!(answer.len(), 2);
407 assert!(answer.contains(&&keys1));
408 assert!(answer.contains(&&keys2));
409 }
410 {
411 let search = RelayIdsBuilder::default()
412 .rsa_identity(rsa2)
413 .build()
414 .unwrap();
415 assert_eq!(set.all_overlapping(&search), vec![&keys2]);
416 assert_eq!(list.all_overlapping(&search), vec![&keys2]);
417 }
418 {
419 let search = RelayIdsBuilder::default()
420 .rsa_identity(rsa3)
421 .build()
422 .unwrap();
423 assert!(set.all_overlapping(&search).is_empty());
424 assert!(list.all_overlapping(&search).is_empty());
425 }
426 }
427
428 #[test]
429 #[allow(clippy::cognitive_complexity)]
430 fn remove_exact() {
431 let rsa1: RsaIdentity = (*b"12345678901234567890").into();
432 let rsa2: RsaIdentity = (*b"abcefghijklmnopqrstu").into();
433 let ed1: Ed25519Identity = (*b"12345678901234567890123456789012").into();
434 let ed2: Ed25519Identity = (*b"abcefghijklmnopqrstuvwxyzABCDEFG").into();
435
436 let keys1 = RelayIdsBuilder::default()
437 .rsa_identity(rsa1)
438 .ed_identity(ed1)
439 .build()
440 .unwrap();
441
442 let keys2 = RelayIdsBuilder::default()
443 .rsa_identity(rsa2)
444 .ed_identity(ed2)
445 .build()
446 .unwrap();
447
448 let mut set = ByRelayIds::new();
452 set.insert(keys1.clone());
453 set.insert(keys2.clone());
454 assert_eq!(set.len(), 2);
455
456 let mut list = ListByRelayIds::new();
457 list.insert(keys1.clone());
458 list.insert(keys2.clone());
459 assert_eq!(list.len(), 2);
460
461 assert_eq!(set.remove_exact(&keys1), Some(keys1.clone()));
462 assert_eq!(set.len(), 1);
463 assert_eq!(list.remove_exact(&keys1), vec![keys1.clone()]);
464 assert_eq!(list.len(), 1);
465
466 {
467 let search = RelayIdsBuilder::default().ed_identity(ed2).build().unwrap();
468
469 assert_eq!(set.remove_exact(&search), None);
471 assert_eq!(set.len(), 1);
472 assert_eq!(list.remove_exact(&search), vec![]);
473 assert_eq!(list.len(), 1);
474
475 let no_match = RelayIdsBuilder::default()
478 .ed_identity(ed2)
479 .rsa_identity(rsa1)
480 .build()
481 .unwrap();
482 assert_eq!(set.remove_by_all_ids(&no_match), None);
483 assert_eq!(set.len(), 1);
484 assert_eq!(list.remove_by_all_ids(&no_match), vec![]);
485 assert_eq!(list.len(), 1);
486
487 assert_eq!(set.remove_by_all_ids(&search), Some(keys2.clone()));
490 assert!(set.is_empty());
491 assert_eq!(list.remove_by_all_ids(&search), vec![keys2.clone()]);
492 assert!(list.is_empty());
493 }
494 }
495
496 #[test]
497 #[allow(clippy::cognitive_complexity)]
498 fn all_subset() {
499 let rsa1: RsaIdentity = (*b"12345678901234567890").into();
500 let rsa2: RsaIdentity = (*b"abcefghijklmnopqrstu").into();
501 let ed1: Ed25519Identity = (*b"12345678901234567890123456789012").into();
502
503 let keys1 = RelayIdsBuilder::default()
505 .rsa_identity(rsa1)
506 .ed_identity(ed1)
507 .build()
508 .unwrap();
509
510 let keys2 = RelayIdsBuilder::default()
512 .rsa_identity(rsa2)
513 .build()
514 .unwrap();
515
516 let mut list = ListByRelayIds::new();
517 list.insert(keys1.clone());
518 list.insert(keys2.clone());
519
520 assert_eq!(list.all_subset(&keys1), vec![&keys1]);
521 assert_eq!(list.all_subset(&keys2), vec![&keys2]);
522
523 {
524 let search = RelayIdsBuilder::default()
525 .rsa_identity(rsa1)
526 .build()
527 .unwrap();
528 assert!(list.all_subset(&search).is_empty());
529 }
530
531 {
532 let search = RelayIdsBuilder::default().ed_identity(ed1).build().unwrap();
533 assert!(list.all_subset(&search).is_empty());
534 }
535
536 {
537 let search = RelayIdsBuilder::default()
538 .rsa_identity(rsa2)
539 .build()
540 .unwrap();
541 assert_eq!(list.all_subset(&search), vec![&keys2]);
542 }
543
544 {
545 let search = RelayIdsBuilder::default()
546 .ed_identity(ed1)
547 .rsa_identity(rsa2)
548 .build()
549 .unwrap();
550 assert_eq!(list.all_subset(&search), vec![&keys2]);
551 }
552 }
553
554 #[test]
555 #[allow(clippy::cognitive_complexity)]
556 fn list_by_relay_ids() {
557 #[derive(Clone, Debug)]
558 struct ErsatzChannel<T> {
559 val: T,
560 ids: RelayIds,
561 }
562
563 impl<T> ErsatzChannel<T> {
564 fn new(val: T, ids: RelayIds) -> Self {
565 Self { val, ids }
566 }
567 }
568
569 impl<T> HasRelayIds for ErsatzChannel<T> {
570 fn identity(&self, key_type: crate::RelayIdType) -> Option<RelayIdRef<'_>> {
571 self.ids.identity(key_type)
572 }
573 }
574
575 fn ids(
577 rsa: impl Into<Option<RsaIdentity>>,
578 ed: impl Into<Option<Ed25519Identity>>,
579 ) -> RelayIds {
580 let mut ids = RelayIdsBuilder::default();
581 if let Some(rsa) = rsa.into() {
582 ids.rsa_identity(rsa);
583 }
584 if let Some(ed) = ed.into() {
585 ids.ed_identity(ed);
586 }
587 ids.build().unwrap()
588 }
589
590 let rsa_a: RsaIdentity = (*b"12345678901234567890").into();
592 let ed_a: Ed25519Identity = (*b"12345678901234567890123456789012").into();
593
594 let ed_b: Ed25519Identity = (*b"abcefghijklmnopqrstuvwxyzABCDEFG").into();
596 let rsa_b: RsaIdentity = (*b"abcefghijklmnopqrstu").into();
597
598 let channel_a_all = ErsatzChannel::new("channel-a-all", ids(rsa_a, ed_a));
600
601 let channel_a_rsa_only_1 = ErsatzChannel::new("channel-a-rsa-only-1", ids(rsa_a, None));
603
604 let channel_a_rsa_only_2 = ErsatzChannel::new("channel-a-rsa-only-2", ids(rsa_a, None));
607
608 let channel_a_ed_only = ErsatzChannel::new("channel-a-ed-only", ids(None, ed_a));
610
611 let channel_b_all = ErsatzChannel::new("channel-b-all", ids(rsa_b, ed_b));
613
614 let channel_invalid = ErsatzChannel::new("channel-invalid", ids(rsa_a, ed_b));
617
618 let mut list = ListByRelayIds::new();
619 list.insert(channel_a_all.clone());
620 list.insert(channel_a_rsa_only_1.clone());
621 list.insert(channel_a_rsa_only_2.clone());
622 list.insert(channel_a_ed_only.clone());
623 list.insert(channel_b_all.clone());
624 list.insert(channel_invalid.clone());
625
626 assert_eq!(
628 sort(list.by_id(&rsa_a).map(|x| x.val)),
629 [
630 "channel-a-all",
631 "channel-a-rsa-only-1",
632 "channel-a-rsa-only-2",
633 "channel-invalid",
634 ],
635 );
636
637 assert_eq!(
639 sort(list.by_id(&ed_a).map(|x| x.val)),
640 ["channel-a-all", "channel-a-ed-only"],
641 );
642
643 assert_eq!(sort(list.by_id(&rsa_b).map(|x| x.val)), ["channel-b-all"]);
645
646 assert_eq!(
648 sort(list.by_id(&ed_b).map(|x| x.val)),
649 ["channel-b-all", "channel-invalid"],
650 );
651
652 assert_eq!(
654 sort(list.by_all_ids(&ids(rsa_a, ed_a)).map(|x| x.val)),
655 ["channel-a-all"],
656 );
657
658 assert_eq!(
660 sort(list.by_all_ids(&ids(rsa_b, ed_b)).map(|x| x.val)),
661 ["channel-b-all"],
662 );
663
664 assert_eq!(
666 sort(
667 list.all_overlapping(&ids(rsa_a, ed_a))
668 .into_iter()
669 .map(|x| x.val)
670 ),
671 [
672 "channel-a-all",
673 "channel-a-ed-only",
674 "channel-a-rsa-only-1",
675 "channel-a-rsa-only-2",
676 "channel-invalid",
677 ],
678 );
679
680 assert_eq!(
682 sort(
683 list.all_subset(&ids(rsa_a, ed_a))
684 .into_iter()
685 .map(|x| x.val)
686 ),
687 [
688 "channel-a-all",
689 "channel-a-ed-only",
690 "channel-a-rsa-only-1",
691 "channel-a-rsa-only-2",
692 ],
693 );
694
695 assert_eq!(list.by_all_ids(&ids(None, None)).count(), 0);
697 assert!(list.all_overlapping(&ids(None, None)).is_empty());
698 assert!(list.all_subset(&ids(None, None)).is_empty());
699 assert_eq!(
700 sort(
701 list.all_overlapping(&ids(rsa_a, None))
702 .into_iter()
703 .map(|x| x.val)
704 ),
705 sort(list.by_id(&rsa_a).map(|x| x.val)),
706 );
707 assert_eq!(
708 sort(
709 list.all_overlapping(&ids(None, ed_b))
710 .into_iter()
711 .map(|x| x.val)
712 ),
713 sort(list.by_id(&ed_b).map(|x| x.val)),
714 );
715 assert_eq!(
716 sort(list.by_id(&rsa_a).map(|x| x.val)),
717 sort(list.by_rsa(&rsa_a).map(|x| x.val)),
718 );
719 assert_eq!(
720 sort(list.by_id(&ed_a).map(|x| x.val)),
721 sort(list.by_ed25519(&ed_a).map(|x| x.val)),
722 );
723
724 {
726 let mut list = list.clone();
727 assert_eq!(
728 sort(
729 list.remove_exact(&ids(rsa_a, ed_a))
730 .into_iter()
731 .map(|x| x.val)
732 ),
733 ["channel-a-all"],
734 );
735 assert_eq!(list.by_all_ids(&ids(rsa_a, ed_a)).count(), 0);
736 }
737
738 {
740 let mut list = list.clone();
741 assert_eq!(
742 sort(
743 list.remove_exact(&ids(rsa_a, None))
744 .into_iter()
745 .map(|x| x.val)
746 ),
747 ["channel-a-rsa-only-1", "channel-a-rsa-only-2"],
748 );
749 assert_eq!(
750 sort(list.by_all_ids(&ids(rsa_a, None)).map(|x| x.val)),
751 ["channel-a-all", "channel-invalid"],
752 );
753 }
754
755 {
757 let mut list = list.clone();
758 assert_eq!(
759 sort(
760 list.remove_by_all_ids(&ids(rsa_a, None))
761 .into_iter()
762 .map(|x| x.val)
763 ),
764 [
765 "channel-a-all",
766 "channel-a-rsa-only-1",
767 "channel-a-rsa-only-2",
768 "channel-invalid",
769 ],
770 );
771 assert_eq!(list.by_all_ids(&ids(rsa_a, None)).count(), 0);
772 }
773 }
774}