1
//! Queues that participate in the memory quota system
2
//!
3
//! Wraps a communication channel, such as [`futures::channel::mpsc`],
4
//! tracks the memory use of the queue,
5
//! and participates in the memory quota system.
6
//!
7
//! Each item in the queue must know its memory cost,
8
//! and provide it via [`HasMemoryCost`].
9
//!
10
//! New queues are created by calling the [`new_mq`](ChannelSpec::new_mq) method
11
//! on a [`ChannelSpec`],
12
//! for example [`MpscSpec`] or [`MpscUnboundedSpec`].
13
//!
14
//! The ends implement [`Stream`] and [`Sink`].
15
//! If the underlying channel's sender is `Clone`,
16
//! for example with an MPSC queue, the returned sender is also `Clone`.
17
//!
18
//! Note that the [`Sender`] and [`Receiver`] only hold weak references to the `Account`.
19
//! Ie, the queue is not the accountholder.
20
//! The caller should keep a separate copy of the account.
21
//!
22
//! # Example
23
//!
24
//! ```
25
//! use tor_memquota::{MemoryQuotaTracker, HasMemoryCost, EnabledToken};
26
//! use tor_rtcompat::{DynTimeProvider, PreferredRuntime};
27
//! use tor_memquota::mq_queue::{MpscSpec, ChannelSpec as _};
28
//! # fn m() -> tor_memquota::Result<()> {
29
//!
30
//! #[derive(Debug)]
31
//! struct Message(String);
32
//! impl HasMemoryCost for Message {
33
//!     fn memory_cost(&self, _: EnabledToken) -> usize { self.0.len() }
34
//! }
35
//!
36
//! let runtime = PreferredRuntime::create().unwrap();
37
//! let time_prov = DynTimeProvider::new(runtime.clone());
38
#![cfg_attr(
39
    feature = "memquota",
40
    doc = "let config  = tor_memquota::Config::builder().max(1024*1024*1024).build().unwrap();",
41
    doc = "let trk = MemoryQuotaTracker::new(&runtime, config).unwrap();"
42
)]
43
#![cfg_attr(
44
    not(feature = "memquota"),
45
    doc = "let trk = MemoryQuotaTracker::new_noop();"
46
)]
47
//! let account = trk.new_account(None).unwrap();
48
//!
49
//! let (tx, rx) = MpscSpec { buffer: 10 }.new_mq::<Message>(time_prov, &account)?;
50
//! #
51
//! # Ok(())
52
//! # }
53
//! # m().unwrap();
54
//! ```
55
//!
56
//! # Caveat
57
//!
58
//! The memory use tracking is based on external observations,
59
//! i.e., items inserted and removed.
60
//!
61
//! How well this reflects the actual memory use of the channel
62
//! depends on the channel's implementation.
63
//!
64
//! For example, if the channel uses a single contiguous buffer
65
//! containing the unboxed items, and that buffer doesn't shrink,
66
//! then the memory tracking can be based on an underestimate.
67
//! (This is significantly mitigated if the bulk of the memory use
68
//! for each item is separately boxed.)
69

            
70
#![forbid(unsafe_code)] // if you remove this, enable (or write) miri tests (git grep miri)
71

            
72
use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
73

            
74
use crate::internal_prelude::*;
75

            
76
use std::task::{Context, Poll, Poll::*};
77
use tor_async_utils::{ErasedSinkTrySendError, SinkCloseChannel, SinkTrySend};
78

            
79
//---------- Sender ----------
80

            
81
/// Sender for a channel that participates in the memory quota system
82
///
83
/// Returned by [`ChannelSpec::new_mq`], a method on `C`.
84
/// See the [module-level docs](crate::mq_queue).
85
360
#[derive(Educe)]
86
#[educe(Debug, Clone(bound = "C::Sender<Entry<T>>: Clone"))]
87
pub struct Sender<T: Debug + Send + 'static, C: ChannelSpec> {
88
    /// The inner sink
89
    tx: C::Sender<Entry<T>>,
90

            
91
    /// Our clone of the `Participation`, for memory accounting
92
    mq: TypedParticipation<Entry<T>>,
93

            
94
    /// Time provider for getting the data age
95
    #[educe(Debug(ignore))] // CoarseTimeProvider isn't Debug
96
    runtime: DynTimeProvider,
97
}
98

            
99
//---------- Receiver ----------
100

            
101
/// Receiver for a channel that participates in the memory quota system
102
///
103
/// Returned by [`ChannelSpec::new_mq`], a method on `C`.
104
/// See the [module-level docs](crate::mq_queue).
105
56
#[derive(Educe)] // not Clone, see below
106
#[educe(Debug)]
107
pub struct Receiver<T: Debug + Send + 'static, C: ChannelSpec> {
108
    /// Payload
109
    //
110
    // We don't make this an "exposed" `Arc`,
111
    // because that would allow the caller to clone it -
112
    // but we don't promise we're a multi-consumer queue even if `C::Receiver` is.
113
    //
114
    // Despite the in-principle Clone-ability of our `Receiver`,
115
    // we're not a working multi-consumer queue, even if the underlying channel is,
116
    // because StreamUnobtrusivePeeker isn't multi-consumer.
117
    //
118
    // Providing the multi-consumer feature would perhaps involve StreamUnobtrusivePeeker
119
    // handling multiple wakers, and then `impl Clone for Receiver where C::Receiver: Clone`.
120
    // (and writing a bunch of tests).
121
    //
122
    // This would all be useless without also `impl ChannelSpec`
123
    // for a multi-consumer queue.
124
    inner: Arc<ReceiverInner<T, C>>,
125
}
126

            
127
/// Payload of `Receiver`, that's within the `Arc`, but contains the `Mutex`.
128
///
129
/// This is a separate type because
130
/// it's what we need to implement [`IsParticipant`] for.
131
56
#[derive(Educe)]
132
#[educe(Debug)]
133
struct ReceiverInner<T: Debug + Send + 'static, C: ChannelSpec> {
134
    /// Mutable state
135
    ///
136
    /// If we have collapsed due to memory reclaim, state is replaced by an `Err`.
137
    /// In that case the caller mostly can't send on the Sender either,
138
    /// because we'll have torn down the Participant,
139
    /// so claims (beyond the cache in the `Sender`'s `Participation`) will fail.
140
    state: Mutex<Result<ReceiverState<T, C>, CollapsedDueToReclaim>>,
141
}
142

            
143
/// Mutable state of a `Receiver`
144
///
145
/// Normally the mutex is only locked by the receiving task.
146
/// On memory pressure, mutex is acquired by the memory system,
147
/// which has a clone of the `Arc<ReceiverInner>`.
148
///
149
/// Within `Arc<Mutex<Result<, >>>`.
150
1344
#[derive(Educe)]
151
#[educe(Debug)]
152
struct ReceiverState<T: Debug + Send + 'static, C: ChannelSpec> {
153
    /// The inner stream, but with an unobtrusive peek for getting the oldest data age
154
    rx: StreamUnobtrusivePeeker<C::Receiver<Entry<T>>>,
155

            
156
    /// The `Participation`, which we use for memory accounting
157
    ///
158
    /// ### Performance and locality
159
    ///
160
    /// We have separate [`Participation`]s for rx and tx.
161
    /// The tx is constantly claiming and the rx releasing;
162
    /// at least each MAX_CACHE, they must balance out
163
    /// via the (fairly globally shared) `MemoryQuotaTracker`.
164
    ///
165
    /// If this turns out to be a problem,
166
    /// we could arrange to share a `Participation`.
167
    mq: TypedParticipation<Entry<T>>,
168

            
169
    /// Hooks passed to [`Receiver::register_collapse_hook`]
170
    ///
171
    /// When receiver dropped, or memory reclaimed, we call all of these.
172
    #[educe(Debug(method = "receiver_state_debug_collapse_notify"))]
173
    collapse_callbacks: Vec<CollapseCallback>,
174
}
175

            
176
//---------- other types ----------
177

            
178
/// Entry in in the inner queue
179
#[derive(Debug)]
180
struct Entry<T> {
181
    /// The actual entry
182
    t: T,
183
    /// The data age - when it was inserted into the queue
184
    when: CoarseInstant,
185
}
186

            
187
/// Error returned when trying to write to a [`Sender`]
188
#[derive(Error, Clone, Debug)]
189
#[non_exhaustive]
190
pub enum SendError<CE> {
191
    /// The underlying channel rejected the message
192
    // Can't be `#[from]` because rustc can't see that C::SendError isn't SendError<C>
193
    #[error("channel send failed")]
194
    Channel(#[source] CE),
195

            
196
    /// The memory quota system prevented the send
197
    ///
198
    /// NB: when the channel is torn down due to memory pressure,
199
    /// the inner receiver is also torn down.
200
    /// This means that this variant is not always reported:
201
    /// sending on the sender in this situation
202
    /// may give [`SendError::Channel`] instead.
203
    #[error("memory quota exhausted, queue reclaimed")]
204
    Memquota(#[from] Error),
205
}
206

            
207
/// Callback passed to `Receiver::register_collapse_hook`
208
pub type CollapseCallback = Box<dyn FnOnce(CollapseReason) + Send + Sync + 'static>;
209

            
210
/// Argument to `CollapseCallback`: why are we collapsing?
211
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
212
#[non_exhaustive]
213
pub enum CollapseReason {
214
    /// The `Receiver` was dropped
215
    ReceiverDropped,
216

            
217
    /// The memory quota system asked us to reclaim memory
218
    MemoryReclaimed,
219
}
220

            
221
/// Marker, appears in state as `Err` to mean "we have collapsed"
222
#[derive(Debug, Clone, Copy)]
223
struct CollapsedDueToReclaim;
224

            
225
//==================== Channel ====================
226

            
227
/// Specification for a communication channel
228
///
229
/// Implemented for [`MpscSpec`] and [`MpscUnboundedSpec`].
230
//
231
// # Correctness (uncomment this if this trait is made unsealed)
232
//
233
// It is a requirement that this object really is some kind of channel.
234
// Specifically:
235
//
236
//  * Things that get put into the `Sender` must eventually emerge from the `Receiver`.
237
//  * Nothing may emerge from the `Receiver` that wasn't put into the `Sender`.
238
//  * If the `Sender` and `Receiver` are dropped, the items must also get dropped.
239
//
240
// If these requirements are violated, it could result in corruption of the memory accounts
241
//
242
// Ideally, if the `Receiver` is dropped, most of the items are dropped soon.
243
//
244
pub trait ChannelSpec: Sealed /* see Correctness, above */ + Sized + 'static {
245
    /// The sending [`Sink`] for items of type `T`.
246
    //
247
    // Right now we insist that everything is Unpin.
248
    // futures::channel::mpsc's types all are.
249
    // If we wanted to support !Unpin channels, that would be possible,
250
    // but we would have some work to do.
251
    //
252
    // We also insist that everything is Debug.  That means `T: Debug`,
253
    // as well as the channels.  We could avoid that, but it would involve
254
    // skipping debug of important fields, or pervasive complex trait bounds
255
    // (Eg `#[educe(Debug(bound = "C::Receiver<Entry<T>>: Debug"))]` or worse.)
256
    //
257
    // This is a GAT because we need to instantiate it with T=Entry<_>.
258
    type Sender<T: Debug + Send + 'static>: Sink<T, Error = Self::SendError>
259
        + Debug + Unpin + Sized;
260

            
261
    /// The receiving [`Stream`] for items of type `T`.
262
    type Receiver<T: Debug + Send + 'static>: Stream<Item = T> + Debug + Unpin + Send + Sized;
263

            
264
    /// The error type `<Receiver<_> as Stream>::Error`.
265
    ///
266
    /// (For this trait to be implemented, it is not allowed to depend on `T`.)
267
    type SendError: std::error::Error;
268

            
269
    /// Create a new channel, based on the spec `self`, that participates in the memory quota
270
    ///
271
    /// See the [module-level docs](crate::mq_queue) for an example.
272
    //
273
    // This method is supposed to be called by the user, not overridden.
274
    #[allow(clippy::type_complexity)] // the Result; not sensibly reducible or aliasable
275
1913
    fn new_mq<T>(self, runtime: DynTimeProvider, account: &Account) -> crate::Result<(
276
1913
        Sender<T, Self>,
277
1913
        Receiver<T, Self>,
278
1913
    )>
279
1913
    where
280
1913
        T: HasMemoryCost + Debug + Send + 'static,
281
1913
    {
282
1913
        let (rx, (tx, mq)) = account.register_participant_with(
283
1913
            runtime.now_coarse(),
284
1917
            move |mq| {
285
1913
                let mq = TypedParticipation::new(mq);
286
1913
                let collapse_callbacks = vec![];
287
1913
                let (tx, rx) = self.raw_channel::<Entry<T>>();
288
1913
                let rx = StreamUnobtrusivePeeker::new(rx);
289
1913
                let state = ReceiverState { rx, mq: mq.clone(), collapse_callbacks };
290
1913
                let state = Mutex::new(Ok(state));
291
1913
                let inner = ReceiverInner { state };
292
1913
                Ok::<_, crate::Error>((inner.into(), (tx, mq)))
293
1917
            },
294
1913
        )??;
295

            
296
1913
        let runtime = runtime.clone();
297
1913

            
298
1913
        let tx = Sender { runtime, tx, mq };
299
1913
        let rx = Receiver { inner: rx };
300
1913

            
301
1913
        Ok((tx, rx))
302
1913
    }
303

            
304
    /// Create a new raw channel as specified by `self`
305
    //
306
    // This is called by `mq_queue`.
307
    fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>);
308

            
309
    /// Close the receiver, preventing further sends
310
    ///
311
    /// This should ensure that only a smallish bounded number of further items
312
    /// can be sent, before errors start being returned.
313
    fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>);
314
}
315

            
316
//---------- impls of Channel ----------
317

            
318
/// Specification for a (bounded) MPSC channel
319
///
320
/// Corresponds to the constructor [`futures::channel::mpsc::channel`].
321
///
322
/// Call [`new_mq`](ChannelSpec::new_mq) on a value of this type.
323
///
324
/// (The [`new`](MpscUnboundedSpec::new) method is provided for convenience;
325
/// you may also construct the value directly.)
326
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor)]
327
#[allow(clippy::exhaustive_structs)] // This is precisely the arguments to mpsc::channel
328
pub struct MpscSpec {
329
    /// Buffer size; see [`futures::channel::mpsc::channel`].
330
    pub buffer: usize,
331
}
332

            
333
/// Specification for an unbounded MPSC channel
334
///
335
/// Corresponds to the constructor [`futures::channel::mpsc::unbounded`].
336
///
337
/// Call [`new_mq`](ChannelSpec::new_mq) on a value of this unit type.
338
///
339
/// (The [`new`](MpscUnboundedSpec::new) method is provided for orthogonality.)
340
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor, Default)]
341
#[allow(clippy::exhaustive_structs)] // This is precisely the arguments to mpsc::unbounded
342
pub struct MpscUnboundedSpec;
343

            
344
impl Sealed for MpscSpec {}
345
impl Sealed for MpscUnboundedSpec {}
346

            
347
impl ChannelSpec for MpscSpec {
348
    type Sender<T: Debug + Send + 'static> = mpsc::Sender<T>;
349
    type Receiver<T: Debug + Send + 'static> = mpsc::Receiver<T>;
350
    type SendError = mpsc::SendError;
351

            
352
1893
    fn raw_channel<T: Debug + Send + 'static>(self) -> (mpsc::Sender<T>, mpsc::Receiver<T>) {
353
1893
        mpsc_channel_no_memquota(self.buffer)
354
1893
    }
355

            
356
1831
    fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
357
1831
        rx.close();
358
1831
    }
359
}
360

            
361
impl ChannelSpec for MpscUnboundedSpec {
362
    type Sender<T: Debug + Send + 'static> = mpsc::UnboundedSender<T>;
363
    type Receiver<T: Debug + Send + 'static> = mpsc::UnboundedReceiver<T>;
364
    type SendError = mpsc::SendError;
365

            
366
12
    fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>) {
367
12
        mpsc::unbounded()
368
12
    }
369

            
370
12
    fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
371
12
        rx.close();
372
12
    }
373
}
374

            
375
//==================== implementations ====================
376

            
377
//---------- Sender ----------
378

            
379
impl<T, C> Sink<T> for Sender<T, C>
380
where
381
    T: HasMemoryCost + Debug + Send + 'static,
382
    C: ChannelSpec,
383
{
384
    type Error = SendError<C::SendError>;
385

            
386
10108
    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
387
10108
        self.get_mut()
388
10108
            .tx
389
10108
            .poll_ready_unpin(cx)
390
10108
            .map_err(SendError::Channel)
391
10108
    }
392

            
393
6162
    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
394
6162
        let self_ = self.get_mut();
395
6162
        let item = Entry {
396
6162
            t: item,
397
6162
            when: self_.runtime.now_coarse(),
398
6162
        };
399
6162
        self_.mq.try_claim(item, |item| {
400
6162
            self_.tx.start_send_unpin(item).map_err(SendError::Channel)
401
6162
        })?
402
6162
    }
403

            
404
4240
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
405
4240
        self.tx
406
4240
            .poll_flush_unpin(cx)
407
4240
            .map(|r| r.map_err(SendError::Channel))
408
4240
    }
409

            
410
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
411
        self.tx
412
            .poll_close_unpin(cx)
413
            .map(|r| r.map_err(SendError::Channel))
414
    }
415
}
416

            
417
impl<T, C> SinkTrySend<T> for Sender<T, C>
418
where
419
    T: HasMemoryCost + Debug + Send + 'static,
420
    C: ChannelSpec,
421
    C::Sender<Entry<T>>: SinkTrySend<Entry<T>>,
422
    <C::Sender<Entry<T>> as SinkTrySend<Entry<T>>>::Error: Send + Sync,
423
{
424
    type Error = ErasedSinkTrySendError;
425
124
    fn try_send_or_return(
426
124
        self: Pin<&mut Self>,
427
124
        item: T,
428
124
    ) -> Result<(), (<Self as SinkTrySend<T>>::Error, T)> {
429
124
        let self_ = self.get_mut();
430
124
        let item = Entry {
431
124
            t: item,
432
124
            when: self_.runtime.now_coarse(),
433
124
        };
434

            
435
        use ErasedSinkTrySendError as ESTSE;
436

            
437
124
        self_
438
124
            .mq
439
124
            .try_claim_or_return(item, |item| {
440
120
                Pin::new(&mut self_.tx).try_send_or_return(item)
441
124
            })
442
124
            .map_err(|(mqe, unsent)| (ESTSE::Other(Arc::new(mqe)), unsent.t))?
443
120
            .map_err(|(tse, unsent)| (ESTSE::from(tse), unsent.t))
444
124
    }
445
}
446

            
447
impl<T, C> SinkCloseChannel<T> for Sender<T, C>
448
where
449
    T: HasMemoryCost + Debug + Send, //Debug + 'static,
450
    C: ChannelSpec,
451
    C::Sender<Entry<T>>: SinkCloseChannel<Entry<T>>,
452
{
453
8
    fn close_channel(self: Pin<&mut Self>) {
454
8
        Pin::new(&mut self.get_mut().tx).close_channel();
455
8
    }
456
}
457

            
458
impl<T, C> Sender<T, C>
459
where
460
    T: Debug + Send + 'static,
461
    C: ChannelSpec,
462
{
463
    /// Obtain a reference to the `Sender`'s [`DynTimeProvider`]
464
    ///
465
    /// (This can sometimes be used to avoid having to keep
466
    /// a separate clone of the time provider.)
467
216
    pub fn time_provider(&self) -> &DynTimeProvider {
468
216
        &self.runtime
469
216
    }
470
}
471

            
472
//---------- Receiver ----------
473

            
474
impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Stream for Receiver<T, C> {
475
    type Item = T;
476

            
477
10520
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
478
10520
        let mut state = self.inner.lock();
479
10520
        let state = match &mut *state {
480
10516
            Ok(y) => y,
481
4
            Err(CollapsedDueToReclaim) => return Ready(None),
482
        };
483
10516
        let ret = state.rx.poll_next_unpin(cx);
484
5964
        if let Ready(Some(item)) = &ret {
485
5902
            if let Some(enabled) = EnabledToken::new_if_compiled_in() {
486
5902
                let cost = item.typed_memory_cost(enabled);
487
5902
                state.mq.release(&cost);
488
5902
            }
489
4614
        }
490
10516
        ret.map(|r| r.map(|e| e.t))
491
10520
    }
492
}
493

            
494
impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> FusedStream for Receiver<T, C>
495
where
496
    C::Receiver<Entry<T>>: FusedStream,
497
{
498
3354
    fn is_terminated(&self) -> bool {
499
3354
        match &*self.inner.lock() {
500
3354
            Ok(y) => y.rx.is_terminated(),
501
            Err(CollapsedDueToReclaim) => true,
502
        }
503
3354
    }
504
}
505

            
506
// TODO: When we have a trait for peekable streams, Receiver should implement it
507

            
508
impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Receiver<T, C> {
509
    /// Register a callback, called when we tear the channel down
510
    ///
511
    /// This will be called when the `Receiver` is dropped,
512
    /// or if we tear down because the memory system asks us to reclaim.
513
    ///
514
    /// `call` might be called at any time, from any thread, but
515
    /// it won't be holding any locks relating to memory quota or the queue.
516
    ///
517
    /// If `self` is *already* in the process of being torn down,
518
    /// `call` might be called immediately, reentrantly!
519
    //
520
    // This callback is nicer than us handing out an mpsc rx
521
    // which user must read and convert items from.
522
    //
523
    // This method is on Receiver because that has the State,
524
    // but could be called during setup to hook both sender's and
525
    // receiver's shutdown mechanisms.
526
    pub fn register_collapse_hook(&self, call: CollapseCallback) {
527
        let mut state = self.inner.lock();
528
        let state = match &mut *state {
529
            Ok(y) => y,
530
            Err(reason) => {
531
                let reason = (*reason).into();
532
                drop::<MutexGuard<_>>(state);
533
                call(reason);
534
                return;
535
            }
536
        };
537
        state.collapse_callbacks.push(call);
538
    }
539
}
540

            
541
impl<T: Debug + Send + 'static, C: ChannelSpec> ReceiverInner<T, C> {
542
    /// Convenience function to take the lock
543
13906
    fn lock(&self) -> MutexGuard<Result<ReceiverState<T, C>, CollapsedDueToReclaim>> {
544
13906
        self.state.lock().expect("mq_mpsc lock poisoned")
545
13906
    }
546
}
547

            
548
impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> IsParticipant
549
    for ReceiverInner<T, C>
550
{
551
16
    fn get_oldest(&self, _: EnabledToken) -> Option<CoarseInstant> {
552
16
        let mut state = self.lock();
553
16
        let state = match &mut *state {
554
16
            Ok(y) => y,
555
            Err(CollapsedDueToReclaim) => return None,
556
        };
557
16
        let peeked = Pin::new(&mut state.rx)
558
16
            .unobtrusive_peek()
559
16
            .map(|peeked| peeked.when);
560
16
        peeked
561
16
    }
562

            
563
16
    fn reclaim(self: Arc<Self>, _: EnabledToken) -> mtracker::ReclaimFuture {
564
16
        Box::pin(async move {
565
16
            let reason = CollapsedDueToReclaim;
566
16
            let mut state_guard = self.lock();
567
16
            let state = mem::replace(&mut *state_guard, Err(reason));
568
16
            drop::<MutexGuard<_>>(state_guard);
569
16
            #[allow(clippy::single_match)] // pattern is intentional.
570
16
            match state {
571
16
                Ok(mut state) => {
572
16
                    for call in state.collapse_callbacks.drain(..) {
573
                        call(reason.into());
574
                    }
575
16
                    drop::<ReceiverState<_, _>>(state); // will drain queue, too
576
                }
577
                Err(CollapsedDueToReclaim) => {}
578
            };
579
16
            mtracker::Reclaimed::Collapsing
580
16
        })
581
16
    }
582
}
583

            
584
impl<T: Debug + Send + 'static, C: ChannelSpec> Drop for ReceiverState<T, C> {
585
1913
    fn drop(&mut self) {
586
1913
        // If there's a mutex, we're in its drop
587
1913

            
588
1913
        // `destroy_participant` prevents the sender from making further non-cached claims
589
1913
        mem::replace(&mut self.mq, Participation::new_dangling().into())
590
1913
            .into_raw()
591
1913
            .destroy_participant();
592

            
593
1913
        for call in self.collapse_callbacks.drain(..) {
594
            call(CollapseReason::ReceiverDropped);
595
        }
596

            
597
        // try to free whatever is in the queue, in case the stream doesn't do that itself
598
        // No-one can poll us any more, so we are no longer interested in wakeups
599
1913
        let mut noop_cx = Context::from_waker(noop_waker_ref());
600

            
601
        // prevent further sends, so that our drain doesn't race indefinitely with the sender
602
1851
        if let Some(mut rx_inner) =
603
1913
            StreamUnobtrusivePeeker::as_raw_inner_pin_mut(Pin::new(&mut self.rx))
604
1851
        {
605
1851
            C::close_receiver(&mut rx_inner);
606
1851
        }
607

            
608
2285
        while let Ready(Some(item)) = self.rx.poll_next_unpin(&mut noop_cx) {
609
372
            drop::<Entry<T>>(item);
610
372
        }
611
1913
    }
612
}
613

            
614
/// Method for educe's Debug impl for `ReceiverState.collapse_callbacks`
615
1344
fn receiver_state_debug_collapse_notify(
616
1344
    v: &[CollapseCallback],
617
1344
    f: &mut fmt::Formatter,
618
1344
) -> fmt::Result {
619
1344
    Debug::fmt(&v.len(), f)
620
1344
}
621

            
622
//---------- misc ----------
623

            
624
impl<T: HasMemoryCost> HasMemoryCost for Entry<T> {
625
12188
    fn memory_cost(&self, enabled: EnabledToken) -> usize {
626
12188
        let time_size = std::alloc::Layout::new::<CoarseInstant>().size();
627
12188
        self.t.memory_cost(enabled).saturating_add(time_size)
628
12188
    }
629
}
630

            
631
impl From<CollapsedDueToReclaim> for CollapseReason {
632
    fn from(CollapsedDueToReclaim: CollapsedDueToReclaim) -> CollapseReason {
633
        CollapseReason::MemoryReclaimed
634
    }
635
}
636

            
637
#[cfg(all(test, feature = "memquota", not(miri) /* coarsetime */))]
638
mod test {
639
    // @@ begin test lint list maintained by maint/add_warning @@
640
    #![allow(clippy::bool_assert_comparison)]
641
    #![allow(clippy::clone_on_copy)]
642
    #![allow(clippy::dbg_macro)]
643
    #![allow(clippy::mixed_attributes_style)]
644
    #![allow(clippy::print_stderr)]
645
    #![allow(clippy::print_stdout)]
646
    #![allow(clippy::single_char_pattern)]
647
    #![allow(clippy::unwrap_used)]
648
    #![allow(clippy::unchecked_duration_subtraction)]
649
    #![allow(clippy::useless_vec)]
650
    #![allow(clippy::needless_pass_by_value)]
651
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
652
    #![allow(clippy::arithmetic_side_effects)] // don't mind potential panicking ops in tests
653

            
654
    use super::*;
655
    use crate::mtracker::test::*;
656
    use tor_rtmock::MockRuntime;
657
    use tracing::debug;
658
    use tracing_test::traced_test;
659

            
660
    #[derive(Default, Debug)]
661
    struct ItemTracker {
662
        state: Mutex<ItemTrackerState>,
663
    }
664
    #[derive(Default, Debug)]
665
    struct ItemTrackerState {
666
        existing: usize,
667
        next_id: usize,
668
    }
669

            
670
    #[derive(Debug)]
671
    struct Item {
672
        id: usize,
673
        tracker: Arc<ItemTracker>,
674
    }
675

            
676
    impl ItemTracker {
677
        fn new_item(self: &Arc<Self>) -> Item {
678
            let mut state = self.lock();
679
            let id = state.next_id;
680
            state.existing += 1;
681
            state.next_id += 1;
682
            debug!("new {id}");
683
            Item {
684
                tracker: self.clone(),
685
                id,
686
            }
687
        }
688

            
689
        fn new_tracker() -> Arc<Self> {
690
            Arc::default()
691
        }
692

            
693
        fn lock(&self) -> MutexGuard<ItemTrackerState> {
694
            self.state.lock().unwrap()
695
        }
696
    }
697

            
698
    impl Drop for Item {
699
        fn drop(&mut self) {
700
            debug!("old {}", self.id);
701
            self.tracker.state.lock().unwrap().existing -= 1;
702
        }
703
    }
704

            
705
    impl HasMemoryCost for Item {
706
        fn memory_cost(&self, _: EnabledToken) -> usize {
707
            mbytes(1)
708
        }
709
    }
710

            
711
    struct Setup {
712
        dtp: DynTimeProvider,
713
        trk: Arc<mtracker::MemoryQuotaTracker>,
714
        acct: Account,
715
        itrk: Arc<ItemTracker>,
716
    }
717

            
718
    fn setup(rt: &MockRuntime) -> Setup {
719
        let dtp = DynTimeProvider::new(rt.clone());
720
        let trk = mk_tracker(rt);
721
        let acct = trk.new_account(None).unwrap();
722
        let itrk = ItemTracker::new_tracker();
723
        Setup {
724
            dtp,
725
            trk,
726
            acct,
727
            itrk,
728
        }
729
    }
730

            
731
    #[derive(Debug)]
732
    struct Gigantic;
733
    impl HasMemoryCost for Gigantic {
734
        fn memory_cost(&self, _et: EnabledToken) -> usize {
735
            mbytes(100)
736
        }
737
    }
738

            
739
    impl Setup {
740
        /// Check that claims and releases have balanced out
741
        ///
742
        /// `n_queues` is the number of queues that exist.
743
        /// This is used to provide some slop, since each queue has two [`Participation`]s
744
        /// each of which can have some cached claim.
745
        fn check_zero_claimed(&self, n_queues: usize) {
746
            let used = self.trk.used_current_approx();
747
            debug!(
748
                "checking zero balance (with slop {n_queues} * 2 * {}; used={used:?}",
749
                *mtracker::MAX_CACHE,
750
            );
751
            assert!(used.unwrap() <= n_queues * 2 * *mtracker::MAX_CACHE);
752
        }
753
    }
754

            
755
    #[traced_test]
756
    #[test]
757
    fn lifecycle() {
758
        MockRuntime::test_with_various(|rt| async move {
759
            let s = setup(&rt);
760
            let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
761

            
762
            tx.send(s.itrk.new_item()).await.unwrap();
763
            let _: Item = rx.next().await.unwrap();
764

            
765
            for _ in 0..20 {
766
                tx.send(s.itrk.new_item()).await.unwrap();
767
            }
768

            
769
            // reclaim task hasn't had a chance to run
770
            debug!("still existing items {}", s.itrk.lock().existing);
771

            
772
            rt.advance_until_stalled().await;
773

            
774
            // reclaim task should have torn everything down
775
            assert!(s.itrk.lock().existing == 0);
776

            
777
            assert!(rx.next().await.is_none());
778

            
779
            // Empirically, this is a "disconnected" error from the inner mpsc,
780
            // but let's not assert that.
781
            let _: SendError<_> = tx.send(s.itrk.new_item()).await.unwrap_err();
782
        });
783
    }
784

            
785
    #[traced_test]
786
    #[test]
787
    fn fill_and_empty() {
788
        MockRuntime::test_with_various(|rt| async move {
789
            let s = setup(&rt);
790
            let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
791

            
792
            const COUNT: usize = 19;
793

            
794
            for _ in 0..COUNT {
795
                tx.send(s.itrk.new_item()).await.unwrap();
796
            }
797

            
798
            rt.advance_until_stalled().await;
799

            
800
            for _ in 0..COUNT {
801
                let _: Item = rx.next().await.unwrap();
802
            }
803

            
804
            rt.advance_until_stalled().await;
805

            
806
            // no memory should be claimed
807
            s.check_zero_claimed(1);
808
        });
809
    }
810

            
811
    #[traced_test]
812
    #[test]
813
    fn sink_error() {
814
        #[derive(Debug, Copy, Clone)]
815
        struct BustedSink {
816
            error: BustedError,
817
        }
818

            
819
        impl<T> Sink<T> for BustedSink {
820
            type Error = BustedError;
821

            
822
            fn poll_ready(
823
                self: Pin<&mut Self>,
824
                _: &mut Context<'_>,
825
            ) -> Poll<Result<(), Self::Error>> {
826
                Ready(Err(self.error))
827
            }
828
            fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
829
                panic!("poll_ready always gives error, start_send should not be called");
830
            }
831
            fn poll_flush(
832
                self: Pin<&mut Self>,
833
                _: &mut Context<'_>,
834
            ) -> Poll<Result<(), Self::Error>> {
835
                Ready(Ok(()))
836
            }
837
            fn poll_close(
838
                self: Pin<&mut Self>,
839
                _: &mut Context<'_>,
840
            ) -> Poll<Result<(), Self::Error>> {
841
                Ready(Ok(()))
842
            }
843
        }
844

            
845
        impl<T> SinkTrySend<T> for BustedSink {
846
            type Error = BustedError;
847

            
848
            fn try_send_or_return(self: Pin<&mut Self>, item: T) -> Result<(), (BustedError, T)> {
849
                Err((self.error, item))
850
            }
851
        }
852

            
853
        impl tor_async_utils::SinkTrySendError for BustedError {
854
            fn is_disconnected(&self) -> bool {
855
                self.is_disconnected
856
            }
857
            fn is_full(&self) -> bool {
858
                false
859
            }
860
        }
861

            
862
        #[derive(Error, Debug, Clone, Copy)]
863
        #[error("busted, for testing, dc={is_disconnected:?}")]
864
        struct BustedError {
865
            is_disconnected: bool,
866
        }
867

            
868
        struct BustedQueueSpec {
869
            error: BustedError,
870
        }
871
        impl Sealed for BustedQueueSpec {}
872
        impl ChannelSpec for BustedQueueSpec {
873
            type Sender<T: Debug + Send + 'static> = BustedSink;
874
            type Receiver<T: Debug + Send + 'static> = futures::stream::Pending<T>;
875
            type SendError = BustedError;
876
            fn raw_channel<T: Debug + Send + 'static>(self) -> (BustedSink, Self::Receiver<T>) {
877
                (BustedSink { error: self.error }, futures::stream::pending())
878
            }
879
            fn close_receiver<T: Debug + Send + 'static>(_rx: &mut Self::Receiver<T>) {}
880
        }
881

            
882
        use ErasedSinkTrySendError as ESTSE;
883

            
884
        MockRuntime::test_with_various(|rt| async move {
885
            let error = BustedError {
886
                is_disconnected: true,
887
            };
888

            
889
            let s = setup(&rt);
890
            let (mut tx, _rx) = BustedQueueSpec { error }
891
                .new_mq(s.dtp.clone(), &s.acct)
892
                .unwrap();
893

            
894
            let e = tx.send(s.itrk.new_item()).await.unwrap_err();
895
            assert!(matches!(e, SendError::Channel(BustedError { .. })));
896

            
897
            // item should have been destroyed
898
            assert_eq!(s.itrk.lock().existing, 0);
899

            
900
            // ---- Test try_send error handling ----
901

            
902
            fn error_is_other_of<E>(e: ESTSE) -> Result<(), impl Debug>
903
            where
904
                E: std::error::Error + 'static,
905
            {
906
                match e {
907
                    ESTSE::Other(e) if e.is::<E>() => Ok(()),
908
                    other => Err(other),
909
                }
910
            }
911

            
912
            let item = s.itrk.new_item();
913

            
914
            // Test try_send failure due to BustedError, is_disconnected: true
915

            
916
            let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
917
            assert!(matches!(e, ESTSE::Disconnected), "{e:?}");
918

            
919
            // Test try_send failure due to BustedError, is_disconnected: false (ie, Other)
920

            
921
            let error = BustedError {
922
                is_disconnected: false,
923
            };
924
            let (mut tx, _rx) = BustedQueueSpec { error }
925
                .new_mq(s.dtp.clone(), &s.acct)
926
                .unwrap();
927
            let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
928
            error_is_other_of::<BustedError>(e).unwrap();
929

            
930
            // no memory should be claimed
931
            s.check_zero_claimed(1);
932

            
933
            // Test try_send failure due to memory quota collapse
934

            
935
            // cause reclaim
936
            {
937
                let (mut tx, _rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
938
                tx.send(Gigantic).await.unwrap();
939
                rt.advance_until_stalled().await;
940
            }
941

            
942
            let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
943
            error_is_other_of::<crate::Error>(e).unwrap();
944

            
945
            drop::<Item>(item);
946
        });
947
    }
948
}