1
//! `HasMemoryCost` and typed memory cost tracking
2

            
3
#![forbid(unsafe_code)] // if you remove this, enable (or write) miri tests (git grep miri)
4

            
5
use crate::internal_prelude::*;
6

            
7
/// Types whose memory usage is known (and stable)
8
///
9
/// ### Important guarantees
10
///
11
/// Implementors of this trait must uphold the guarantees in the API of
12
/// [`memory_cost`](HasMemoryCost::memory_cost).
13
///
14
/// If these guarantees are violated, memory tracking may go wrong,
15
/// with seriously bad implications for the whole program,
16
/// including possible complete denial of service.
17
///
18
/// (Nevertheless, memory safety will not be compromised,
19
/// so trait this is not `unsafe`.)
20
pub trait HasMemoryCost {
21
    /// Returns the memory cost of `self`, in bytes
22
    ///
23
    /// ### Return value must be stable
24
    ///
25
    /// It is vital that the return value does not change, for any particular `self`,
26
    /// unless `self` is mutated through `&mut self` or similar.
27
    /// Otherwise, memory accounting may go awry.
28
    ///
29
    /// If `self` has interior mutability. the changing internal state
30
    /// must not change the memory cost.
31
    ///
32
    /// ### Panics - forbidden
33
    ///
34
    /// This method must not panic.
35
    /// Otherwise, memory accounting may go awry.
36
    fn memory_cost(&self, _: EnabledToken) -> usize;
37
}
38

            
39
/// A [`Participation`] for use only for tracking the memory use of objects of type `T`
40
///
41
/// Wrapping a `Participation` in a `TypedParticipation`
42
/// helps prevent accidentally passing wrongly calculated costs
43
/// to `claim` and `release`.
44
2637
#[derive(Deref, Educe)]
45
#[educe(Clone)]
46
#[educe(Debug(named_field = false))]
47
pub struct TypedParticipation<T> {
48
    /// The actual participation
49
    #[deref]
50
    raw: Participation,
51
    /// Marker
52
    #[educe(Debug(ignore))]
53
    marker: PhantomData<fn(T)>,
54
}
55

            
56
/// Memory cost obtained from a `T`
57
#[derive(Educe, derive_more::Display)]
58
#[educe(Copy, Clone)]
59
#[educe(Debug(named_field = false))]
60
#[display("{raw}")]
61
pub struct TypedMemoryCost<T> {
62
    /// The actual cost in bytes
63
    raw: usize,
64
    /// Marker
65
    #[educe(Debug(ignore))]
66
    marker: PhantomData<fn(T)>,
67
}
68

            
69
/// Types that can return a memory cost known to be the cost of some value of type `T`
70
///
71
/// [`TypedParticipation::claim`] and
72
/// [`release`](TypedParticipation::release)
73
/// take arguments implementing this trait.
74
///
75
/// Implemented by:
76
///
77
///   * `T: HasMemoryCost` (the usual case)
78
///   * `HasTypedMemoryCost<T>` (memory cost, calculated earlier, from a `T`)
79
///
80
/// ### Guarantees
81
///
82
/// This trait has the same guarantees as `HasMemoryCost`.
83
/// Normally, it will not be necessary to add an implementation.
84
// We could seal this trait, but we would need to use a special variant of Sealed,
85
// since we wouldn't want to `impl<T: HasMemoryCost> Sealed for T`
86
// for a normal Sealed trait also used elsewhere.
87
// The bug of implementing this trait for other types seems unlikely,
88
// and we don't think there's a significant API stability hazard.
89
pub trait HasTypedMemoryCost<T>: Sized {
90
    /// The cost, as a `TypedMemoryCost<T>` rather than a raw `usize`
91
    fn typed_memory_cost(&self, _: EnabledToken) -> TypedMemoryCost<T>;
92
}
93

            
94
impl<T: HasMemoryCost> HasTypedMemoryCost<T> for T {
95
17416
    fn typed_memory_cost(&self, enabled: EnabledToken) -> TypedMemoryCost<T> {
96
17416
        TypedMemoryCost::from_raw(self.memory_cost(enabled))
97
17416
    }
98
}
99
impl<T> HasTypedMemoryCost<T> for TypedMemoryCost<T> {
100
17426
    fn typed_memory_cost(&self, _: EnabledToken) -> TypedMemoryCost<T> {
101
17426
        *self
102
17426
    }
103
}
104

            
105
impl<T> TypedParticipation<T> {
106
    /// Wrap a [`Participation`], ensuring that future calls claim and release only `T`
107
4230
    pub fn new(raw: Participation) -> Self {
108
4230
        TypedParticipation {
109
4230
            raw,
110
4230
            marker: PhantomData,
111
4230
        }
112
4230
    }
113

            
114
    /// Record increase in memory use, of a `T: HasMemoryCost` or a `TypedMemoryCost<T>`
115
8904
    pub fn claim(&mut self, t: &impl HasTypedMemoryCost<T>) -> Result<(), Error> {
116
8904
        let Some(enabled) = EnabledToken::new_if_compiled_in() else {
117
            return Ok(());
118
        };
119
8904
        self.raw.claim(t.typed_memory_cost(enabled).raw)
120
8904
    }
121
    /// Record decrease in memory use, of a `T: HasMemoryCost` or a `TypedMemoryCost<T>`
122
8534
    pub fn release(&mut self, t: &impl HasTypedMemoryCost<T>) {
123
8534
        let Some(enabled) = EnabledToken::new_if_compiled_in() else {
124
            return;
125
        };
126
8534
        self.raw.release(t.typed_memory_cost(enabled).raw);
127
8534
    }
128

            
129
    /// Claiming wrapper for a closure
130
    ///
131
    /// Claims the memory, iff `call` succeeds.
132
    ///
133
    /// Specifically:
134
    /// Claims memory for `item`.   If that fails, returns the error.
135
    /// If the claim succeeded, calls `call`.
136
    /// If it fails or panics, the memory is released, undoing the claim,
137
    /// and the error is returned (or the panic propagated).
138
    ///
139
    /// In these error cases, `item` will typically be dropped by `call`,
140
    /// it is not convenient for `call` to do otherwise.
141
    /// If that's wanted, use [`try_claim_or_return`](TypedParticipation::try_claim_or_return).
142
8768
    pub fn try_claim<C, F, E, R>(&mut self, item: C, call: F) -> Result<Result<R, E>, Error>
143
8768
    where
144
8768
        C: HasTypedMemoryCost<T>,
145
8768
        F: FnOnce(C) -> Result<R, E>,
146
8768
    {
147
8768
        self.try_claim_or_return(item, call).map_err(|(e, _item)| e)
148
8768
    }
149

            
150
    /// Claiming wrapper for a closure
151
    ///
152
    /// Claims the memory, iff `call` succeeds.
153
    ///
154
    /// Like [`try_claim`](TypedParticipation::try_claim),
155
    /// but returns the item if memory claim fails.
156
    /// Typically, a failing `call` will need to return the item in `E`.
157
8896
    pub fn try_claim_or_return<C, F, E, R>(
158
8896
        &mut self,
159
8896
        item: C,
160
8896
        call: F,
161
8896
    ) -> Result<Result<R, E>, (Error, C)>
162
8896
    where
163
8896
        C: HasTypedMemoryCost<T>,
164
8896
        F: FnOnce(C) -> Result<R, E>,
165
8896
    {
166
8896
        let Some(enabled) = EnabledToken::new_if_compiled_in() else {
167
            return Ok(call(item));
168
        };
169

            
170
8896
        let cost = item.typed_memory_cost(enabled);
171
8896
        match self.claim(&cost) {
172
8888
            Ok(()) => {}
173
8
            Err(e) => return Err((e, item)),
174
        }
175
        // Unwind safety:
176
        //  - "`F` may not be safely transferred across an unwind boundary"
177
        //    but we don't; it is moved into the closure and
178
        //   it can't obwerve its own panic
179
        //  - "`C` may not be safely transferred across an unwind boundary"
180
        //   Once again, item is moved into call, and never seen again.
181
8888
        match catch_unwind(AssertUnwindSafe(move || call(item))) {
182
4
            Err(panic_payload) => {
183
4
                self.release(&cost);
184
4
                std::panic::resume_unwind(panic_payload)
185
            }
186
14
            Ok(Err(caller_error)) => {
187
14
                self.release(&cost);
188
14
                Ok(Err(caller_error))
189
            }
190
8870
            Ok(Ok(y)) => Ok(Ok(y)),
191
        }
192
8892
    }
193

            
194
    /// Mutably access the inner `Participation`
195
    ///
196
    /// This bypasses the type check.
197
    /// It is up to you to make sure that the `claim` and `release` calls
198
    /// are only made with properly calculated costs.
199
    pub fn as_raw(&mut self) -> &mut Participation {
200
        &mut self.raw
201
    }
202

            
203
    /// Unwrap, and obtain the inner `Participation`
204
2113
    pub fn into_raw(self) -> Participation {
205
2113
        self.raw
206
2113
    }
207
}
208

            
209
impl<T> From<Participation> for TypedParticipation<T> {
210
2117
    fn from(untyped: Participation) -> TypedParticipation<T> {
211
2117
        TypedParticipation::new(untyped)
212
2117
    }
213
}
214

            
215
impl<T> TypedMemoryCost<T> {
216
    /// Convert a raw number of bytes into a type-tagged memory cost
217
17416
    pub fn from_raw(raw: usize) -> Self {
218
17416
        TypedMemoryCost {
219
17416
            raw,
220
17416
            marker: PhantomData,
221
17416
        }
222
17416
    }
223

            
224
    /// Convert a type-tagged memory cost into a raw number of bytes
225
    pub fn into_raw(self) -> usize {
226
        self.raw
227
    }
228
}
229

            
230
#[cfg(all(test, feature = "memquota", not(miri) /* coarsetime */))]
231
mod test {
232
    // @@ begin test lint list maintained by maint/add_warning @@
233
    #![allow(clippy::bool_assert_comparison)]
234
    #![allow(clippy::clone_on_copy)]
235
    #![allow(clippy::dbg_macro)]
236
    #![allow(clippy::mixed_attributes_style)]
237
    #![allow(clippy::print_stderr)]
238
    #![allow(clippy::print_stdout)]
239
    #![allow(clippy::single_char_pattern)]
240
    #![allow(clippy::unwrap_used)]
241
    #![allow(clippy::unchecked_duration_subtraction)]
242
    #![allow(clippy::useless_vec)]
243
    #![allow(clippy::needless_pass_by_value)]
244
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
245
    #![allow(clippy::arithmetic_side_effects)] // don't mind potential panicking ops in tests
246

            
247
    use super::*;
248
    use crate::mtracker::test::*;
249
    use crate::mtracker::*;
250
    use tor_rtmock::MockRuntime;
251

            
252
    // We don't really need to test the correctness, since this is just type wrappers.
253
    // But we should at least demonstrate that the API is usable.
254

            
255
    #[derive(Debug)]
256
    struct DummyParticipant;
257
    impl IsParticipant for DummyParticipant {
258
        fn get_oldest(&self, _: EnabledToken) -> Option<CoarseInstant> {
259
            None
260
        }
261
        fn reclaim(self: Arc<Self>, _: EnabledToken) -> ReclaimFuture {
262
            panic!()
263
        }
264
    }
265

            
266
    struct Costed;
267
    impl HasMemoryCost for Costed {
268
        fn memory_cost(&self, _: EnabledToken) -> usize {
269
            // We nearly exceed the limit with one allocation.
270
            //
271
            // This proves that claim does claim, or we'd underflow on release,
272
            // and that release does release, not claim, or we'd reclaim and crash.
273
            TEST_DEFAULT_LIMIT - mbytes(1)
274
        }
275
    }
276

            
277
    #[test]
278
    fn api() {
279
        MockRuntime::test_with_various(|rt| async move {
280
            let trk = mk_tracker(&rt);
281
            let acct = trk.new_account(None).unwrap();
282
            let particip = Arc::new(DummyParticipant);
283
            let partn = acct
284
                .register_participant(Arc::downgrade(&particip) as _)
285
                .unwrap();
286
            let mut partn: TypedParticipation<Costed> = partn.into();
287

            
288
            partn.claim(&Costed).unwrap();
289
            partn.release(&Costed);
290

            
291
            let cost = Costed.typed_memory_cost(EnabledToken::new());
292
            partn.claim(&cost).unwrap();
293
            partn.release(&cost);
294

            
295
            // claim, then release due to error
296
            partn
297
                .try_claim(Costed, |_: Costed| Err::<Void, _>(()))
298
                .unwrap()
299
                .unwrap_err();
300

            
301
            // claim, then release due to panic
302
            catch_unwind(AssertUnwindSafe(|| {
303
                let didnt_panic =
304
                    partn.try_claim(Costed, |_: Costed| -> Result<Void, Void> { panic!() });
305
                panic!("{:?}", didnt_panic);
306
            }))
307
            .unwrap_err();
308

            
309
            // claim OK, then explicitly release later
310
            let did_claim = partn
311
                .try_claim(Costed, |c: Costed| Ok::<Costed, Void>(c))
312
                .unwrap()
313
                .void_unwrap();
314
            // Check that we did claim at least something!
315
            assert!(trk.used_current_approx().unwrap() > 0);
316

            
317
            partn.release(&did_claim);
318

            
319
            drop(acct);
320
            drop(particip);
321
            drop(trk);
322
            partn
323
                .try_claim(Costed, |_| -> Result<Void, Void> { panic!() })
324
                .unwrap_err();
325

            
326
            rt.advance_until_stalled().await;
327
        });
328
    }
329
}