tor_memquota/
mq_queue.rs

1//! Queues that participate in the memory quota system
2//!
3//! Wraps a communication channel, such as [`futures::channel::mpsc`],
4//! tracks the memory use of the queue,
5//! and participates in the memory quota system.
6//!
7//! Each item in the queue must know its memory cost,
8//! and provide it via [`HasMemoryCost`].
9//!
10//! New queues are created by calling the [`new_mq`](ChannelSpec::new_mq) method
11//! on a [`ChannelSpec`],
12//! for example [`MpscSpec`] or [`MpscUnboundedSpec`].
13//!
14//! The ends implement [`Stream`] and [`Sink`].
15//! If the underlying channel's sender is `Clone`,
16//! for example with an MPSC queue, the returned sender is also `Clone`.
17//!
18//! Note that the [`Sender`] and [`Receiver`] only hold weak references to the `Account`.
19//! Ie, the queue is not the accountholder.
20//! The caller should keep a separate copy of the account.
21//!
22//! # Example
23//!
24//! ```
25//! use tor_memquota::{MemoryQuotaTracker, HasMemoryCost, EnabledToken};
26//! use tor_rtcompat::{DynTimeProvider, PreferredRuntime};
27//! use tor_memquota::mq_queue::{MpscSpec, ChannelSpec as _};
28//! # fn m() -> tor_memquota::Result<()> {
29//!
30//! #[derive(Debug)]
31//! struct Message(String);
32//! impl HasMemoryCost for Message {
33//!     fn memory_cost(&self, _: EnabledToken) -> usize { self.0.len() }
34//! }
35//!
36//! let runtime = PreferredRuntime::create().unwrap();
37//! let time_prov = DynTimeProvider::new(runtime.clone());
38#![cfg_attr(
39    feature = "memquota",
40    doc = "let config  = tor_memquota::Config::builder().max(1024*1024*1024).build().unwrap();",
41    doc = "let trk = MemoryQuotaTracker::new(&runtime, config).unwrap();"
42)]
43#![cfg_attr(
44    not(feature = "memquota"),
45    doc = "let trk = MemoryQuotaTracker::new_noop();"
46)]
47//! let account = trk.new_account(None).unwrap();
48//!
49//! let (tx, rx) = MpscSpec { buffer: 10 }.new_mq::<Message>(time_prov, &account)?;
50//! #
51//! # Ok(())
52//! # }
53//! # m().unwrap();
54//! ```
55//!
56//! # Caveat
57//!
58//! The memory use tracking is based on external observations,
59//! i.e., items inserted and removed.
60//!
61//! How well this reflects the actual memory use of the channel
62//! depends on the channel's implementation.
63//!
64//! For example, if the channel uses a single contiguous buffer
65//! containing the unboxed items, and that buffer doesn't shrink,
66//! then the memory tracking can be based on an underestimate.
67//! (This is significantly mitigated if the bulk of the memory use
68//! for each item is separately boxed.)
69
70#![forbid(unsafe_code)] // if you remove this, enable (or write) miri tests (git grep miri)
71
72use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
73
74use crate::internal_prelude::*;
75
76use std::task::{Context, Poll, Poll::*};
77use tor_async_utils::{ErasedSinkTrySendError, SinkCloseChannel, SinkTrySend};
78
79//---------- Sender ----------
80
81/// Sender for a channel that participates in the memory quota system
82///
83/// Returned by [`ChannelSpec::new_mq`], a method on `C`.
84/// See the [module-level docs](crate::mq_queue).
85#[derive(Educe)]
86#[educe(Debug, Clone(bound = "C::Sender<Entry<T>>: Clone"))]
87pub struct Sender<T: Debug + Send + 'static, C: ChannelSpec> {
88    /// The inner sink
89    tx: C::Sender<Entry<T>>,
90
91    /// Our clone of the `Participation`, for memory accounting
92    mq: TypedParticipation<Entry<T>>,
93
94    /// Time provider for getting the data age
95    #[educe(Debug(ignore))] // CoarseTimeProvider isn't Debug
96    runtime: DynTimeProvider,
97}
98
99//---------- Receiver ----------
100
101/// Receiver for a channel that participates in the memory quota system
102///
103/// Returned by [`ChannelSpec::new_mq`], a method on `C`.
104/// See the [module-level docs](crate::mq_queue).
105#[derive(Educe)] // not Clone, see below
106#[educe(Debug)]
107pub struct Receiver<T: Debug + Send + 'static, C: ChannelSpec> {
108    /// Payload
109    //
110    // We don't make this an "exposed" `Arc`,
111    // because that would allow the caller to clone it -
112    // but we don't promise we're a multi-consumer queue even if `C::Receiver` is.
113    //
114    // Despite the in-principle Clone-ability of our `Receiver`,
115    // we're not a working multi-consumer queue, even if the underlying channel is,
116    // because StreamUnobtrusivePeeker isn't multi-consumer.
117    //
118    // Providing the multi-consumer feature would perhaps involve StreamUnobtrusivePeeker
119    // handling multiple wakers, and then `impl Clone for Receiver where C::Receiver: Clone`.
120    // (and writing a bunch of tests).
121    //
122    // This would all be useless without also `impl ChannelSpec`
123    // for a multi-consumer queue.
124    inner: Arc<ReceiverInner<T, C>>,
125}
126
127/// Payload of `Receiver`, that's within the `Arc`, but contains the `Mutex`.
128///
129/// This is a separate type because
130/// it's what we need to implement [`IsParticipant`] for.
131#[derive(Educe)]
132#[educe(Debug)]
133struct ReceiverInner<T: Debug + Send + 'static, C: ChannelSpec> {
134    /// Mutable state
135    ///
136    /// If we have collapsed due to memory reclaim, state is replaced by an `Err`.
137    /// In that case the caller mostly can't send on the Sender either,
138    /// because we'll have torn down the Participant,
139    /// so claims (beyond the cache in the `Sender`'s `Participation`) will fail.
140    state: Mutex<Result<ReceiverState<T, C>, CollapsedDueToReclaim>>,
141}
142
143/// Mutable state of a `Receiver`
144///
145/// Normally the mutex is only locked by the receiving task.
146/// On memory pressure, mutex is acquired by the memory system,
147/// which has a clone of the `Arc<ReceiverInner>`.
148///
149/// Within `Arc<Mutex<Result<, >>>`.
150#[derive(Educe)]
151#[educe(Debug)]
152struct ReceiverState<T: Debug + Send + 'static, C: ChannelSpec> {
153    /// The inner stream, but with an unobtrusive peek for getting the oldest data age
154    rx: StreamUnobtrusivePeeker<C::Receiver<Entry<T>>>,
155
156    /// The `Participation`, which we use for memory accounting
157    ///
158    /// ### Performance and locality
159    ///
160    /// We have separate [`Participation`]s for rx and tx.
161    /// The tx is constantly claiming and the rx releasing;
162    /// at least each MAX_CACHE, they must balance out
163    /// via the (fairly globally shared) `MemoryQuotaTracker`.
164    ///
165    /// If this turns out to be a problem,
166    /// we could arrange to share a `Participation`.
167    mq: TypedParticipation<Entry<T>>,
168
169    /// Hooks passed to [`Receiver::register_collapse_hook`]
170    ///
171    /// When receiver dropped, or memory reclaimed, we call all of these.
172    #[educe(Debug(method = "receiver_state_debug_collapse_notify"))]
173    collapse_callbacks: Vec<CollapseCallback>,
174}
175
176//---------- other types ----------
177
178/// Entry in in the inner queue
179#[derive(Debug)]
180struct Entry<T> {
181    /// The actual entry
182    t: T,
183    /// The data age - when it was inserted into the queue
184    when: CoarseInstant,
185}
186
187/// Error returned when trying to write to a [`Sender`]
188#[derive(Error, Clone, Debug)]
189#[non_exhaustive]
190pub enum SendError<CE> {
191    /// The underlying channel rejected the message
192    // Can't be `#[from]` because rustc can't see that C::SendError isn't SendError<C>
193    #[error("channel send failed")]
194    Channel(#[source] CE),
195
196    /// The memory quota system prevented the send
197    ///
198    /// NB: when the channel is torn down due to memory pressure,
199    /// the inner receiver is also torn down.
200    /// This means that this variant is not always reported:
201    /// sending on the sender in this situation
202    /// may give [`SendError::Channel`] instead.
203    #[error("memory quota exhausted, queue reclaimed")]
204    Memquota(#[from] Error),
205}
206
207/// Callback passed to `Receiver::register_collapse_hook`
208pub type CollapseCallback = Box<dyn FnOnce(CollapseReason) + Send + Sync + 'static>;
209
210/// Argument to `CollapseCallback`: why are we collapsing?
211#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
212#[non_exhaustive]
213pub enum CollapseReason {
214    /// The `Receiver` was dropped
215    ReceiverDropped,
216
217    /// The memory quota system asked us to reclaim memory
218    MemoryReclaimed,
219}
220
221/// Marker, appears in state as `Err` to mean "we have collapsed"
222#[derive(Debug, Clone, Copy)]
223struct CollapsedDueToReclaim;
224
225//==================== Channel ====================
226
227/// Specification for a communication channel
228///
229/// Implemented for [`MpscSpec`] and [`MpscUnboundedSpec`].
230//
231// # Correctness (uncomment this if this trait is made unsealed)
232//
233// It is a requirement that this object really is some kind of channel.
234// Specifically:
235//
236//  * Things that get put into the `Sender` must eventually emerge from the `Receiver`.
237//  * Nothing may emerge from the `Receiver` that wasn't put into the `Sender`.
238//  * If the `Sender` and `Receiver` are dropped, the items must also get dropped.
239//
240// If these requirements are violated, it could result in corruption of the memory accounts
241//
242// Ideally, if the `Receiver` is dropped, most of the items are dropped soon.
243//
244pub trait ChannelSpec: Sealed /* see Correctness, above */ + Sized + 'static {
245    /// The sending [`Sink`] for items of type `T`.
246    //
247    // Right now we insist that everything is Unpin.
248    // futures::channel::mpsc's types all are.
249    // If we wanted to support !Unpin channels, that would be possible,
250    // but we would have some work to do.
251    //
252    // We also insist that everything is Debug.  That means `T: Debug`,
253    // as well as the channels.  We could avoid that, but it would involve
254    // skipping debug of important fields, or pervasive complex trait bounds
255    // (Eg `#[educe(Debug(bound = "C::Receiver<Entry<T>>: Debug"))]` or worse.)
256    //
257    // This is a GAT because we need to instantiate it with T=Entry<_>.
258    type Sender<T: Debug + Send + 'static>: Sink<T, Error = Self::SendError>
259        + Debug + Unpin + Sized;
260
261    /// The receiving [`Stream`] for items of type `T`.
262    type Receiver<T: Debug + Send + 'static>: Stream<Item = T> + Debug + Unpin + Send + Sized;
263
264    /// The error type `<Receiver<_> as Stream>::Error`.
265    ///
266    /// (For this trait to be implemented, it is not allowed to depend on `T`.)
267    type SendError: std::error::Error;
268
269    /// Create a new channel, based on the spec `self`, that participates in the memory quota
270    ///
271    /// See the [module-level docs](crate::mq_queue) for an example.
272    //
273    // This method is supposed to be called by the user, not overridden.
274    #[allow(clippy::type_complexity)] // the Result; not sensibly reducible or aliasable
275    fn new_mq<T>(self, runtime: DynTimeProvider, account: &Account) -> crate::Result<(
276        Sender<T, Self>,
277        Receiver<T, Self>,
278    )>
279    where
280        T: HasMemoryCost + Debug + Send + 'static,
281    {
282        let (rx, (tx, mq)) = account.register_participant_with(
283            runtime.now_coarse(),
284            move |mq| {
285                let mq = TypedParticipation::new(mq);
286                let collapse_callbacks = vec![];
287                let (tx, rx) = self.raw_channel::<Entry<T>>();
288                let rx = StreamUnobtrusivePeeker::new(rx);
289                let state = ReceiverState { rx, mq: mq.clone(), collapse_callbacks };
290                let state = Mutex::new(Ok(state));
291                let inner = ReceiverInner { state };
292                Ok::<_, crate::Error>((inner.into(), (tx, mq)))
293            },
294        )??;
295
296        let runtime = runtime.clone();
297
298        let tx = Sender { runtime, tx, mq };
299        let rx = Receiver { inner: rx };
300
301        Ok((tx, rx))
302    }
303
304    /// Create a new raw channel as specified by `self`
305    //
306    // This is called by `mq_queue`.
307    fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>);
308
309    /// Close the receiver, preventing further sends
310    ///
311    /// This should ensure that only a smallish bounded number of further items
312    /// can be sent, before errors start being returned.
313    fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>);
314}
315
316//---------- impls of Channel ----------
317
318/// Specification for a (bounded) MPSC channel
319///
320/// Corresponds to the constructor [`futures::channel::mpsc::channel`].
321///
322/// Call [`new_mq`](ChannelSpec::new_mq) on a value of this type.
323///
324/// (The [`new`](MpscUnboundedSpec::new) method is provided for convenience;
325/// you may also construct the value directly.)
326#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor)]
327#[allow(clippy::exhaustive_structs)] // This is precisely the arguments to mpsc::channel
328pub struct MpscSpec {
329    /// Buffer size; see [`futures::channel::mpsc::channel`].
330    pub buffer: usize,
331}
332
333/// Specification for an unbounded MPSC channel
334///
335/// Corresponds to the constructor [`futures::channel::mpsc::unbounded`].
336///
337/// Call [`new_mq`](ChannelSpec::new_mq) on a value of this unit type.
338///
339/// (The [`new`](MpscUnboundedSpec::new) method is provided for orthogonality.)
340#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor, Default)]
341#[allow(clippy::exhaustive_structs)] // This is precisely the arguments to mpsc::unbounded
342pub struct MpscUnboundedSpec;
343
344impl Sealed for MpscSpec {}
345impl Sealed for MpscUnboundedSpec {}
346
347impl ChannelSpec for MpscSpec {
348    type Sender<T: Debug + Send + 'static> = mpsc::Sender<T>;
349    type Receiver<T: Debug + Send + 'static> = mpsc::Receiver<T>;
350    type SendError = mpsc::SendError;
351
352    fn raw_channel<T: Debug + Send + 'static>(self) -> (mpsc::Sender<T>, mpsc::Receiver<T>) {
353        mpsc_channel_no_memquota(self.buffer)
354    }
355
356    fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
357        rx.close();
358    }
359}
360
361impl ChannelSpec for MpscUnboundedSpec {
362    type Sender<T: Debug + Send + 'static> = mpsc::UnboundedSender<T>;
363    type Receiver<T: Debug + Send + 'static> = mpsc::UnboundedReceiver<T>;
364    type SendError = mpsc::SendError;
365
366    fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>) {
367        mpsc::unbounded()
368    }
369
370    fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
371        rx.close();
372    }
373}
374
375//==================== implementations ====================
376
377//---------- Sender ----------
378
379impl<T, C> Sink<T> for Sender<T, C>
380where
381    T: HasMemoryCost + Debug + Send + 'static,
382    C: ChannelSpec,
383{
384    type Error = SendError<C::SendError>;
385
386    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
387        self.get_mut()
388            .tx
389            .poll_ready_unpin(cx)
390            .map_err(SendError::Channel)
391    }
392
393    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
394        let self_ = self.get_mut();
395        let item = Entry {
396            t: item,
397            when: self_.runtime.now_coarse(),
398        };
399        self_.mq.try_claim(item, |item| {
400            self_.tx.start_send_unpin(item).map_err(SendError::Channel)
401        })?
402    }
403
404    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
405        self.tx
406            .poll_flush_unpin(cx)
407            .map(|r| r.map_err(SendError::Channel))
408    }
409
410    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
411        self.tx
412            .poll_close_unpin(cx)
413            .map(|r| r.map_err(SendError::Channel))
414    }
415}
416
417impl<T, C> SinkTrySend<T> for Sender<T, C>
418where
419    T: HasMemoryCost + Debug + Send + 'static,
420    C: ChannelSpec,
421    C::Sender<Entry<T>>: SinkTrySend<Entry<T>>,
422    <C::Sender<Entry<T>> as SinkTrySend<Entry<T>>>::Error: Send + Sync,
423{
424    type Error = ErasedSinkTrySendError;
425    fn try_send_or_return(
426        self: Pin<&mut Self>,
427        item: T,
428    ) -> Result<(), (<Self as SinkTrySend<T>>::Error, T)> {
429        let self_ = self.get_mut();
430        let item = Entry {
431            t: item,
432            when: self_.runtime.now_coarse(),
433        };
434
435        use ErasedSinkTrySendError as ESTSE;
436
437        self_
438            .mq
439            .try_claim_or_return(item, |item| {
440                Pin::new(&mut self_.tx).try_send_or_return(item)
441            })
442            .map_err(|(mqe, unsent)| (ESTSE::Other(Arc::new(mqe)), unsent.t))?
443            .map_err(|(tse, unsent)| (ESTSE::from(tse), unsent.t))
444    }
445}
446
447impl<T, C> SinkCloseChannel<T> for Sender<T, C>
448where
449    T: HasMemoryCost + Debug + Send, //Debug + 'static,
450    C: ChannelSpec,
451    C::Sender<Entry<T>>: SinkCloseChannel<Entry<T>>,
452{
453    fn close_channel(self: Pin<&mut Self>) {
454        Pin::new(&mut self.get_mut().tx).close_channel();
455    }
456}
457
458impl<T, C> Sender<T, C>
459where
460    T: Debug + Send + 'static,
461    C: ChannelSpec,
462{
463    /// Obtain a reference to the `Sender`'s [`DynTimeProvider`]
464    ///
465    /// (This can sometimes be used to avoid having to keep
466    /// a separate clone of the time provider.)
467    pub fn time_provider(&self) -> &DynTimeProvider {
468        &self.runtime
469    }
470}
471
472//---------- Receiver ----------
473
474impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Stream for Receiver<T, C> {
475    type Item = T;
476
477    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
478        let mut state = self.inner.lock();
479        let state = match &mut *state {
480            Ok(y) => y,
481            Err(CollapsedDueToReclaim) => return Ready(None),
482        };
483        let ret = state.rx.poll_next_unpin(cx);
484        if let Ready(Some(item)) = &ret {
485            if let Some(enabled) = EnabledToken::new_if_compiled_in() {
486                let cost = item.typed_memory_cost(enabled);
487                state.mq.release(&cost);
488            }
489        }
490        ret.map(|r| r.map(|e| e.t))
491    }
492}
493
494impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> FusedStream for Receiver<T, C>
495where
496    C::Receiver<Entry<T>>: FusedStream,
497{
498    fn is_terminated(&self) -> bool {
499        match &*self.inner.lock() {
500            Ok(y) => y.rx.is_terminated(),
501            Err(CollapsedDueToReclaim) => true,
502        }
503    }
504}
505
506// TODO: When we have a trait for peekable streams, Receiver should implement it
507
508impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Receiver<T, C> {
509    /// Register a callback, called when we tear the channel down
510    ///
511    /// This will be called when the `Receiver` is dropped,
512    /// or if we tear down because the memory system asks us to reclaim.
513    ///
514    /// `call` might be called at any time, from any thread, but
515    /// it won't be holding any locks relating to memory quota or the queue.
516    ///
517    /// If `self` is *already* in the process of being torn down,
518    /// `call` might be called immediately, reentrantly!
519    //
520    // This callback is nicer than us handing out an mpsc rx
521    // which user must read and convert items from.
522    //
523    // This method is on Receiver because that has the State,
524    // but could be called during setup to hook both sender's and
525    // receiver's shutdown mechanisms.
526    pub fn register_collapse_hook(&self, call: CollapseCallback) {
527        let mut state = self.inner.lock();
528        let state = match &mut *state {
529            Ok(y) => y,
530            Err(reason) => {
531                let reason = (*reason).into();
532                drop::<MutexGuard<_>>(state);
533                call(reason);
534                return;
535            }
536        };
537        state.collapse_callbacks.push(call);
538    }
539}
540
541impl<T: Debug + Send + 'static, C: ChannelSpec> ReceiverInner<T, C> {
542    /// Convenience function to take the lock
543    fn lock(&self) -> MutexGuard<Result<ReceiverState<T, C>, CollapsedDueToReclaim>> {
544        self.state.lock().expect("mq_mpsc lock poisoned")
545    }
546}
547
548impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> IsParticipant
549    for ReceiverInner<T, C>
550{
551    fn get_oldest(&self, _: EnabledToken) -> Option<CoarseInstant> {
552        let mut state = self.lock();
553        let state = match &mut *state {
554            Ok(y) => y,
555            Err(CollapsedDueToReclaim) => return None,
556        };
557        let peeked = Pin::new(&mut state.rx)
558            .unobtrusive_peek()
559            .map(|peeked| peeked.when);
560        peeked
561    }
562
563    fn reclaim(self: Arc<Self>, _: EnabledToken) -> mtracker::ReclaimFuture {
564        Box::pin(async move {
565            let reason = CollapsedDueToReclaim;
566            let mut state_guard = self.lock();
567            let state = mem::replace(&mut *state_guard, Err(reason));
568            drop::<MutexGuard<_>>(state_guard);
569            #[allow(clippy::single_match)] // pattern is intentional.
570            match state {
571                Ok(mut state) => {
572                    for call in state.collapse_callbacks.drain(..) {
573                        call(reason.into());
574                    }
575                    drop::<ReceiverState<_, _>>(state); // will drain queue, too
576                }
577                Err(CollapsedDueToReclaim) => {}
578            };
579            mtracker::Reclaimed::Collapsing
580        })
581    }
582}
583
584impl<T: Debug + Send + 'static, C: ChannelSpec> Drop for ReceiverState<T, C> {
585    fn drop(&mut self) {
586        // If there's a mutex, we're in its drop
587
588        // `destroy_participant` prevents the sender from making further non-cached claims
589        mem::replace(&mut self.mq, Participation::new_dangling().into())
590            .into_raw()
591            .destroy_participant();
592
593        for call in self.collapse_callbacks.drain(..) {
594            call(CollapseReason::ReceiverDropped);
595        }
596
597        // try to free whatever is in the queue, in case the stream doesn't do that itself
598        // No-one can poll us any more, so we are no longer interested in wakeups
599        let mut noop_cx = Context::from_waker(noop_waker_ref());
600
601        // prevent further sends, so that our drain doesn't race indefinitely with the sender
602        if let Some(mut rx_inner) =
603            StreamUnobtrusivePeeker::as_raw_inner_pin_mut(Pin::new(&mut self.rx))
604        {
605            C::close_receiver(&mut rx_inner);
606        }
607
608        while let Ready(Some(item)) = self.rx.poll_next_unpin(&mut noop_cx) {
609            drop::<Entry<T>>(item);
610        }
611    }
612}
613
614/// Method for educe's Debug impl for `ReceiverState.collapse_callbacks`
615fn receiver_state_debug_collapse_notify(
616    v: &[CollapseCallback],
617    f: &mut fmt::Formatter,
618) -> fmt::Result {
619    Debug::fmt(&v.len(), f)
620}
621
622//---------- misc ----------
623
624impl<T: HasMemoryCost> HasMemoryCost for Entry<T> {
625    fn memory_cost(&self, enabled: EnabledToken) -> usize {
626        let time_size = std::alloc::Layout::new::<CoarseInstant>().size();
627        self.t.memory_cost(enabled).saturating_add(time_size)
628    }
629}
630
631impl From<CollapsedDueToReclaim> for CollapseReason {
632    fn from(CollapsedDueToReclaim: CollapsedDueToReclaim) -> CollapseReason {
633        CollapseReason::MemoryReclaimed
634    }
635}
636
637#[cfg(all(test, feature = "memquota", not(miri) /* coarsetime */))]
638mod test {
639    // @@ begin test lint list maintained by maint/add_warning @@
640    #![allow(clippy::bool_assert_comparison)]
641    #![allow(clippy::clone_on_copy)]
642    #![allow(clippy::dbg_macro)]
643    #![allow(clippy::mixed_attributes_style)]
644    #![allow(clippy::print_stderr)]
645    #![allow(clippy::print_stdout)]
646    #![allow(clippy::single_char_pattern)]
647    #![allow(clippy::unwrap_used)]
648    #![allow(clippy::unchecked_duration_subtraction)]
649    #![allow(clippy::useless_vec)]
650    #![allow(clippy::needless_pass_by_value)]
651    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
652    #![allow(clippy::arithmetic_side_effects)] // don't mind potential panicking ops in tests
653
654    use super::*;
655    use crate::mtracker::test::*;
656    use tor_rtmock::MockRuntime;
657    use tracing::debug;
658    use tracing_test::traced_test;
659
660    #[derive(Default, Debug)]
661    struct ItemTracker {
662        state: Mutex<ItemTrackerState>,
663    }
664    #[derive(Default, Debug)]
665    struct ItemTrackerState {
666        existing: usize,
667        next_id: usize,
668    }
669
670    #[derive(Debug)]
671    struct Item {
672        id: usize,
673        tracker: Arc<ItemTracker>,
674    }
675
676    impl ItemTracker {
677        fn new_item(self: &Arc<Self>) -> Item {
678            let mut state = self.lock();
679            let id = state.next_id;
680            state.existing += 1;
681            state.next_id += 1;
682            debug!("new {id}");
683            Item {
684                tracker: self.clone(),
685                id,
686            }
687        }
688
689        fn new_tracker() -> Arc<Self> {
690            Arc::default()
691        }
692
693        fn lock(&self) -> MutexGuard<ItemTrackerState> {
694            self.state.lock().unwrap()
695        }
696    }
697
698    impl Drop for Item {
699        fn drop(&mut self) {
700            debug!("old {}", self.id);
701            self.tracker.state.lock().unwrap().existing -= 1;
702        }
703    }
704
705    impl HasMemoryCost for Item {
706        fn memory_cost(&self, _: EnabledToken) -> usize {
707            mbytes(1)
708        }
709    }
710
711    struct Setup {
712        dtp: DynTimeProvider,
713        trk: Arc<mtracker::MemoryQuotaTracker>,
714        acct: Account,
715        itrk: Arc<ItemTracker>,
716    }
717
718    fn setup(rt: &MockRuntime) -> Setup {
719        let dtp = DynTimeProvider::new(rt.clone());
720        let trk = mk_tracker(rt);
721        let acct = trk.new_account(None).unwrap();
722        let itrk = ItemTracker::new_tracker();
723        Setup {
724            dtp,
725            trk,
726            acct,
727            itrk,
728        }
729    }
730
731    #[derive(Debug)]
732    struct Gigantic;
733    impl HasMemoryCost for Gigantic {
734        fn memory_cost(&self, _et: EnabledToken) -> usize {
735            mbytes(100)
736        }
737    }
738
739    impl Setup {
740        /// Check that claims and releases have balanced out
741        ///
742        /// `n_queues` is the number of queues that exist.
743        /// This is used to provide some slop, since each queue has two [`Participation`]s
744        /// each of which can have some cached claim.
745        fn check_zero_claimed(&self, n_queues: usize) {
746            let used = self.trk.used_current_approx();
747            debug!(
748                "checking zero balance (with slop {n_queues} * 2 * {}; used={used:?}",
749                *mtracker::MAX_CACHE,
750            );
751            assert!(used.unwrap() <= n_queues * 2 * *mtracker::MAX_CACHE);
752        }
753    }
754
755    #[traced_test]
756    #[test]
757    fn lifecycle() {
758        MockRuntime::test_with_various(|rt| async move {
759            let s = setup(&rt);
760            let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
761
762            tx.send(s.itrk.new_item()).await.unwrap();
763            let _: Item = rx.next().await.unwrap();
764
765            for _ in 0..20 {
766                tx.send(s.itrk.new_item()).await.unwrap();
767            }
768
769            // reclaim task hasn't had a chance to run
770            debug!("still existing items {}", s.itrk.lock().existing);
771
772            rt.advance_until_stalled().await;
773
774            // reclaim task should have torn everything down
775            assert!(s.itrk.lock().existing == 0);
776
777            assert!(rx.next().await.is_none());
778
779            // Empirically, this is a "disconnected" error from the inner mpsc,
780            // but let's not assert that.
781            let _: SendError<_> = tx.send(s.itrk.new_item()).await.unwrap_err();
782        });
783    }
784
785    #[traced_test]
786    #[test]
787    fn fill_and_empty() {
788        MockRuntime::test_with_various(|rt| async move {
789            let s = setup(&rt);
790            let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
791
792            const COUNT: usize = 19;
793
794            for _ in 0..COUNT {
795                tx.send(s.itrk.new_item()).await.unwrap();
796            }
797
798            rt.advance_until_stalled().await;
799
800            for _ in 0..COUNT {
801                let _: Item = rx.next().await.unwrap();
802            }
803
804            rt.advance_until_stalled().await;
805
806            // no memory should be claimed
807            s.check_zero_claimed(1);
808        });
809    }
810
811    #[traced_test]
812    #[test]
813    fn sink_error() {
814        #[derive(Debug, Copy, Clone)]
815        struct BustedSink {
816            error: BustedError,
817        }
818
819        impl<T> Sink<T> for BustedSink {
820            type Error = BustedError;
821
822            fn poll_ready(
823                self: Pin<&mut Self>,
824                _: &mut Context<'_>,
825            ) -> Poll<Result<(), Self::Error>> {
826                Ready(Err(self.error))
827            }
828            fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
829                panic!("poll_ready always gives error, start_send should not be called");
830            }
831            fn poll_flush(
832                self: Pin<&mut Self>,
833                _: &mut Context<'_>,
834            ) -> Poll<Result<(), Self::Error>> {
835                Ready(Ok(()))
836            }
837            fn poll_close(
838                self: Pin<&mut Self>,
839                _: &mut Context<'_>,
840            ) -> Poll<Result<(), Self::Error>> {
841                Ready(Ok(()))
842            }
843        }
844
845        impl<T> SinkTrySend<T> for BustedSink {
846            type Error = BustedError;
847
848            fn try_send_or_return(self: Pin<&mut Self>, item: T) -> Result<(), (BustedError, T)> {
849                Err((self.error, item))
850            }
851        }
852
853        impl tor_async_utils::SinkTrySendError for BustedError {
854            fn is_disconnected(&self) -> bool {
855                self.is_disconnected
856            }
857            fn is_full(&self) -> bool {
858                false
859            }
860        }
861
862        #[derive(Error, Debug, Clone, Copy)]
863        #[error("busted, for testing, dc={is_disconnected:?}")]
864        struct BustedError {
865            is_disconnected: bool,
866        }
867
868        struct BustedQueueSpec {
869            error: BustedError,
870        }
871        impl Sealed for BustedQueueSpec {}
872        impl ChannelSpec for BustedQueueSpec {
873            type Sender<T: Debug + Send + 'static> = BustedSink;
874            type Receiver<T: Debug + Send + 'static> = futures::stream::Pending<T>;
875            type SendError = BustedError;
876            fn raw_channel<T: Debug + Send + 'static>(self) -> (BustedSink, Self::Receiver<T>) {
877                (BustedSink { error: self.error }, futures::stream::pending())
878            }
879            fn close_receiver<T: Debug + Send + 'static>(_rx: &mut Self::Receiver<T>) {}
880        }
881
882        use ErasedSinkTrySendError as ESTSE;
883
884        MockRuntime::test_with_various(|rt| async move {
885            let error = BustedError {
886                is_disconnected: true,
887            };
888
889            let s = setup(&rt);
890            let (mut tx, _rx) = BustedQueueSpec { error }
891                .new_mq(s.dtp.clone(), &s.acct)
892                .unwrap();
893
894            let e = tx.send(s.itrk.new_item()).await.unwrap_err();
895            assert!(matches!(e, SendError::Channel(BustedError { .. })));
896
897            // item should have been destroyed
898            assert_eq!(s.itrk.lock().existing, 0);
899
900            // ---- Test try_send error handling ----
901
902            fn error_is_other_of<E>(e: ESTSE) -> Result<(), impl Debug>
903            where
904                E: std::error::Error + 'static,
905            {
906                match e {
907                    ESTSE::Other(e) if e.is::<E>() => Ok(()),
908                    other => Err(other),
909                }
910            }
911
912            let item = s.itrk.new_item();
913
914            // Test try_send failure due to BustedError, is_disconnected: true
915
916            let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
917            assert!(matches!(e, ESTSE::Disconnected), "{e:?}");
918
919            // Test try_send failure due to BustedError, is_disconnected: false (ie, Other)
920
921            let error = BustedError {
922                is_disconnected: false,
923            };
924            let (mut tx, _rx) = BustedQueueSpec { error }
925                .new_mq(s.dtp.clone(), &s.acct)
926                .unwrap();
927            let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
928            error_is_other_of::<BustedError>(e).unwrap();
929
930            // no memory should be claimed
931            s.check_zero_claimed(1);
932
933            // Test try_send failure due to memory quota collapse
934
935            // cause reclaim
936            {
937                let (mut tx, _rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
938                tx.send(Gigantic).await.unwrap();
939                rt.advance_until_stalled().await;
940            }
941
942            let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
943            error_is_other_of::<crate::Error>(e).unwrap();
944
945            drop::<Item>(item);
946        });
947    }
948}