1
//! Types and code to map circuit IDs to circuits.
2

            
3
// NOTE: This is a work in progress and I bet I'll refactor it a lot;
4
// it needs to stay opaque!
5

            
6
use crate::{Error, Result};
7
use tor_basic_utils::RngExt;
8
use tor_cell::chancell::CircId;
9

            
10
use crate::tunnel::circuit::halfcirc::HalfCirc;
11
use crate::tunnel::circuit::{celltypes::CreateResponse, CircuitRxSender};
12

            
13
use oneshot_fused_workaround as oneshot;
14

            
15
use rand::distr::Distribution;
16
use rand::Rng;
17
use std::collections::{hash_map::Entry, HashMap};
18
use std::ops::{Deref, DerefMut};
19

            
20
/// Which group of circuit IDs are we allowed to allocate in this map?
21
///
22
/// If we initiated the channel, we use High circuit ids.  If we're the
23
/// responder, we use low circuit ids.
24
#[derive(Copy, Clone)]
25
pub(super) enum CircIdRange {
26
    /// Only use circuit IDs with the MSB cleared.
27
    #[allow(dead_code)] // Relays will need this.
28
    Low,
29
    /// Only use circuit IDs with the MSB set.
30
    High,
31
    // Historical note: There used to be an "All" range of circuit IDs
32
    // available to clients only.  We stopped using "All" when we moved to link
33
    // protocol version 4.
34
}
35

            
36
impl rand::distr::Distribution<CircId> for CircIdRange {
37
    /// Return a random circuit ID in the appropriate range.
38
520
    fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> CircId {
39
520
        let midpoint = 0x8000_0000_u32;
40
520
        let v = match self {
41
            // 0 is an invalid value
42
256
            CircIdRange::Low => rng.gen_range_checked(1..midpoint),
43
264
            CircIdRange::High => rng.gen_range_checked(midpoint..=u32::MAX),
44
        };
45
520
        let v = v.expect("Unexpected empty range passed to gen_range_checked");
46
520
        CircId::new(v).expect("Unexpected zero value")
47
520
    }
48
}
49

            
50
/// An entry in the circuit map.  Right now, we only have "here's the
51
/// way to send cells to a given circuit", but that's likely to
52
/// change.
53
#[derive(Debug)]
54
pub(super) enum CircEnt {
55
    /// A circuit that has not yet received a CREATED cell.
56
    ///
57
    /// For this circuit, the CREATED* cell or DESTROY cell gets sent
58
    /// to the oneshot sender to tell the corresponding
59
    /// PendingClientCirc that the handshake is done.
60
    ///
61
    /// Once that's done, the `CircuitRxSender` mpsc sender will be used to send subsequent
62
    /// cells to the circuit.
63
    Opening(oneshot::Sender<CreateResponse>, CircuitRxSender),
64

            
65
    /// A circuit that is open and can be given relay cells.
66
    Open(CircuitRxSender),
67

            
68
    /// A circuit where we have sent a DESTROY, but the other end might
69
    /// not have gotten a DESTROY yet.
70
    DestroySent(HalfCirc),
71
}
72

            
73
/// An "smart pointer" that wraps an exclusive reference
74
/// of a `CircEnt`.
75
///
76
/// When being dropped, this object updates the open or opening entries
77
/// counter of the `CircMap`.
78
pub(super) struct MutCircEnt<'a> {
79
    /// An exclusive reference to the `CircEnt`.
80
    value: &'a mut CircEnt,
81
    /// An exclusive reference to the open or opening
82
    ///  entries counter.
83
    open_count: &'a mut usize,
84
    /// True if the entry was open or opening when borrowed.
85
    was_open: bool,
86
}
87

            
88
impl<'a> Drop for MutCircEnt<'a> {
89
504
    fn drop(&mut self) {
90
504
        let is_open = !matches!(self.value, CircEnt::DestroySent(_));
91
504
        match (self.was_open, is_open) {
92
            (false, true) => *self.open_count = self.open_count.saturating_add(1),
93
            (true, false) => *self.open_count = self.open_count.saturating_sub(1),
94
504
            (_, _) => (),
95
        };
96
504
    }
97
}
98

            
99
impl<'a> Deref for MutCircEnt<'a> {
100
    type Target = CircEnt;
101
276
    fn deref(&self) -> &Self::Target {
102
276
        self.value
103
276
    }
104
}
105

            
106
impl<'a> DerefMut for MutCircEnt<'a> {
107
224
    fn deref_mut(&mut self) -> &mut Self::Target {
108
224
        self.value
109
224
    }
110
}
111

            
112
/// A map from circuit IDs to circuit entries. Each channel has one.
113
pub(super) struct CircMap {
114
    /// Map from circuit IDs to entries
115
    m: HashMap<CircId, CircEnt>,
116
    /// Rule for allocating new circuit IDs.
117
    range: CircIdRange,
118
    /// Number of open or opening entry in this map.
119
    open_count: usize,
120
}
121

            
122
impl CircMap {
123
    /// Make a new empty CircMap
124
267
    pub(super) fn new(idrange: CircIdRange) -> Self {
125
267
        CircMap {
126
267
            m: HashMap::new(),
127
267
            range: idrange,
128
267
            open_count: 0,
129
267
        }
130
267
    }
131

            
132
    /// Add a new pair of elements (corresponding to a PendingClientCirc)
133
    /// to this map.
134
    ///
135
    /// On success return the allocated circuit ID.
136
520
    pub(super) fn add_ent<R: Rng>(
137
520
        &mut self,
138
520
        rng: &mut R,
139
520
        createdsink: oneshot::Sender<CreateResponse>,
140
520
        sink: CircuitRxSender,
141
520
    ) -> Result<CircId> {
142
        /// How many times do we probe for a random circuit ID before
143
        /// we assume that the range is fully populated?
144
        ///
145
        /// TODO: C tor does 64, but that is probably overkill with 4-byte circuit IDs.
146
        const N_ATTEMPTS: usize = 16;
147
520
        let iter = self.range.sample_iter(rng).take(N_ATTEMPTS);
148
520
        let circ_ent = CircEnt::Opening(createdsink, sink);
149
520
        for id in iter {
150
520
            let ent = self.m.entry(id);
151
520
            if let Entry::Vacant(_) = &ent {
152
520
                ent.or_insert(circ_ent);
153
520
                self.open_count += 1;
154
520
                return Ok(id);
155
            }
156
        }
157
        Err(Error::IdRangeFull)
158
520
    }
159

            
160
    /// Testing only: install an entry in this circuit map without regard
161
    /// for consistency.
162
    #[cfg(test)]
163
48
    pub(super) fn put_unchecked(&mut self, id: CircId, ent: CircEnt) {
164
48
        self.m.insert(id, ent);
165
48
    }
166

            
167
    /// Return the entry for `id` in this map, if any.
168
524
    pub(super) fn get_mut(&mut self, id: CircId) -> Option<MutCircEnt> {
169
524
        let open_count = &mut self.open_count;
170
776
        self.m.get_mut(&id).map(move |ent| MutCircEnt {
171
504
            open_count,
172
504
            was_open: !matches!(ent, CircEnt::DestroySent(_)),
173
504
            value: ent,
174
776
        })
175
524
    }
176

            
177
    /// See whether 'id' is an opening circuit.  If so, mark it "open" and
178
    /// return a oneshot::Sender that is waiting for its create cell.
179
22
    pub(super) fn advance_from_opening(
180
22
        &mut self,
181
22
        id: CircId,
182
22
    ) -> Result<oneshot::Sender<CreateResponse>> {
183
        // TODO: there should be a better way to do
184
        // this. hash_map::Entry seems like it could be better, but
185
        // there seems to be no way to replace the object in-place as
186
        // a consuming function of itself.
187
22
        let ok = matches!(self.m.get(&id), Some(CircEnt::Opening(_, _)));
188
22
        if ok {
189
2
            if let Some(CircEnt::Opening(oneshot, sink)) = self.m.remove(&id) {
190
2
                self.m.insert(id, CircEnt::Open(sink));
191
2
                Ok(oneshot)
192
            } else {
193
                panic!("internal error: inconsistent circuit state");
194
            }
195
        } else {
196
20
            Err(Error::ChanProto(
197
20
                "Unexpected CREATED* cell not on opening circuit".into(),
198
20
            ))
199
        }
200
22
    }
201

            
202
    /// Called when we have sent a DESTROY on a circuit.  Configures
203
    /// a "HalfCirc" object to track how many cells we get on this
204
    /// circuit, and to prevent us from reusing it immediately.
205
42
    pub(super) fn destroy_sent(&mut self, id: CircId, hs: HalfCirc) {
206
42
        if let Some(replaced) = self.m.insert(id, CircEnt::DestroySent(hs)) {
207
8
            if !matches!(replaced, CircEnt::DestroySent(_)) {
208
8
                // replaced an Open/Opening entry with DestroySent
209
8
                self.open_count = self.open_count.saturating_sub(1);
210
8
            }
211
34
        }
212
42
    }
213

            
214
    /// Extract the value from this map with 'id' if any
215
34
    pub(super) fn remove(&mut self, id: CircId) -> Option<CircEnt> {
216
47
        self.m.remove(&id).map(|removed| {
217
26
            if !matches!(removed, CircEnt::DestroySent(_)) {
218
18
                self.open_count = self.open_count.saturating_sub(1);
219
18
            }
220
26
            removed
221
47
        })
222
34
    }
223

            
224
    /// Return the total number of open and opening entries in the map
225
88
    pub(super) fn open_ent_count(&self) -> usize {
226
88
        self.open_count
227
88
    }
228

            
229
    // TODO: Eventually if we want relay support, we'll need to support
230
    // circuit IDs chosen by somebody else. But for now, we don't need those.
231
}
232

            
233
#[cfg(test)]
234
mod test {
235
    // @@ begin test lint list maintained by maint/add_warning @@
236
    #![allow(clippy::bool_assert_comparison)]
237
    #![allow(clippy::clone_on_copy)]
238
    #![allow(clippy::dbg_macro)]
239
    #![allow(clippy::mixed_attributes_style)]
240
    #![allow(clippy::print_stderr)]
241
    #![allow(clippy::print_stdout)]
242
    #![allow(clippy::single_char_pattern)]
243
    #![allow(clippy::unwrap_used)]
244
    #![allow(clippy::unchecked_duration_subtraction)]
245
    #![allow(clippy::useless_vec)]
246
    #![allow(clippy::needless_pass_by_value)]
247
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
248
    use super::*;
249
    use crate::fake_mpsc;
250
    use tor_basic_utils::test_rng::testing_rng;
251

            
252
    #[test]
253
    fn circmap_basics() {
254
        let mut map_low = CircMap::new(CircIdRange::Low);
255
        let mut map_high = CircMap::new(CircIdRange::High);
256
        let mut ids_low: Vec<CircId> = Vec::new();
257
        let mut ids_high: Vec<CircId> = Vec::new();
258
        let mut rng = testing_rng();
259

            
260
        assert!(map_low.get_mut(CircId::new(77).unwrap()).is_none());
261

            
262
        for _ in 0..128 {
263
            let (csnd, _) = oneshot::channel();
264
            let (snd, _) = fake_mpsc(8);
265
            let id_low = map_low.add_ent(&mut rng, csnd, snd).unwrap();
266
            assert!(u32::from(id_low) > 0);
267
            assert!(u32::from(id_low) < 0x80000000);
268
            assert!(!ids_low.contains(&id_low));
269
            ids_low.push(id_low);
270

            
271
            assert!(matches!(
272
                *map_low.get_mut(id_low).unwrap(),
273
                CircEnt::Opening(_, _)
274
            ));
275

            
276
            let (csnd, _) = oneshot::channel();
277
            let (snd, _) = fake_mpsc(8);
278
            let id_high = map_high.add_ent(&mut rng, csnd, snd).unwrap();
279
            assert!(u32::from(id_high) >= 0x80000000);
280
            assert!(!ids_high.contains(&id_high));
281
            ids_high.push(id_high);
282
        }
283

            
284
        // Test open / opening entry counting
285
        assert_eq!(128, map_low.open_ent_count());
286
        assert_eq!(128, map_high.open_ent_count());
287

            
288
        // Test remove
289
        assert!(map_low.get_mut(ids_low[0]).is_some());
290
        map_low.remove(ids_low[0]);
291
        assert!(map_low.get_mut(ids_low[0]).is_none());
292
        assert_eq!(127, map_low.open_ent_count());
293

            
294
        // Test DestroySent doesn't count
295
        map_low.destroy_sent(CircId::new(256).unwrap(), HalfCirc::new(1));
296
        assert_eq!(127, map_low.open_ent_count());
297

            
298
        // Test advance_from_opening.
299

            
300
        // Good case.
301
        assert!(map_high.get_mut(ids_high[0]).is_some());
302
        assert!(matches!(
303
            *map_high.get_mut(ids_high[0]).unwrap(),
304
            CircEnt::Opening(_, _)
305
        ));
306
        let adv = map_high.advance_from_opening(ids_high[0]);
307
        assert!(adv.is_ok());
308
        assert!(matches!(
309
            *map_high.get_mut(ids_high[0]).unwrap(),
310
            CircEnt::Open(_)
311
        ));
312

            
313
        // Can't double-advance.
314
        let adv = map_high.advance_from_opening(ids_high[0]);
315
        assert!(adv.is_err());
316

            
317
        // Can't advance an entry that is not there.  We know "77"
318
        // can't be in map_high, since we only added high circids to
319
        // it.
320
        let adv = map_high.advance_from_opening(CircId::new(77).unwrap());
321
        assert!(adv.is_err());
322
    }
323
}