1
//! Queues that participate in the memory quota system
2
//!
3
//! Wraps a communication channel, such as [`futures::channel::mpsc`],
4
//! tracks the memory use of the queue,
5
//! and participates in the memory quota system.
6
//!
7
//! Each item in the queue must know its memory cost,
8
//! and provide it via [`HasMemoryCost`].
9
//!
10
//! New queues are created by calling the [`new_mq`](ChannelSpec::new_mq) method
11
//! on a [`ChannelSpec`],
12
//! for example [`MpscSpec`] or [`MpscUnboundedSpec`].
13
//!
14
//! The ends implement [`Stream`] and [`Sink`].
15
//! If the underlying channel's sender is `Clone`,
16
//! for example with an MPSC queue, the returned sender is also `Clone`.
17
//!
18
//! Note that the [`Sender`] and [`Receiver`] only hold weak references to the `Account`.
19
//! Ie, the queue is not the accountholder.
20
//! The caller should keep a separate copy of the account.
21
//!
22
//! # Example
23
//!
24
//! ```
25
//! use tor_memquota::{MemoryQuotaTracker, HasMemoryCost, EnabledToken};
26
//! use tor_rtcompat::{DynTimeProvider, PreferredRuntime};
27
//! use tor_memquota::mq_queue::{MpscSpec, ChannelSpec as _};
28
//! # fn m() -> tor_memquota::Result<()> {
29
//!
30
//! #[derive(Debug)]
31
//! struct Message(String);
32
//! impl HasMemoryCost for Message {
33
//!     fn memory_cost(&self, _: EnabledToken) -> usize { self.0.len() }
34
//! }
35
//!
36
//! let runtime = PreferredRuntime::create().unwrap();
37
//! let time_prov = DynTimeProvider::new(runtime.clone());
38
#![cfg_attr(
39
    feature = "memquota",
40
    doc = "let config  = tor_memquota::Config::builder().max(1024*1024*1024).build().unwrap();",
41
    doc = "let trk = MemoryQuotaTracker::new(&runtime, config).unwrap();"
42
)]
43
#![cfg_attr(
44
    not(feature = "memquota"),
45
    doc = "let trk = MemoryQuotaTracker::new_noop();"
46
)]
47
//! let account = trk.new_account(None).unwrap();
48
//!
49
//! let (tx, rx) = MpscSpec { buffer: 10 }.new_mq::<Message>(time_prov, &account)?;
50
//! #
51
//! # Ok(())
52
//! # }
53
//! # m().unwrap();
54
//! ```
55
//!
56
//! # Caveat
57
//!
58
//! The memory use tracking is based on external observations,
59
//! i.e., items inserted and removed.
60
//!
61
//! How well this reflects the actual memory use of the channel
62
//! depends on the channel's implementation.
63
//!
64
//! For example, if the channel uses a single contiguous buffer
65
//! containing the unboxed items, and that buffer doesn't shrink,
66
//! then the memory tracking can be based on an underestimate.
67
//! (This is significantly mitigated if the bulk of the memory use
68
//! for each item is separately boxed.)
69

            
70
#![forbid(unsafe_code)] // if you remove this, enable (or write) miri tests (git grep miri)
71

            
72
use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
73

            
74
use crate::internal_prelude::*;
75

            
76
use std::task::{Context, Poll, Poll::*};
77
use tor_async_utils::{ErasedSinkTrySendError, SinkCloseChannel, SinkTrySend};
78

            
79
//---------- Sender ----------
80

            
81
/// Sender for a channel that participates in the memory quota system
82
///
83
/// Returned by [`ChannelSpec::new_mq`], a method on `C`.
84
/// See the [module-level docs](crate::mq_queue).
85
#[derive(Educe)]
86
#[educe(Debug, Clone(bound = "C::Sender<Entry<T>>: Clone"))]
87
pub struct Sender<T: Debug + Send + 'static, C: ChannelSpec> {
88
    /// The inner sink
89
    tx: C::Sender<Entry<T>>,
90

            
91
    /// Our clone of the `Participation`, for memory accounting
92
    mq: TypedParticipation<Entry<T>>,
93

            
94
    /// Time provider for getting the data age
95
    #[educe(Debug(ignore))] // CoarseTimeProvider isn't Debug
96
    runtime: DynTimeProvider,
97
}
98

            
99
//---------- Receiver ----------
100

            
101
/// Receiver for a channel that participates in the memory quota system
102
///
103
/// Returned by [`ChannelSpec::new_mq`], a method on `C`.
104
/// See the [module-level docs](crate::mq_queue).
105
#[derive(Educe)] // not Clone, see below
106
#[educe(Debug)]
107
pub struct Receiver<T: Debug + Send + 'static, C: ChannelSpec> {
108
    /// Payload
109
    //
110
    // We don't make this an "exposed" `Arc`,
111
    // because that would allow the caller to clone it -
112
    // but we don't promise we're a multi-consumer queue even if `C::Receiver` is.
113
    //
114
    // Despite the in-principle Clone-ability of our `Receiver`,
115
    // we're not a working multi-consumer queue, even if the underlying channel is,
116
    // because StreamUnobtrusivePeeker isn't multi-consumer.
117
    //
118
    // Providing the multi-consumer feature would perhaps involve StreamUnobtrusivePeeker
119
    // handling multiple wakers, and then `impl Clone for Receiver where C::Receiver: Clone`.
120
    // (and writing a bunch of tests).
121
    //
122
    // This would all be useless without also `impl ChannelSpec`
123
    // for a multi-consumer queue.
124
    inner: Arc<ReceiverInner<T, C>>,
125
}
126

            
127
/// Payload of `Receiver`, that's within the `Arc`, but contains the `Mutex`.
128
///
129
/// This is a separate type because
130
/// it's what we need to implement [`IsParticipant`] for.
131
#[derive(Educe)]
132
#[educe(Debug)]
133
struct ReceiverInner<T: Debug + Send + 'static, C: ChannelSpec> {
134
    /// Mutable state
135
    ///
136
    /// If we have collapsed due to memory reclaim, state is replaced by an `Err`.
137
    /// In that case the caller mostly can't send on the Sender either,
138
    /// because we'll have torn down the Participant,
139
    /// so claims (beyond the cache in the `Sender`'s `Participation`) will fail.
140
    state: Mutex<Result<ReceiverState<T, C>, CollapsedDueToReclaim>>,
141
}
142

            
143
/// Mutable state of a `Receiver`
144
///
145
/// Normally the mutex is only locked by the receiving task.
146
/// On memory pressure, mutex is acquired by the memory system,
147
/// which has a clone of the `Arc<ReceiverInner>`.
148
///
149
/// Within `Arc<Mutex<Result<, >>>`.
150
#[derive(Educe)]
151
#[educe(Debug)]
152
struct ReceiverState<T: Debug + Send + 'static, C: ChannelSpec> {
153
    /// The inner stream, but with an unobtrusive peek for getting the oldest data age
154
    rx: StreamUnobtrusivePeeker<C::Receiver<Entry<T>>>,
155

            
156
    /// The `Participation`, which we use for memory accounting
157
    ///
158
    /// ### Performance and locality
159
    ///
160
    /// We have separate [`Participation`]s for rx and tx.
161
    /// The tx is constantly claiming and the rx releasing;
162
    /// at least each MAX_CACHE, they must balance out
163
    /// via the (fairly globally shared) `MemoryQuotaTracker`.
164
    ///
165
    /// If this turns out to be a problem,
166
    /// we could arrange to share a `Participation`.
167
    mq: TypedParticipation<Entry<T>>,
168

            
169
    /// Hooks passed to [`Receiver::register_collapse_hook`]
170
    ///
171
    /// When receiver dropped, or memory reclaimed, we call all of these.
172
    #[educe(Debug(method = "receiver_state_debug_collapse_notify"))]
173
    collapse_callbacks: Vec<CollapseCallback>,
174
}
175

            
176
//---------- other types ----------
177

            
178
/// Entry in in the inner queue
179
#[derive(Debug)]
180
struct Entry<T> {
181
    /// The actual entry
182
    t: T,
183
    /// The data age - when it was inserted into the queue
184
    when: CoarseInstant,
185
}
186

            
187
/// Error returned when trying to write to a [`Sender`]
188
#[derive(Error, Clone, Debug)]
189
#[non_exhaustive]
190
pub enum SendError<CE> {
191
    /// The underlying channel rejected the message
192
    // Can't be `#[from]` because rustc can't see that C::SendError isn't SendError<C>
193
    #[error("channel send failed")]
194
    Channel(#[source] CE),
195

            
196
    /// The memory quota system prevented the send
197
    ///
198
    /// NB: when the channel is torn down due to memory pressure,
199
    /// the inner receiver is also torn down.
200
    /// This means that this variant is not always reported:
201
    /// sending on the sender in this situation
202
    /// may give [`SendError::Channel`] instead.
203
    #[error("memory quota exhausted, queue reclaimed")]
204
    Memquota(#[from] Error),
205
}
206

            
207
/// Callback passed to `Receiver::register_collapse_hook`
208
pub type CollapseCallback = Box<dyn FnOnce(CollapseReason) + Send + Sync + 'static>;
209

            
210
/// Argument to `CollapseCallback`: why are we collapsing?
211
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
212
#[non_exhaustive]
213
pub enum CollapseReason {
214
    /// The `Receiver` was dropped
215
    ReceiverDropped,
216

            
217
    /// The memory quota system asked us to reclaim memory
218
    MemoryReclaimed,
219
}
220

            
221
/// Marker, appears in state as `Err` to mean "we have collapsed"
222
#[derive(Debug, Clone, Copy)]
223
struct CollapsedDueToReclaim;
224

            
225
//==================== Channel ====================
226

            
227
/// Specification for a communication channel
228
///
229
/// Implemented for [`MpscSpec`] and [`MpscUnboundedSpec`].
230
//
231
// # Correctness (uncomment this if this trait is made unsealed)
232
//
233
// It is a requirement that this object really is some kind of channel.
234
// Specifically:
235
//
236
//  * Things that get put into the `Sender` must eventually emerge from the `Receiver`.
237
//  * Nothing may emerge from the `Receiver` that wasn't put into the `Sender`.
238
//  * If the `Sender` and `Receiver` are dropped, the items must also get dropped.
239
//
240
// If these requirements are violated, it could result in corruption of the memory accounts
241
//
242
// Ideally, if the `Receiver` is dropped, most of the items are dropped soon.
243
//
244
pub trait ChannelSpec: Sealed /* see Correctness, above */ + Sized + 'static {
245
    /// The sending [`Sink`] for items of type `T`.
246
    //
247
    // Right now we insist that everything is Unpin.
248
    // futures::channel::mpsc's types all are.
249
    // If we wanted to support !Unpin channels, that would be possible,
250
    // but we would have some work to do.
251
    //
252
    // We also insist that everything is Debug.  That means `T: Debug`,
253
    // as well as the channels.  We could avoid that, but it would involve
254
    // skipping debug of important fields, or pervasive complex trait bounds
255
    // (Eg `#[educe(Debug(bound = "C::Receiver<Entry<T>>: Debug"))]` or worse.)
256
    //
257
    // This is a GAT because we need to instantiate it with T=Entry<_>.
258
    type Sender<T: Debug + Send + 'static>: Sink<T, Error = Self::SendError>
259
        + Debug + Unpin + Sized;
260

            
261
    /// The receiving [`Stream`] for items of type `T`.
262
    type Receiver<T: Debug + Send + 'static>: Stream<Item = T> + Debug + Unpin + Send + Sized;
263

            
264
    /// The error type `<Receiver<_> as Stream>::Error`.
265
    ///
266
    /// (For this trait to be implemented, it is not allowed to depend on `T`.)
267
    type SendError: std::error::Error;
268

            
269
    /// Create a new channel, based on the spec `self`, that participates in the memory quota
270
    ///
271
    /// See the [module-level docs](crate::mq_queue) for an example.
272
    //
273
    // This method is supposed to be called by the user, not overridden.
274
    #[allow(clippy::type_complexity)] // the Result; not sensibly reducible or aliasable
275
2500
    fn new_mq<T>(self, runtime: DynTimeProvider, account: &Account) -> crate::Result<(
276
2500
        Sender<T, Self>,
277
2500
        Receiver<T, Self>,
278
2500
    )>
279
2500
    where
280
2500
        T: HasMemoryCost + Debug + Send + 'static,
281
    {
282
2500
        let (rx, (tx, mq)) = account.register_participant_with(
283
2500
            runtime.now_coarse(),
284
2500
            move |mq| {
285
2500
                let mq = TypedParticipation::new(mq);
286
2500
                let collapse_callbacks = vec![];
287
2500
                let (tx, rx) = self.raw_channel::<Entry<T>>();
288
2500
                let rx = StreamUnobtrusivePeeker::new(rx);
289
2500
                let state = ReceiverState { rx, mq: mq.clone(), collapse_callbacks };
290
2500
                let state = Mutex::new(Ok(state));
291
2500
                let inner = ReceiverInner { state };
292
2500
                Ok::<_, crate::Error>((inner.into(), (tx, mq)))
293
2500
            },
294
        )??;
295

            
296
2500
        let runtime = runtime.clone();
297

            
298
2500
        let tx = Sender { runtime, tx, mq };
299
2500
        let rx = Receiver { inner: rx };
300

            
301
2500
        Ok((tx, rx))
302
2500
    }
303

            
304
    /// Create a new raw channel as specified by `self`
305
    //
306
    // This is called by `mq_queue`.
307
    fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>);
308

            
309
    /// Close the receiver, preventing further sends
310
    ///
311
    /// This should ensure that only a smallish bounded number of further items
312
    /// can be sent, before errors start being returned.
313
    fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>);
314
}
315

            
316
//---------- impls of Channel ----------
317

            
318
/// Specification for a (bounded) MPSC channel
319
///
320
/// Corresponds to the constructor [`futures::channel::mpsc::channel`].
321
///
322
/// Call [`new_mq`](ChannelSpec::new_mq) on a value of this type.
323
///
324
/// (The [`new`](MpscUnboundedSpec::new) method is provided for convenience;
325
/// you may also construct the value directly.)
326
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor)]
327
#[allow(clippy::exhaustive_structs)] // This is precisely the arguments to mpsc::channel
328
pub struct MpscSpec {
329
    /// Buffer size; see [`futures::channel::mpsc::channel`].
330
    pub buffer: usize,
331
}
332

            
333
/// Specification for an unbounded MPSC channel
334
///
335
/// Corresponds to the constructor [`futures::channel::mpsc::unbounded`].
336
///
337
/// Call [`new_mq`](ChannelSpec::new_mq) on a value of this unit type.
338
///
339
/// (The [`new`](MpscUnboundedSpec::new) method is provided for orthogonality.)
340
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor, Default)]
341
#[allow(clippy::exhaustive_structs)] // This is precisely the arguments to mpsc::unbounded
342
pub struct MpscUnboundedSpec;
343

            
344
impl Sealed for MpscSpec {}
345
impl Sealed for MpscUnboundedSpec {}
346

            
347
impl ChannelSpec for MpscSpec {
348
    type Sender<T: Debug + Send + 'static> = mpsc::Sender<T>;
349
    type Receiver<T: Debug + Send + 'static> = mpsc::Receiver<T>;
350
    type SendError = mpsc::SendError;
351

            
352
2096
    fn raw_channel<T: Debug + Send + 'static>(self) -> (mpsc::Sender<T>, mpsc::Receiver<T>) {
353
2096
        mpsc_channel_no_memquota(self.buffer)
354
2096
    }
355

            
356
1816
    fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
357
1816
        rx.close();
358
1816
    }
359
}
360

            
361
impl ChannelSpec for MpscUnboundedSpec {
362
    type Sender<T: Debug + Send + 'static> = mpsc::UnboundedSender<T>;
363
    type Receiver<T: Debug + Send + 'static> = mpsc::UnboundedReceiver<T>;
364
    type SendError = mpsc::SendError;
365

            
366
396
    fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>) {
367
396
        mpsc::unbounded()
368
396
    }
369

            
370
396
    fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
371
396
        rx.close();
372
396
    }
373
}
374

            
375
//==================== implementations ====================
376

            
377
//---------- Sender ----------
378

            
379
impl<T, C> Sink<T> for Sender<T, C>
380
where
381
    T: HasMemoryCost + Debug + Send + 'static,
382
    C: ChannelSpec,
383
{
384
    type Error = SendError<C::SendError>;
385

            
386
25788
    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
387
25788
        self.get_mut()
388
25788
            .tx
389
25788
            .poll_ready_unpin(cx)
390
25788
            .map_err(SendError::Channel)
391
25788
    }
392

            
393
9646
    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
394
9646
        let self_ = self.get_mut();
395
9646
        let item = Entry {
396
9646
            t: item,
397
9646
            when: self_.runtime.now_coarse(),
398
9646
        };
399
9646
        self_.mq.try_claim(item, |item| {
400
9646
            self_.tx.start_send_unpin(item).map_err(SendError::Channel)
401
9646
        })?
402
9646
    }
403

            
404
8364
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
405
8364
        self.tx
406
8364
            .poll_flush_unpin(cx)
407
8364
            .map(|r| r.map_err(SendError::Channel))
408
8364
    }
409

            
410
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
411
        self.tx
412
            .poll_close_unpin(cx)
413
            .map(|r| r.map_err(SendError::Channel))
414
    }
415
}
416

            
417
impl<T, C> SinkTrySend<T> for Sender<T, C>
418
where
419
    T: HasMemoryCost + Debug + Send + 'static,
420
    C: ChannelSpec,
421
    C::Sender<Entry<T>>: SinkTrySend<Entry<T>>,
422
    <C::Sender<Entry<T>> as SinkTrySend<Entry<T>>>::Error: Send + Sync,
423
{
424
    type Error = ErasedSinkTrySendError;
425
252
    fn try_send_or_return(
426
252
        self: Pin<&mut Self>,
427
252
        item: T,
428
252
    ) -> Result<(), (<Self as SinkTrySend<T>>::Error, T)> {
429
252
        let self_ = self.get_mut();
430
252
        let item = Entry {
431
252
            t: item,
432
252
            when: self_.runtime.now_coarse(),
433
252
        };
434

            
435
        use ErasedSinkTrySendError as ESTSE;
436

            
437
252
        self_
438
252
            .mq
439
252
            .try_claim_or_return(item, |item| {
440
248
                Pin::new(&mut self_.tx).try_send_or_return(item)
441
248
            })
442
252
            .map_err(|(mqe, unsent)| (ESTSE::Other(Arc::new(mqe)), unsent.t))?
443
248
            .map_err(|(tse, unsent)| (ESTSE::from(tse), unsent.t))
444
252
    }
445
}
446

            
447
impl<T, C> SinkCloseChannel<T> for Sender<T, C>
448
where
449
    T: HasMemoryCost + Debug + Send, //Debug + 'static,
450
    C: ChannelSpec,
451
    C::Sender<Entry<T>>: SinkCloseChannel<Entry<T>>,
452
{
453
12
    fn close_channel(self: Pin<&mut Self>) {
454
12
        Pin::new(&mut self.get_mut().tx).close_channel();
455
12
    }
456
}
457

            
458
impl<T, C> Sender<T, C>
459
where
460
    T: Debug + Send + 'static,
461
    C: ChannelSpec,
462
{
463
    /// Obtain a reference to the `Sender`'s [`DynTimeProvider`]
464
    ///
465
    /// (This can sometimes be used to avoid having to keep
466
    /// a separate clone of the time provider.)
467
456
    pub fn time_provider(&self) -> &DynTimeProvider {
468
456
        &self.runtime
469
456
    }
470
}
471

            
472
//---------- Receiver ----------
473

            
474
impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Stream for Receiver<T, C> {
475
    type Item = T;
476

            
477
18688
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
478
18688
        let mut state = self.inner.lock();
479
18688
        let state = match &mut *state {
480
18684
            Ok(y) => y,
481
4
            Err(CollapsedDueToReclaim) => return Ready(None),
482
        };
483
18684
        let ret = state.rx.poll_next_unpin(cx);
484
9594
        if let Ready(Some(item)) = &ret {
485
9426
            if let Some(enabled) = EnabledToken::new_if_compiled_in() {
486
9426
                let cost = item.typed_memory_cost(enabled);
487
9426
                state.mq.release(&cost);
488
9426
            }
489
9258
        }
490
18684
        ret.map(|r| r.map(|e| e.t))
491
18688
    }
492
}
493

            
494
impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> FusedStream for Receiver<T, C>
495
where
496
    C::Receiver<Entry<T>>: FusedStream,
497
{
498
5618
    fn is_terminated(&self) -> bool {
499
5618
        match &*self.inner.lock() {
500
5618
            Ok(y) => y.rx.is_terminated(),
501
            Err(CollapsedDueToReclaim) => true,
502
        }
503
5618
    }
504
}
505

            
506
// TODO: When we have a trait for peekable streams, Receiver should implement it
507

            
508
impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Receiver<T, C> {
509
    /// Register a callback, called when we tear the channel down
510
    ///
511
    /// This will be called when the `Receiver` is dropped,
512
    /// or if we tear down because the memory system asks us to reclaim.
513
    ///
514
    /// `call` might be called at any time, from any thread, but
515
    /// it won't be holding any locks relating to memory quota or the queue.
516
    ///
517
    /// If `self` is *already* in the process of being torn down,
518
    /// `call` might be called immediately, reentrantly!
519
    //
520
    // This callback is nicer than us handing out an mpsc rx
521
    // which user must read and convert items from.
522
    //
523
    // This method is on Receiver because that has the State,
524
    // but could be called during setup to hook both sender's and
525
    // receiver's shutdown mechanisms.
526
    pub fn register_collapse_hook(&self, call: CollapseCallback) {
527
        let mut state = self.inner.lock();
528
        let state = match &mut *state {
529
            Ok(y) => y,
530
            Err(reason) => {
531
                let reason = (*reason).into();
532
                drop::<MutexGuard<_>>(state);
533
                call(reason);
534
                return;
535
            }
536
        };
537
        state.collapse_callbacks.push(call);
538
    }
539
}
540

            
541
impl<T: Debug + Send + 'static, C: ChannelSpec> ReceiverInner<T, C> {
542
    /// Convenience function to take the lock
543
24338
    fn lock(&self) -> MutexGuard<Result<ReceiverState<T, C>, CollapsedDueToReclaim>> {
544
24338
        self.state.lock().expect("mq_mpsc lock poisoned")
545
24338
    }
546
}
547

            
548
impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> IsParticipant
549
    for ReceiverInner<T, C>
550
{
551
16
    fn get_oldest(&self, _: EnabledToken) -> Option<CoarseInstant> {
552
16
        let mut state = self.lock();
553
16
        let state = match &mut *state {
554
16
            Ok(y) => y,
555
            Err(CollapsedDueToReclaim) => return None,
556
        };
557
16
        Pin::new(&mut state.rx)
558
16
            .unobtrusive_peek()
559
16
            .map(|peeked| peeked.when)
560
16
    }
561

            
562
16
    fn reclaim(self: Arc<Self>, _: EnabledToken) -> mtracker::ReclaimFuture {
563
16
        Box::pin(async move {
564
16
            let reason = CollapsedDueToReclaim;
565
16
            let mut state_guard = self.lock();
566
16
            let state = mem::replace(&mut *state_guard, Err(reason));
567
16
            drop::<MutexGuard<_>>(state_guard);
568
            #[allow(clippy::single_match)] // pattern is intentional.
569
16
            match state {
570
16
                Ok(mut state) => {
571
16
                    for call in state.collapse_callbacks.drain(..) {
572
                        call(reason.into());
573
                    }
574
16
                    drop::<ReceiverState<_, _>>(state); // will drain queue, too
575
                }
576
                Err(CollapsedDueToReclaim) => {}
577
            };
578
16
            mtracker::Reclaimed::Collapsing
579
16
        })
580
16
    }
581
}
582

            
583
impl<T: Debug + Send + 'static, C: ChannelSpec> Drop for ReceiverState<T, C> {
584
2388
    fn drop(&mut self) {
585
        // If there's a mutex, we're in its drop
586

            
587
        // `destroy_participant` prevents the sender from making further non-cached claims
588
2388
        mem::replace(&mut self.mq, Participation::new_dangling().into())
589
2388
            .into_raw()
590
2388
            .destroy_participant();
591

            
592
2388
        for call in self.collapse_callbacks.drain(..) {
593
            call(CollapseReason::ReceiverDropped);
594
        }
595

            
596
        // try to free whatever is in the queue, in case the stream doesn't do that itself
597
        // No-one can poll us any more, so we are no longer interested in wakeups
598
2388
        let mut noop_cx = Context::from_waker(noop_waker_ref());
599

            
600
        // prevent further sends, so that our drain doesn't race indefinitely with the sender
601
2220
        if let Some(mut rx_inner) =
602
2388
            StreamUnobtrusivePeeker::as_raw_inner_pin_mut(Pin::new(&mut self.rx))
603
2220
        {
604
2220
            C::close_receiver(&mut rx_inner);
605
2220
        }
606

            
607
2740
        while let Ready(Some(item)) = self.rx.poll_next_unpin(&mut noop_cx) {
608
352
            drop::<Entry<T>>(item);
609
352
        }
610
2388
    }
611
}
612

            
613
/// Method for educe's Debug impl for `ReceiverState.collapse_callbacks`
614
2162
fn receiver_state_debug_collapse_notify(
615
2162
    v: &[CollapseCallback],
616
2162
    f: &mut fmt::Formatter,
617
2162
) -> fmt::Result {
618
2162
    Debug::fmt(&v.len(), f)
619
2162
}
620

            
621
//---------- misc ----------
622

            
623
impl<T: HasMemoryCost> HasMemoryCost for Entry<T> {
624
19324
    fn memory_cost(&self, enabled: EnabledToken) -> usize {
625
19324
        let time_size = std::alloc::Layout::new::<CoarseInstant>().size();
626
19324
        self.t.memory_cost(enabled).saturating_add(time_size)
627
19324
    }
628
}
629

            
630
impl From<CollapsedDueToReclaim> for CollapseReason {
631
    fn from(CollapsedDueToReclaim: CollapsedDueToReclaim) -> CollapseReason {
632
        CollapseReason::MemoryReclaimed
633
    }
634
}
635

            
636
#[cfg(all(test, feature = "memquota", not(miri) /* coarsetime */))]
637
mod test {
638
    // @@ begin test lint list maintained by maint/add_warning @@
639
    #![allow(clippy::bool_assert_comparison)]
640
    #![allow(clippy::clone_on_copy)]
641
    #![allow(clippy::dbg_macro)]
642
    #![allow(clippy::mixed_attributes_style)]
643
    #![allow(clippy::print_stderr)]
644
    #![allow(clippy::print_stdout)]
645
    #![allow(clippy::single_char_pattern)]
646
    #![allow(clippy::unwrap_used)]
647
    #![allow(clippy::unchecked_duration_subtraction)]
648
    #![allow(clippy::useless_vec)]
649
    #![allow(clippy::needless_pass_by_value)]
650
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
651
    #![allow(clippy::arithmetic_side_effects)] // don't mind potential panicking ops in tests
652

            
653
    use super::*;
654
    use crate::mtracker::test::*;
655
    use tor_rtmock::MockRuntime;
656
    use tracing::debug;
657
    use tracing_test::traced_test;
658

            
659
    #[derive(Default, Debug)]
660
    struct ItemTracker {
661
        state: Mutex<ItemTrackerState>,
662
    }
663
    #[derive(Default, Debug)]
664
    struct ItemTrackerState {
665
        existing: usize,
666
        next_id: usize,
667
    }
668

            
669
    #[derive(Debug)]
670
    struct Item {
671
        id: usize,
672
        tracker: Arc<ItemTracker>,
673
    }
674

            
675
    impl ItemTracker {
676
        fn new_item(self: &Arc<Self>) -> Item {
677
            let mut state = self.lock();
678
            let id = state.next_id;
679
            state.existing += 1;
680
            state.next_id += 1;
681
            debug!("new {id}");
682
            Item {
683
                tracker: self.clone(),
684
                id,
685
            }
686
        }
687

            
688
        fn new_tracker() -> Arc<Self> {
689
            Arc::default()
690
        }
691

            
692
        fn lock(&self) -> MutexGuard<ItemTrackerState> {
693
            self.state.lock().unwrap()
694
        }
695
    }
696

            
697
    impl Drop for Item {
698
        fn drop(&mut self) {
699
            debug!("old {}", self.id);
700
            self.tracker.state.lock().unwrap().existing -= 1;
701
        }
702
    }
703

            
704
    impl HasMemoryCost for Item {
705
        fn memory_cost(&self, _: EnabledToken) -> usize {
706
            mbytes(1)
707
        }
708
    }
709

            
710
    struct Setup {
711
        dtp: DynTimeProvider,
712
        trk: Arc<mtracker::MemoryQuotaTracker>,
713
        acct: Account,
714
        itrk: Arc<ItemTracker>,
715
    }
716

            
717
    fn setup(rt: &MockRuntime) -> Setup {
718
        let dtp = DynTimeProvider::new(rt.clone());
719
        let trk = mk_tracker(rt);
720
        let acct = trk.new_account(None).unwrap();
721
        let itrk = ItemTracker::new_tracker();
722
        Setup {
723
            dtp,
724
            trk,
725
            acct,
726
            itrk,
727
        }
728
    }
729

            
730
    #[derive(Debug)]
731
    struct Gigantic;
732
    impl HasMemoryCost for Gigantic {
733
        fn memory_cost(&self, _et: EnabledToken) -> usize {
734
            mbytes(100)
735
        }
736
    }
737

            
738
    impl Setup {
739
        /// Check that claims and releases have balanced out
740
        ///
741
        /// `n_queues` is the number of queues that exist.
742
        /// This is used to provide some slop, since each queue has two [`Participation`]s
743
        /// each of which can have some cached claim.
744
        fn check_zero_claimed(&self, n_queues: usize) {
745
            let used = self.trk.used_current_approx();
746
            debug!(
747
                "checking zero balance (with slop {n_queues} * 2 * {}; used={used:?}",
748
                *mtracker::MAX_CACHE,
749
            );
750
            assert!(used.unwrap() <= n_queues * 2 * *mtracker::MAX_CACHE);
751
        }
752
    }
753

            
754
    #[traced_test]
755
    #[test]
756
    fn lifecycle() {
757
        MockRuntime::test_with_various(|rt| async move {
758
            let s = setup(&rt);
759
            let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
760

            
761
            tx.send(s.itrk.new_item()).await.unwrap();
762
            let _: Item = rx.next().await.unwrap();
763

            
764
            for _ in 0..20 {
765
                tx.send(s.itrk.new_item()).await.unwrap();
766
            }
767

            
768
            // reclaim task hasn't had a chance to run
769
            debug!("still existing items {}", s.itrk.lock().existing);
770

            
771
            rt.advance_until_stalled().await;
772

            
773
            // reclaim task should have torn everything down
774
            assert!(s.itrk.lock().existing == 0);
775

            
776
            assert!(rx.next().await.is_none());
777

            
778
            // Empirically, this is a "disconnected" error from the inner mpsc,
779
            // but let's not assert that.
780
            let _: SendError<_> = tx.send(s.itrk.new_item()).await.unwrap_err();
781
        });
782
    }
783

            
784
    #[traced_test]
785
    #[test]
786
    fn fill_and_empty() {
787
        MockRuntime::test_with_various(|rt| async move {
788
            let s = setup(&rt);
789
            let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
790

            
791
            const COUNT: usize = 19;
792

            
793
            for _ in 0..COUNT {
794
                tx.send(s.itrk.new_item()).await.unwrap();
795
            }
796

            
797
            rt.advance_until_stalled().await;
798

            
799
            for _ in 0..COUNT {
800
                let _: Item = rx.next().await.unwrap();
801
            }
802

            
803
            rt.advance_until_stalled().await;
804

            
805
            // no memory should be claimed
806
            s.check_zero_claimed(1);
807
        });
808
    }
809

            
810
    #[traced_test]
811
    #[test]
812
    fn sink_error() {
813
        #[derive(Debug, Copy, Clone)]
814
        struct BustedSink {
815
            error: BustedError,
816
        }
817

            
818
        impl<T> Sink<T> for BustedSink {
819
            type Error = BustedError;
820

            
821
            fn poll_ready(
822
                self: Pin<&mut Self>,
823
                _: &mut Context<'_>,
824
            ) -> Poll<Result<(), Self::Error>> {
825
                Ready(Err(self.error))
826
            }
827
            fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
828
                panic!("poll_ready always gives error, start_send should not be called");
829
            }
830
            fn poll_flush(
831
                self: Pin<&mut Self>,
832
                _: &mut Context<'_>,
833
            ) -> Poll<Result<(), Self::Error>> {
834
                Ready(Ok(()))
835
            }
836
            fn poll_close(
837
                self: Pin<&mut Self>,
838
                _: &mut Context<'_>,
839
            ) -> Poll<Result<(), Self::Error>> {
840
                Ready(Ok(()))
841
            }
842
        }
843

            
844
        impl<T> SinkTrySend<T> for BustedSink {
845
            type Error = BustedError;
846

            
847
            fn try_send_or_return(self: Pin<&mut Self>, item: T) -> Result<(), (BustedError, T)> {
848
                Err((self.error, item))
849
            }
850
        }
851

            
852
        impl tor_async_utils::SinkTrySendError for BustedError {
853
            fn is_disconnected(&self) -> bool {
854
                self.is_disconnected
855
            }
856
            fn is_full(&self) -> bool {
857
                false
858
            }
859
        }
860

            
861
        #[derive(Error, Debug, Clone, Copy)]
862
        #[error("busted, for testing, dc={is_disconnected:?}")]
863
        struct BustedError {
864
            is_disconnected: bool,
865
        }
866

            
867
        struct BustedQueueSpec {
868
            error: BustedError,
869
        }
870
        impl Sealed for BustedQueueSpec {}
871
        impl ChannelSpec for BustedQueueSpec {
872
            type Sender<T: Debug + Send + 'static> = BustedSink;
873
            type Receiver<T: Debug + Send + 'static> = futures::stream::Pending<T>;
874
            type SendError = BustedError;
875
            fn raw_channel<T: Debug + Send + 'static>(self) -> (BustedSink, Self::Receiver<T>) {
876
                (BustedSink { error: self.error }, futures::stream::pending())
877
            }
878
            fn close_receiver<T: Debug + Send + 'static>(_rx: &mut Self::Receiver<T>) {}
879
        }
880

            
881
        use ErasedSinkTrySendError as ESTSE;
882

            
883
        MockRuntime::test_with_various(|rt| async move {
884
            let error = BustedError {
885
                is_disconnected: true,
886
            };
887

            
888
            let s = setup(&rt);
889
            let (mut tx, _rx) = BustedQueueSpec { error }
890
                .new_mq(s.dtp.clone(), &s.acct)
891
                .unwrap();
892

            
893
            let e = tx.send(s.itrk.new_item()).await.unwrap_err();
894
            assert!(matches!(e, SendError::Channel(BustedError { .. })));
895

            
896
            // item should have been destroyed
897
            assert_eq!(s.itrk.lock().existing, 0);
898

            
899
            // ---- Test try_send error handling ----
900

            
901
            fn error_is_other_of<E>(e: ESTSE) -> Result<(), impl Debug>
902
            where
903
                E: std::error::Error + 'static,
904
            {
905
                match e {
906
                    ESTSE::Other(e) if e.is::<E>() => Ok(()),
907
                    other => Err(other),
908
                }
909
            }
910

            
911
            let item = s.itrk.new_item();
912

            
913
            // Test try_send failure due to BustedError, is_disconnected: true
914

            
915
            let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
916
            assert!(matches!(e, ESTSE::Disconnected), "{e:?}");
917

            
918
            // Test try_send failure due to BustedError, is_disconnected: false (ie, Other)
919

            
920
            let error = BustedError {
921
                is_disconnected: false,
922
            };
923
            let (mut tx, _rx) = BustedQueueSpec { error }
924
                .new_mq(s.dtp.clone(), &s.acct)
925
                .unwrap();
926
            let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
927
            error_is_other_of::<BustedError>(e).unwrap();
928

            
929
            // no memory should be claimed
930
            s.check_zero_claimed(1);
931

            
932
            // Test try_send failure due to memory quota collapse
933

            
934
            // cause reclaim
935
            {
936
                let (mut tx, _rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
937
                tx.send(Gigantic).await.unwrap();
938
                rt.advance_until_stalled().await;
939
            }
940

            
941
            let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
942
            error_is_other_of::<crate::Error>(e).unwrap();
943

            
944
            drop::<Item>(item);
945
        });
946
    }
947
}