1
//! Types and code to map circuit IDs to circuits.
2

            
3
// NOTE: This is a work in progress and I bet I'll refactor it a lot;
4
// it needs to stay opaque!
5

            
6
use crate::client::circuit::padding::{PaddingController, QueuedCellPaddingInfo};
7
use crate::{Error, Result};
8
use tor_basic_utils::RngExt;
9
use tor_cell::chancell::CircId;
10

            
11
use crate::client::circuit::halfcirc::HalfCirc;
12
use crate::client::circuit::{CircuitRxSender, celltypes::CreateResponse};
13

            
14
use oneshot_fused_workaround as oneshot;
15

            
16
use rand::Rng;
17
use rand::distr::Distribution;
18
use std::collections::{HashMap, hash_map::Entry};
19
use std::ops::{Deref, DerefMut};
20

            
21
/// Which group of circuit IDs are we allowed to allocate in this map?
22
///
23
/// If we initiated the channel, we use High circuit ids.  If we're the
24
/// responder, we use low circuit ids.
25
#[derive(Copy, Clone)]
26
pub(super) enum CircIdRange {
27
    /// Only use circuit IDs with the MSB cleared.
28
    #[allow(dead_code)] // Relays will need this.
29
    Low,
30
    /// Only use circuit IDs with the MSB set.
31
    High,
32
    // Historical note: There used to be an "All" range of circuit IDs
33
    // available to clients only.  We stopped using "All" when we moved to link
34
    // protocol version 4.
35
}
36

            
37
impl rand::distr::Distribution<CircId> for CircIdRange {
38
    /// Return a random circuit ID in the appropriate range.
39
524
    fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> CircId {
40
524
        let midpoint = 0x8000_0000_u32;
41
524
        let v = match self {
42
            // 0 is an invalid value
43
256
            CircIdRange::Low => rng.gen_range_checked(1..midpoint),
44
268
            CircIdRange::High => rng.gen_range_checked(midpoint..=u32::MAX),
45
        };
46
524
        let v = v.expect("Unexpected empty range passed to gen_range_checked");
47
524
        CircId::new(v).expect("Unexpected zero value")
48
524
    }
49
}
50

            
51
/// An entry in the circuit map.  Right now, we only have "here's the
52
/// way to send cells to a given circuit", but that's likely to
53
/// change.
54
#[derive(Debug)]
55
pub(super) enum CircEnt {
56
    /// A circuit that has not yet received a CREATED cell.
57
    ///
58
    /// For this circuit, the CREATED* cell or DESTROY cell gets sent
59
    /// to the oneshot sender to tell the corresponding
60
    /// PendingClientCirc that the handshake is done.
61
    ///
62
    /// Once that's done, the `CircuitRxSender` mpsc sender will be used to send subsequent
63
    /// cells to the circuit.
64
    Opening {
65
        /// The oneshot sender on which to report a create response
66
        create_response_sender: oneshot::Sender<CreateResponse>,
67
        /// A sink which should receive all the relay cells for this circuit
68
        /// from this channel
69
        cell_sender: CircuitRxSender,
70
        //// A padding controller we should use when reporting flushed cells.
71
        padding_ctrl: PaddingController,
72
    },
73

            
74
    /// A circuit that is open and can be given relay cells.
75
    Open {
76
        /// A sink which should receive all the relay cells for this circuit
77
        /// from this channel
78
        cell_sender: CircuitRxSender,
79
        //// A padding controller we should use when reporting flushed cells.
80
        padding_ctrl: PaddingController,
81
    },
82

            
83
    /// A circuit where we have sent a DESTROY, but the other end might
84
    /// not have gotten a DESTROY yet.
85
    DestroySent(HalfCirc),
86
}
87

            
88
/// An "smart pointer" that wraps an exclusive reference
89
/// of a `CircEnt`.
90
///
91
/// When being dropped, this object updates the open or opening entries
92
/// counter of the `CircMap`.
93
pub(super) struct MutCircEnt<'a> {
94
    /// An exclusive reference to the `CircEnt`.
95
    value: &'a mut CircEnt,
96
    /// An exclusive reference to the open or opening
97
    ///  entries counter.
98
    open_count: &'a mut usize,
99
    /// True if the entry was open or opening when borrowed.
100
    was_open: bool,
101
}
102

            
103
impl<'a> Drop for MutCircEnt<'a> {
104
624
    fn drop(&mut self) {
105
624
        let is_open = !matches!(self.value, CircEnt::DestroySent(_));
106
624
        match (self.was_open, is_open) {
107
            (false, true) => *self.open_count = self.open_count.saturating_add(1),
108
            (true, false) => *self.open_count = self.open_count.saturating_sub(1),
109
624
            (_, _) => (),
110
        };
111
624
    }
112
}
113

            
114
impl<'a> Deref for MutCircEnt<'a> {
115
    type Target = CircEnt;
116
284
    fn deref(&self) -> &Self::Target {
117
284
        self.value
118
284
    }
119
}
120

            
121
impl<'a> DerefMut for MutCircEnt<'a> {
122
336
    fn deref_mut(&mut self) -> &mut Self::Target {
123
336
        self.value
124
336
    }
125
}
126

            
127
/// A map from circuit IDs to circuit entries. Each channel has one.
128
pub(super) struct CircMap {
129
    /// Map from circuit IDs to entries
130
    m: HashMap<CircId, CircEnt>,
131
    /// Rule for allocating new circuit IDs.
132
    range: CircIdRange,
133
    /// Number of open or opening entry in this map.
134
    open_count: usize,
135
}
136

            
137
impl CircMap {
138
    /// Make a new empty CircMap
139
520
    pub(super) fn new(idrange: CircIdRange) -> Self {
140
520
        CircMap {
141
520
            m: HashMap::new(),
142
520
            range: idrange,
143
520
            open_count: 0,
144
520
        }
145
520
    }
146

            
147
    /// Add a new set of elements (corresponding to a PendingClientCirc)
148
    /// to this map.
149
    ///
150
    /// On success return the allocated circuit ID.
151
524
    pub(super) fn add_ent<R: Rng>(
152
524
        &mut self,
153
524
        rng: &mut R,
154
524
        createdsink: oneshot::Sender<CreateResponse>,
155
524
        sink: CircuitRxSender,
156
524
        padding_ctrl: PaddingController,
157
524
    ) -> Result<CircId> {
158
        /// How many times do we probe for a random circuit ID before
159
        /// we assume that the range is fully populated?
160
        ///
161
        /// TODO: C tor does 64, but that is probably overkill with 4-byte circuit IDs.
162
        const N_ATTEMPTS: usize = 16;
163
524
        let iter = self.range.sample_iter(rng).take(N_ATTEMPTS);
164
524
        let circ_ent = CircEnt::Opening {
165
524
            create_response_sender: createdsink,
166
524
            cell_sender: sink,
167
524
            padding_ctrl,
168
524
        };
169
524
        for id in iter {
170
524
            let ent = self.m.entry(id);
171
524
            if let Entry::Vacant(_) = &ent {
172
524
                ent.or_insert(circ_ent);
173
524
                self.open_count += 1;
174
524
                return Ok(id);
175
            }
176
        }
177
        Err(Error::IdRangeFull)
178
524
    }
179

            
180
    /// Testing only: install an entry in this circuit map without regard
181
    /// for consistency.
182
    #[cfg(test)]
183
72
    pub(super) fn put_unchecked(&mut self, id: CircId, ent: CircEnt) {
184
72
        self.m.insert(id, ent);
185
72
    }
186

            
187
    /// Return the entry for `id` in this map, if any.
188
652
    pub(super) fn get_mut(&mut self, id: CircId) -> Option<MutCircEnt> {
189
652
        let open_count = &mut self.open_count;
190
652
        self.m.get_mut(&id).map(move |ent| MutCircEnt {
191
624
            open_count,
192
624
            was_open: !matches!(ent, CircEnt::DestroySent(_)),
193
624
            value: ent,
194
624
        })
195
652
    }
196

            
197
    /// Inform the relevant circuit's padding subsystem that a given cell has been flushed.
198
4342
    pub(super) fn note_cell_flushed(&mut self, id: CircId, info: QueuedCellPaddingInfo) {
199
4342
        let padding_ctrl = match self.m.get(&id) {
200
            Some(CircEnt::Opening { padding_ctrl, .. }) => padding_ctrl,
201
            Some(CircEnt::Open { padding_ctrl, .. }) => padding_ctrl,
202
4342
            Some(CircEnt::DestroySent(..)) | None => return,
203
        };
204
        padding_ctrl.flushed_relay_cell(info);
205
4342
    }
206

            
207
    /// See whether 'id' is an opening circuit.  If so, mark it "open" and
208
    /// return a oneshot::Sender that is waiting for its create cell.
209
30
    pub(super) fn advance_from_opening(
210
30
        &mut self,
211
30
        id: CircId,
212
30
    ) -> Result<oneshot::Sender<CreateResponse>> {
213
        // TODO: there should be a better way to do
214
        // this. hash_map::Entry seems like it could be better, but
215
        // there seems to be no way to replace the object in-place as
216
        // a consuming function of itself.
217
30
        let ok = matches!(self.m.get(&id), Some(CircEnt::Opening { .. }));
218
30
        if ok {
219
            if let Some(CircEnt::Opening {
220
2
                create_response_sender: oneshot,
221
2
                cell_sender: sink,
222
2
                padding_ctrl,
223
2
            }) = self.m.remove(&id)
224
            {
225
2
                self.m.insert(
226
2
                    id,
227
2
                    CircEnt::Open {
228
2
                        cell_sender: sink,
229
2
                        padding_ctrl,
230
2
                    },
231
                );
232
2
                Ok(oneshot)
233
            } else {
234
                panic!("internal error: inconsistent circuit state");
235
            }
236
        } else {
237
28
            Err(Error::ChanProto(
238
28
                "Unexpected CREATED* cell not on opening circuit".into(),
239
28
            ))
240
        }
241
30
    }
242

            
243
    /// Called when we have sent a DESTROY on a circuit.  Configures
244
    /// a "HalfCirc" object to track how many cells we get on this
245
    /// circuit, and to prevent us from reusing it immediately.
246
106
    pub(super) fn destroy_sent(&mut self, id: CircId, hs: HalfCirc) {
247
106
        if let Some(replaced) = self.m.insert(id, CircEnt::DestroySent(hs)) {
248
12
            if !matches!(replaced, CircEnt::DestroySent(_)) {
249
12
                // replaced an Open/Opening entry with DestroySent
250
12
                self.open_count = self.open_count.saturating_sub(1);
251
12
            }
252
94
        }
253
106
    }
254

            
255
    /// Extract the value from this map with 'id' if any
256
50
    pub(super) fn remove(&mut self, id: CircId) -> Option<CircEnt> {
257
69
        self.m.remove(&id).map(|removed| {
258
38
            if !matches!(removed, CircEnt::DestroySent(_)) {
259
26
                self.open_count = self.open_count.saturating_sub(1);
260
26
            }
261
38
            removed
262
38
        })
263
50
    }
264

            
265
    /// Return the total number of open and opening entries in the map
266
172
    pub(super) fn open_ent_count(&self) -> usize {
267
172
        self.open_count
268
172
    }
269

            
270
    // TODO: Eventually if we want relay support, we'll need to support
271
    // circuit IDs chosen by somebody else. But for now, we don't need those.
272
}
273

            
274
#[cfg(test)]
275
mod test {
276
    // @@ begin test lint list maintained by maint/add_warning @@
277
    #![allow(clippy::bool_assert_comparison)]
278
    #![allow(clippy::clone_on_copy)]
279
    #![allow(clippy::dbg_macro)]
280
    #![allow(clippy::mixed_attributes_style)]
281
    #![allow(clippy::print_stderr)]
282
    #![allow(clippy::print_stdout)]
283
    #![allow(clippy::single_char_pattern)]
284
    #![allow(clippy::unwrap_used)]
285
    #![allow(clippy::unchecked_duration_subtraction)]
286
    #![allow(clippy::useless_vec)]
287
    #![allow(clippy::needless_pass_by_value)]
288
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
289
    use super::*;
290
    use crate::{client::circuit::padding::new_padding, fake_mpsc};
291
    use tor_basic_utils::test_rng::testing_rng;
292
    use tor_rtcompat::DynTimeProvider;
293

            
294
    #[test]
295
    fn circmap_basics() {
296
        let mut map_low = CircMap::new(CircIdRange::Low);
297
        let mut map_high = CircMap::new(CircIdRange::High);
298
        let mut ids_low: Vec<CircId> = Vec::new();
299
        let mut ids_high: Vec<CircId> = Vec::new();
300
        let mut rng = testing_rng();
301
        tor_rtcompat::test_with_one_runtime!(|runtime| async {
302
            let (padding_ctrl, _padding_stream) = new_padding(DynTimeProvider::new(runtime));
303

            
304
            assert!(map_low.get_mut(CircId::new(77).unwrap()).is_none());
305

            
306
            for _ in 0..128 {
307
                let (csnd, _) = oneshot::channel();
308
                let (snd, _) = fake_mpsc(8);
309
                let id_low = map_low
310
                    .add_ent(&mut rng, csnd, snd, padding_ctrl.clone())
311
                    .unwrap();
312
                assert!(u32::from(id_low) > 0);
313
                assert!(u32::from(id_low) < 0x80000000);
314
                assert!(!ids_low.contains(&id_low));
315
                ids_low.push(id_low);
316

            
317
                assert!(matches!(
318
                    *map_low.get_mut(id_low).unwrap(),
319
                    CircEnt::Opening { .. }
320
                ));
321

            
322
                let (csnd, _) = oneshot::channel();
323
                let (snd, _) = fake_mpsc(8);
324
                let id_high = map_high
325
                    .add_ent(&mut rng, csnd, snd, padding_ctrl.clone())
326
                    .unwrap();
327
                assert!(u32::from(id_high) >= 0x80000000);
328
                assert!(!ids_high.contains(&id_high));
329
                ids_high.push(id_high);
330
            }
331

            
332
            // Test open / opening entry counting
333
            assert_eq!(128, map_low.open_ent_count());
334
            assert_eq!(128, map_high.open_ent_count());
335

            
336
            // Test remove
337
            assert!(map_low.get_mut(ids_low[0]).is_some());
338
            map_low.remove(ids_low[0]);
339
            assert!(map_low.get_mut(ids_low[0]).is_none());
340
            assert_eq!(127, map_low.open_ent_count());
341

            
342
            // Test DestroySent doesn't count
343
            map_low.destroy_sent(CircId::new(256).unwrap(), HalfCirc::new(1));
344
            assert_eq!(127, map_low.open_ent_count());
345

            
346
            // Test advance_from_opening.
347

            
348
            // Good case.
349
            assert!(map_high.get_mut(ids_high[0]).is_some());
350
            assert!(matches!(
351
                *map_high.get_mut(ids_high[0]).unwrap(),
352
                CircEnt::Opening { .. }
353
            ));
354
            let adv = map_high.advance_from_opening(ids_high[0]);
355
            assert!(adv.is_ok());
356
            assert!(matches!(
357
                *map_high.get_mut(ids_high[0]).unwrap(),
358
                CircEnt::Open { .. }
359
            ));
360

            
361
            // Can't double-advance.
362
            let adv = map_high.advance_from_opening(ids_high[0]);
363
            assert!(adv.is_err());
364

            
365
            // Can't advance an entry that is not there.  We know "77"
366
            // can't be in map_high, since we only added high circids to
367
            // it.
368
            let adv = map_high.advance_from_opening(CircId::new(77).unwrap());
369
            assert!(adv.is_err());
370
        });
371
    }
372
}