1use crate::client::circuit::padding::{PaddingController, QueuedCellPaddingInfo};
7use crate::{Error, Result};
8use tor_basic_utils::RngExt;
9use tor_cell::chancell::CircId;
10
11use crate::client::circuit::halfcirc::HalfCirc;
12use crate::client::circuit::{CircuitRxSender, celltypes::CreateResponse};
13
14use oneshot_fused_workaround as oneshot;
15
16use rand::Rng;
17use rand::distr::Distribution;
18use std::collections::{HashMap, hash_map::Entry};
19use std::ops::{Deref, DerefMut};
20
21#[derive(Copy, Clone)]
26pub(super) enum CircIdRange {
27 #[allow(dead_code)] Low,
30 High,
32 }
36
37impl rand::distr::Distribution<CircId> for CircIdRange {
38 fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> CircId {
40 let midpoint = 0x8000_0000_u32;
41 let v = match self {
42 CircIdRange::Low => rng.gen_range_checked(1..midpoint),
44 CircIdRange::High => rng.gen_range_checked(midpoint..=u32::MAX),
45 };
46 let v = v.expect("Unexpected empty range passed to gen_range_checked");
47 CircId::new(v).expect("Unexpected zero value")
48 }
49}
50
51#[derive(Debug)]
55pub(super) enum CircEnt {
56 Opening {
65 create_response_sender: oneshot::Sender<CreateResponse>,
67 cell_sender: CircuitRxSender,
70 padding_ctrl: PaddingController,
72 },
73
74 Open {
76 cell_sender: CircuitRxSender,
79 padding_ctrl: PaddingController,
81 },
82
83 DestroySent(HalfCirc),
86}
87
88pub(super) struct MutCircEnt<'a> {
94 value: &'a mut CircEnt,
96 open_count: &'a mut usize,
99 was_open: bool,
101}
102
103impl<'a> Drop for MutCircEnt<'a> {
104 fn drop(&mut self) {
105 let is_open = !matches!(self.value, CircEnt::DestroySent(_));
106 match (self.was_open, is_open) {
107 (false, true) => *self.open_count = self.open_count.saturating_add(1),
108 (true, false) => *self.open_count = self.open_count.saturating_sub(1),
109 (_, _) => (),
110 };
111 }
112}
113
114impl<'a> Deref for MutCircEnt<'a> {
115 type Target = CircEnt;
116 fn deref(&self) -> &Self::Target {
117 self.value
118 }
119}
120
121impl<'a> DerefMut for MutCircEnt<'a> {
122 fn deref_mut(&mut self) -> &mut Self::Target {
123 self.value
124 }
125}
126
127pub(super) struct CircMap {
129 m: HashMap<CircId, CircEnt>,
131 range: CircIdRange,
133 open_count: usize,
135}
136
137impl CircMap {
138 pub(super) fn new(idrange: CircIdRange) -> Self {
140 CircMap {
141 m: HashMap::new(),
142 range: idrange,
143 open_count: 0,
144 }
145 }
146
147 pub(super) fn add_ent<R: Rng>(
152 &mut self,
153 rng: &mut R,
154 createdsink: oneshot::Sender<CreateResponse>,
155 sink: CircuitRxSender,
156 padding_ctrl: PaddingController,
157 ) -> Result<CircId> {
158 const N_ATTEMPTS: usize = 16;
163 let iter = self.range.sample_iter(rng).take(N_ATTEMPTS);
164 let circ_ent = CircEnt::Opening {
165 create_response_sender: createdsink,
166 cell_sender: sink,
167 padding_ctrl,
168 };
169 for id in iter {
170 let ent = self.m.entry(id);
171 if let Entry::Vacant(_) = &ent {
172 ent.or_insert(circ_ent);
173 self.open_count += 1;
174 return Ok(id);
175 }
176 }
177 Err(Error::IdRangeFull)
178 }
179
180 #[cfg(test)]
183 pub(super) fn put_unchecked(&mut self, id: CircId, ent: CircEnt) {
184 self.m.insert(id, ent);
185 }
186
187 pub(super) fn get_mut(&mut self, id: CircId) -> Option<MutCircEnt> {
189 let open_count = &mut self.open_count;
190 self.m.get_mut(&id).map(move |ent| MutCircEnt {
191 open_count,
192 was_open: !matches!(ent, CircEnt::DestroySent(_)),
193 value: ent,
194 })
195 }
196
197 pub(super) fn note_cell_flushed(&mut self, id: CircId, info: QueuedCellPaddingInfo) {
199 let padding_ctrl = match self.m.get(&id) {
200 Some(CircEnt::Opening { padding_ctrl, .. }) => padding_ctrl,
201 Some(CircEnt::Open { padding_ctrl, .. }) => padding_ctrl,
202 Some(CircEnt::DestroySent(..)) | None => return,
203 };
204 padding_ctrl.flushed_relay_cell(info);
205 }
206
207 pub(super) fn advance_from_opening(
210 &mut self,
211 id: CircId,
212 ) -> Result<oneshot::Sender<CreateResponse>> {
213 let ok = matches!(self.m.get(&id), Some(CircEnt::Opening { .. }));
218 if ok {
219 if let Some(CircEnt::Opening {
220 create_response_sender: oneshot,
221 cell_sender: sink,
222 padding_ctrl,
223 }) = self.m.remove(&id)
224 {
225 self.m.insert(
226 id,
227 CircEnt::Open {
228 cell_sender: sink,
229 padding_ctrl,
230 },
231 );
232 Ok(oneshot)
233 } else {
234 panic!("internal error: inconsistent circuit state");
235 }
236 } else {
237 Err(Error::ChanProto(
238 "Unexpected CREATED* cell not on opening circuit".into(),
239 ))
240 }
241 }
242
243 pub(super) fn destroy_sent(&mut self, id: CircId, hs: HalfCirc) {
247 if let Some(replaced) = self.m.insert(id, CircEnt::DestroySent(hs)) {
248 if !matches!(replaced, CircEnt::DestroySent(_)) {
249 self.open_count = self.open_count.saturating_sub(1);
251 }
252 }
253 }
254
255 pub(super) fn remove(&mut self, id: CircId) -> Option<CircEnt> {
257 self.m.remove(&id).map(|removed| {
258 if !matches!(removed, CircEnt::DestroySent(_)) {
259 self.open_count = self.open_count.saturating_sub(1);
260 }
261 removed
262 })
263 }
264
265 pub(super) fn open_ent_count(&self) -> usize {
267 self.open_count
268 }
269
270 }
273
274#[cfg(test)]
275mod test {
276 #![allow(clippy::bool_assert_comparison)]
278 #![allow(clippy::clone_on_copy)]
279 #![allow(clippy::dbg_macro)]
280 #![allow(clippy::mixed_attributes_style)]
281 #![allow(clippy::print_stderr)]
282 #![allow(clippy::print_stdout)]
283 #![allow(clippy::single_char_pattern)]
284 #![allow(clippy::unwrap_used)]
285 #![allow(clippy::unchecked_duration_subtraction)]
286 #![allow(clippy::useless_vec)]
287 #![allow(clippy::needless_pass_by_value)]
288 use super::*;
290 use crate::{client::circuit::padding::new_padding, fake_mpsc};
291 use tor_basic_utils::test_rng::testing_rng;
292 use tor_rtcompat::DynTimeProvider;
293
294 #[test]
295 fn circmap_basics() {
296 let mut map_low = CircMap::new(CircIdRange::Low);
297 let mut map_high = CircMap::new(CircIdRange::High);
298 let mut ids_low: Vec<CircId> = Vec::new();
299 let mut ids_high: Vec<CircId> = Vec::new();
300 let mut rng = testing_rng();
301 tor_rtcompat::test_with_one_runtime!(|runtime| async {
302 let (padding_ctrl, _padding_stream) = new_padding(DynTimeProvider::new(runtime));
303
304 assert!(map_low.get_mut(CircId::new(77).unwrap()).is_none());
305
306 for _ in 0..128 {
307 let (csnd, _) = oneshot::channel();
308 let (snd, _) = fake_mpsc(8);
309 let id_low = map_low
310 .add_ent(&mut rng, csnd, snd, padding_ctrl.clone())
311 .unwrap();
312 assert!(u32::from(id_low) > 0);
313 assert!(u32::from(id_low) < 0x80000000);
314 assert!(!ids_low.contains(&id_low));
315 ids_low.push(id_low);
316
317 assert!(matches!(
318 *map_low.get_mut(id_low).unwrap(),
319 CircEnt::Opening { .. }
320 ));
321
322 let (csnd, _) = oneshot::channel();
323 let (snd, _) = fake_mpsc(8);
324 let id_high = map_high
325 .add_ent(&mut rng, csnd, snd, padding_ctrl.clone())
326 .unwrap();
327 assert!(u32::from(id_high) >= 0x80000000);
328 assert!(!ids_high.contains(&id_high));
329 ids_high.push(id_high);
330 }
331
332 assert_eq!(128, map_low.open_ent_count());
334 assert_eq!(128, map_high.open_ent_count());
335
336 assert!(map_low.get_mut(ids_low[0]).is_some());
338 map_low.remove(ids_low[0]);
339 assert!(map_low.get_mut(ids_low[0]).is_none());
340 assert_eq!(127, map_low.open_ent_count());
341
342 map_low.destroy_sent(CircId::new(256).unwrap(), HalfCirc::new(1));
344 assert_eq!(127, map_low.open_ent_count());
345
346 assert!(map_high.get_mut(ids_high[0]).is_some());
350 assert!(matches!(
351 *map_high.get_mut(ids_high[0]).unwrap(),
352 CircEnt::Opening { .. }
353 ));
354 let adv = map_high.advance_from_opening(ids_high[0]);
355 assert!(adv.is_ok());
356 assert!(matches!(
357 *map_high.get_mut(ids_high[0]).unwrap(),
358 CircEnt::Open { .. }
359 ));
360
361 let adv = map_high.advance_from_opening(ids_high[0]);
363 assert!(adv.is_err());
364
365 let adv = map_high.advance_from_opening(CircId::new(77).unwrap());
369 assert!(adv.is_err());
370 });
371 }
372}