1#![cfg_attr(
39 feature = "memquota",
40 doc = "let config = tor_memquota::Config::builder().max(1024*1024*1024).build().unwrap();",
41 doc = "let trk = MemoryQuotaTracker::new(&runtime, config).unwrap();"
42)]
43#![cfg_attr(
44 not(feature = "memquota"),
45 doc = "let trk = MemoryQuotaTracker::new_noop();"
46)]
47#![forbid(unsafe_code)] use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
73
74use crate::internal_prelude::*;
75
76use std::task::{Context, Poll, Poll::*};
77use tor_async_utils::{ErasedSinkTrySendError, SinkCloseChannel, SinkTrySend};
78
79#[derive(Educe)]
86#[educe(Debug, Clone(bound = "C::Sender<Entry<T>>: Clone"))]
87pub struct Sender<T: Debug + Send + 'static, C: ChannelSpec> {
88 tx: C::Sender<Entry<T>>,
90
91 mq: TypedParticipation<Entry<T>>,
93
94 #[educe(Debug(ignore))] runtime: DynTimeProvider,
97}
98
99#[derive(Educe)] #[educe(Debug)]
107pub struct Receiver<T: Debug + Send + 'static, C: ChannelSpec> {
108 inner: Arc<ReceiverInner<T, C>>,
125}
126
127#[derive(Educe)]
132#[educe(Debug)]
133struct ReceiverInner<T: Debug + Send + 'static, C: ChannelSpec> {
134 state: Mutex<Result<ReceiverState<T, C>, CollapsedDueToReclaim>>,
141}
142
143#[derive(Educe)]
151#[educe(Debug)]
152struct ReceiverState<T: Debug + Send + 'static, C: ChannelSpec> {
153 rx: StreamUnobtrusivePeeker<C::Receiver<Entry<T>>>,
155
156 mq: TypedParticipation<Entry<T>>,
168
169 #[educe(Debug(method = "receiver_state_debug_collapse_notify"))]
173 collapse_callbacks: Vec<CollapseCallback>,
174}
175
176#[derive(Debug)]
180struct Entry<T> {
181 t: T,
183 when: CoarseInstant,
185}
186
187#[derive(Error, Clone, Debug)]
189#[non_exhaustive]
190pub enum SendError<CE> {
191 #[error("channel send failed")]
194 Channel(#[source] CE),
195
196 #[error("memory quota exhausted, queue reclaimed")]
204 Memquota(#[from] Error),
205}
206
207pub type CollapseCallback = Box<dyn FnOnce(CollapseReason) + Send + Sync + 'static>;
209
210#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
212#[non_exhaustive]
213pub enum CollapseReason {
214 ReceiverDropped,
216
217 MemoryReclaimed,
219}
220
221#[derive(Debug, Clone, Copy)]
223struct CollapsedDueToReclaim;
224
225pub trait ChannelSpec: Sealed + Sized + 'static {
245 type Sender<T: Debug + Send + 'static>: Sink<T, Error = Self::SendError>
259 + Debug + Unpin + Sized;
260
261 type Receiver<T: Debug + Send + 'static>: Stream<Item = T> + Debug + Unpin + Send + Sized;
263
264 type SendError: std::error::Error;
268
269 #[allow(clippy::type_complexity)] fn new_mq<T>(self, runtime: DynTimeProvider, account: &Account) -> crate::Result<(
276 Sender<T, Self>,
277 Receiver<T, Self>,
278 )>
279 where
280 T: HasMemoryCost + Debug + Send + 'static,
281 {
282 let (rx, (tx, mq)) = account.register_participant_with(
283 runtime.now_coarse(),
284 move |mq| {
285 let mq = TypedParticipation::new(mq);
286 let collapse_callbacks = vec![];
287 let (tx, rx) = self.raw_channel::<Entry<T>>();
288 let rx = StreamUnobtrusivePeeker::new(rx);
289 let state = ReceiverState { rx, mq: mq.clone(), collapse_callbacks };
290 let state = Mutex::new(Ok(state));
291 let inner = ReceiverInner { state };
292 Ok::<_, crate::Error>((inner.into(), (tx, mq)))
293 },
294 )??;
295
296 let runtime = runtime.clone();
297
298 let tx = Sender { runtime, tx, mq };
299 let rx = Receiver { inner: rx };
300
301 Ok((tx, rx))
302 }
303
304 fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>);
308
309 fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>);
314}
315
316#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor)]
327#[allow(clippy::exhaustive_structs)] pub struct MpscSpec {
329 pub buffer: usize,
331}
332
333#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor, Default)]
341#[allow(clippy::exhaustive_structs)] pub struct MpscUnboundedSpec;
343
344impl Sealed for MpscSpec {}
345impl Sealed for MpscUnboundedSpec {}
346
347impl ChannelSpec for MpscSpec {
348 type Sender<T: Debug + Send + 'static> = mpsc::Sender<T>;
349 type Receiver<T: Debug + Send + 'static> = mpsc::Receiver<T>;
350 type SendError = mpsc::SendError;
351
352 fn raw_channel<T: Debug + Send + 'static>(self) -> (mpsc::Sender<T>, mpsc::Receiver<T>) {
353 mpsc_channel_no_memquota(self.buffer)
354 }
355
356 fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
357 rx.close();
358 }
359}
360
361impl ChannelSpec for MpscUnboundedSpec {
362 type Sender<T: Debug + Send + 'static> = mpsc::UnboundedSender<T>;
363 type Receiver<T: Debug + Send + 'static> = mpsc::UnboundedReceiver<T>;
364 type SendError = mpsc::SendError;
365
366 fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>) {
367 mpsc::unbounded()
368 }
369
370 fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
371 rx.close();
372 }
373}
374
375impl<T, C> Sink<T> for Sender<T, C>
380where
381 T: HasMemoryCost + Debug + Send + 'static,
382 C: ChannelSpec,
383{
384 type Error = SendError<C::SendError>;
385
386 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
387 self.get_mut()
388 .tx
389 .poll_ready_unpin(cx)
390 .map_err(SendError::Channel)
391 }
392
393 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
394 let self_ = self.get_mut();
395 let item = Entry {
396 t: item,
397 when: self_.runtime.now_coarse(),
398 };
399 self_.mq.try_claim(item, |item| {
400 self_.tx.start_send_unpin(item).map_err(SendError::Channel)
401 })?
402 }
403
404 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
405 self.tx
406 .poll_flush_unpin(cx)
407 .map(|r| r.map_err(SendError::Channel))
408 }
409
410 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
411 self.tx
412 .poll_close_unpin(cx)
413 .map(|r| r.map_err(SendError::Channel))
414 }
415}
416
417impl<T, C> SinkTrySend<T> for Sender<T, C>
418where
419 T: HasMemoryCost + Debug + Send + 'static,
420 C: ChannelSpec,
421 C::Sender<Entry<T>>: SinkTrySend<Entry<T>>,
422 <C::Sender<Entry<T>> as SinkTrySend<Entry<T>>>::Error: Send + Sync,
423{
424 type Error = ErasedSinkTrySendError;
425 fn try_send_or_return(
426 self: Pin<&mut Self>,
427 item: T,
428 ) -> Result<(), (<Self as SinkTrySend<T>>::Error, T)> {
429 let self_ = self.get_mut();
430 let item = Entry {
431 t: item,
432 when: self_.runtime.now_coarse(),
433 };
434
435 use ErasedSinkTrySendError as ESTSE;
436
437 self_
438 .mq
439 .try_claim_or_return(item, |item| {
440 Pin::new(&mut self_.tx).try_send_or_return(item)
441 })
442 .map_err(|(mqe, unsent)| (ESTSE::Other(Arc::new(mqe)), unsent.t))?
443 .map_err(|(tse, unsent)| (ESTSE::from(tse), unsent.t))
444 }
445}
446
447impl<T, C> SinkCloseChannel<T> for Sender<T, C>
448where
449 T: HasMemoryCost + Debug + Send, C: ChannelSpec,
451 C::Sender<Entry<T>>: SinkCloseChannel<Entry<T>>,
452{
453 fn close_channel(self: Pin<&mut Self>) {
454 Pin::new(&mut self.get_mut().tx).close_channel();
455 }
456}
457
458impl<T, C> Sender<T, C>
459where
460 T: Debug + Send + 'static,
461 C: ChannelSpec,
462{
463 pub fn time_provider(&self) -> &DynTimeProvider {
468 &self.runtime
469 }
470}
471
472impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Stream for Receiver<T, C> {
475 type Item = T;
476
477 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
478 let mut state = self.inner.lock();
479 let state = match &mut *state {
480 Ok(y) => y,
481 Err(CollapsedDueToReclaim) => return Ready(None),
482 };
483 let ret = state.rx.poll_next_unpin(cx);
484 if let Ready(Some(item)) = &ret {
485 if let Some(enabled) = EnabledToken::new_if_compiled_in() {
486 let cost = item.typed_memory_cost(enabled);
487 state.mq.release(&cost);
488 }
489 }
490 ret.map(|r| r.map(|e| e.t))
491 }
492}
493
494impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> FusedStream for Receiver<T, C>
495where
496 C::Receiver<Entry<T>>: FusedStream,
497{
498 fn is_terminated(&self) -> bool {
499 match &*self.inner.lock() {
500 Ok(y) => y.rx.is_terminated(),
501 Err(CollapsedDueToReclaim) => true,
502 }
503 }
504}
505
506impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Receiver<T, C> {
509 pub fn register_collapse_hook(&self, call: CollapseCallback) {
527 let mut state = self.inner.lock();
528 let state = match &mut *state {
529 Ok(y) => y,
530 Err(reason) => {
531 let reason = (*reason).into();
532 drop::<MutexGuard<_>>(state);
533 call(reason);
534 return;
535 }
536 };
537 state.collapse_callbacks.push(call);
538 }
539}
540
541impl<T: Debug + Send + 'static, C: ChannelSpec> ReceiverInner<T, C> {
542 fn lock(&self) -> MutexGuard<Result<ReceiverState<T, C>, CollapsedDueToReclaim>> {
544 self.state.lock().expect("mq_mpsc lock poisoned")
545 }
546}
547
548impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> IsParticipant
549 for ReceiverInner<T, C>
550{
551 fn get_oldest(&self, _: EnabledToken) -> Option<CoarseInstant> {
552 let mut state = self.lock();
553 let state = match &mut *state {
554 Ok(y) => y,
555 Err(CollapsedDueToReclaim) => return None,
556 };
557 let peeked = Pin::new(&mut state.rx)
558 .unobtrusive_peek()
559 .map(|peeked| peeked.when);
560 peeked
561 }
562
563 fn reclaim(self: Arc<Self>, _: EnabledToken) -> mtracker::ReclaimFuture {
564 Box::pin(async move {
565 let reason = CollapsedDueToReclaim;
566 let mut state_guard = self.lock();
567 let state = mem::replace(&mut *state_guard, Err(reason));
568 drop::<MutexGuard<_>>(state_guard);
569 #[allow(clippy::single_match)] match state {
571 Ok(mut state) => {
572 for call in state.collapse_callbacks.drain(..) {
573 call(reason.into());
574 }
575 drop::<ReceiverState<_, _>>(state); }
577 Err(CollapsedDueToReclaim) => {}
578 };
579 mtracker::Reclaimed::Collapsing
580 })
581 }
582}
583
584impl<T: Debug + Send + 'static, C: ChannelSpec> Drop for ReceiverState<T, C> {
585 fn drop(&mut self) {
586 mem::replace(&mut self.mq, Participation::new_dangling().into())
590 .into_raw()
591 .destroy_participant();
592
593 for call in self.collapse_callbacks.drain(..) {
594 call(CollapseReason::ReceiverDropped);
595 }
596
597 let mut noop_cx = Context::from_waker(noop_waker_ref());
600
601 if let Some(mut rx_inner) =
603 StreamUnobtrusivePeeker::as_raw_inner_pin_mut(Pin::new(&mut self.rx))
604 {
605 C::close_receiver(&mut rx_inner);
606 }
607
608 while let Ready(Some(item)) = self.rx.poll_next_unpin(&mut noop_cx) {
609 drop::<Entry<T>>(item);
610 }
611 }
612}
613
614fn receiver_state_debug_collapse_notify(
616 v: &[CollapseCallback],
617 f: &mut fmt::Formatter,
618) -> fmt::Result {
619 Debug::fmt(&v.len(), f)
620}
621
622impl<T: HasMemoryCost> HasMemoryCost for Entry<T> {
625 fn memory_cost(&self, enabled: EnabledToken) -> usize {
626 let time_size = std::alloc::Layout::new::<CoarseInstant>().size();
627 self.t.memory_cost(enabled).saturating_add(time_size)
628 }
629}
630
631impl From<CollapsedDueToReclaim> for CollapseReason {
632 fn from(CollapsedDueToReclaim: CollapsedDueToReclaim) -> CollapseReason {
633 CollapseReason::MemoryReclaimed
634 }
635}
636
637#[cfg(all(test, feature = "memquota", not(miri) ))]
638mod test {
639 #![allow(clippy::bool_assert_comparison)]
641 #![allow(clippy::clone_on_copy)]
642 #![allow(clippy::dbg_macro)]
643 #![allow(clippy::mixed_attributes_style)]
644 #![allow(clippy::print_stderr)]
645 #![allow(clippy::print_stdout)]
646 #![allow(clippy::single_char_pattern)]
647 #![allow(clippy::unwrap_used)]
648 #![allow(clippy::unchecked_duration_subtraction)]
649 #![allow(clippy::useless_vec)]
650 #![allow(clippy::needless_pass_by_value)]
651 #![allow(clippy::arithmetic_side_effects)] use super::*;
655 use crate::mtracker::test::*;
656 use tor_rtmock::MockRuntime;
657 use tracing::debug;
658 use tracing_test::traced_test;
659
660 #[derive(Default, Debug)]
661 struct ItemTracker {
662 state: Mutex<ItemTrackerState>,
663 }
664 #[derive(Default, Debug)]
665 struct ItemTrackerState {
666 existing: usize,
667 next_id: usize,
668 }
669
670 #[derive(Debug)]
671 struct Item {
672 id: usize,
673 tracker: Arc<ItemTracker>,
674 }
675
676 impl ItemTracker {
677 fn new_item(self: &Arc<Self>) -> Item {
678 let mut state = self.lock();
679 let id = state.next_id;
680 state.existing += 1;
681 state.next_id += 1;
682 debug!("new {id}");
683 Item {
684 tracker: self.clone(),
685 id,
686 }
687 }
688
689 fn new_tracker() -> Arc<Self> {
690 Arc::default()
691 }
692
693 fn lock(&self) -> MutexGuard<ItemTrackerState> {
694 self.state.lock().unwrap()
695 }
696 }
697
698 impl Drop for Item {
699 fn drop(&mut self) {
700 debug!("old {}", self.id);
701 self.tracker.state.lock().unwrap().existing -= 1;
702 }
703 }
704
705 impl HasMemoryCost for Item {
706 fn memory_cost(&self, _: EnabledToken) -> usize {
707 mbytes(1)
708 }
709 }
710
711 struct Setup {
712 dtp: DynTimeProvider,
713 trk: Arc<mtracker::MemoryQuotaTracker>,
714 acct: Account,
715 itrk: Arc<ItemTracker>,
716 }
717
718 fn setup(rt: &MockRuntime) -> Setup {
719 let dtp = DynTimeProvider::new(rt.clone());
720 let trk = mk_tracker(rt);
721 let acct = trk.new_account(None).unwrap();
722 let itrk = ItemTracker::new_tracker();
723 Setup {
724 dtp,
725 trk,
726 acct,
727 itrk,
728 }
729 }
730
731 #[derive(Debug)]
732 struct Gigantic;
733 impl HasMemoryCost for Gigantic {
734 fn memory_cost(&self, _et: EnabledToken) -> usize {
735 mbytes(100)
736 }
737 }
738
739 impl Setup {
740 fn check_zero_claimed(&self, n_queues: usize) {
746 let used = self.trk.used_current_approx();
747 debug!(
748 "checking zero balance (with slop {n_queues} * 2 * {}; used={used:?}",
749 *mtracker::MAX_CACHE,
750 );
751 assert!(used.unwrap() <= n_queues * 2 * *mtracker::MAX_CACHE);
752 }
753 }
754
755 #[traced_test]
756 #[test]
757 fn lifecycle() {
758 MockRuntime::test_with_various(|rt| async move {
759 let s = setup(&rt);
760 let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
761
762 tx.send(s.itrk.new_item()).await.unwrap();
763 let _: Item = rx.next().await.unwrap();
764
765 for _ in 0..20 {
766 tx.send(s.itrk.new_item()).await.unwrap();
767 }
768
769 debug!("still existing items {}", s.itrk.lock().existing);
771
772 rt.advance_until_stalled().await;
773
774 assert!(s.itrk.lock().existing == 0);
776
777 assert!(rx.next().await.is_none());
778
779 let _: SendError<_> = tx.send(s.itrk.new_item()).await.unwrap_err();
782 });
783 }
784
785 #[traced_test]
786 #[test]
787 fn fill_and_empty() {
788 MockRuntime::test_with_various(|rt| async move {
789 let s = setup(&rt);
790 let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
791
792 const COUNT: usize = 19;
793
794 for _ in 0..COUNT {
795 tx.send(s.itrk.new_item()).await.unwrap();
796 }
797
798 rt.advance_until_stalled().await;
799
800 for _ in 0..COUNT {
801 let _: Item = rx.next().await.unwrap();
802 }
803
804 rt.advance_until_stalled().await;
805
806 s.check_zero_claimed(1);
808 });
809 }
810
811 #[traced_test]
812 #[test]
813 fn sink_error() {
814 #[derive(Debug, Copy, Clone)]
815 struct BustedSink {
816 error: BustedError,
817 }
818
819 impl<T> Sink<T> for BustedSink {
820 type Error = BustedError;
821
822 fn poll_ready(
823 self: Pin<&mut Self>,
824 _: &mut Context<'_>,
825 ) -> Poll<Result<(), Self::Error>> {
826 Ready(Err(self.error))
827 }
828 fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
829 panic!("poll_ready always gives error, start_send should not be called");
830 }
831 fn poll_flush(
832 self: Pin<&mut Self>,
833 _: &mut Context<'_>,
834 ) -> Poll<Result<(), Self::Error>> {
835 Ready(Ok(()))
836 }
837 fn poll_close(
838 self: Pin<&mut Self>,
839 _: &mut Context<'_>,
840 ) -> Poll<Result<(), Self::Error>> {
841 Ready(Ok(()))
842 }
843 }
844
845 impl<T> SinkTrySend<T> for BustedSink {
846 type Error = BustedError;
847
848 fn try_send_or_return(self: Pin<&mut Self>, item: T) -> Result<(), (BustedError, T)> {
849 Err((self.error, item))
850 }
851 }
852
853 impl tor_async_utils::SinkTrySendError for BustedError {
854 fn is_disconnected(&self) -> bool {
855 self.is_disconnected
856 }
857 fn is_full(&self) -> bool {
858 false
859 }
860 }
861
862 #[derive(Error, Debug, Clone, Copy)]
863 #[error("busted, for testing, dc={is_disconnected:?}")]
864 struct BustedError {
865 is_disconnected: bool,
866 }
867
868 struct BustedQueueSpec {
869 error: BustedError,
870 }
871 impl Sealed for BustedQueueSpec {}
872 impl ChannelSpec for BustedQueueSpec {
873 type Sender<T: Debug + Send + 'static> = BustedSink;
874 type Receiver<T: Debug + Send + 'static> = futures::stream::Pending<T>;
875 type SendError = BustedError;
876 fn raw_channel<T: Debug + Send + 'static>(self) -> (BustedSink, Self::Receiver<T>) {
877 (BustedSink { error: self.error }, futures::stream::pending())
878 }
879 fn close_receiver<T: Debug + Send + 'static>(_rx: &mut Self::Receiver<T>) {}
880 }
881
882 use ErasedSinkTrySendError as ESTSE;
883
884 MockRuntime::test_with_various(|rt| async move {
885 let error = BustedError {
886 is_disconnected: true,
887 };
888
889 let s = setup(&rt);
890 let (mut tx, _rx) = BustedQueueSpec { error }
891 .new_mq(s.dtp.clone(), &s.acct)
892 .unwrap();
893
894 let e = tx.send(s.itrk.new_item()).await.unwrap_err();
895 assert!(matches!(e, SendError::Channel(BustedError { .. })));
896
897 assert_eq!(s.itrk.lock().existing, 0);
899
900 fn error_is_other_of<E>(e: ESTSE) -> Result<(), impl Debug>
903 where
904 E: std::error::Error + 'static,
905 {
906 match e {
907 ESTSE::Other(e) if e.is::<E>() => Ok(()),
908 other => Err(other),
909 }
910 }
911
912 let item = s.itrk.new_item();
913
914 let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
917 assert!(matches!(e, ESTSE::Disconnected), "{e:?}");
918
919 let error = BustedError {
922 is_disconnected: false,
923 };
924 let (mut tx, _rx) = BustedQueueSpec { error }
925 .new_mq(s.dtp.clone(), &s.acct)
926 .unwrap();
927 let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
928 error_is_other_of::<BustedError>(e).unwrap();
929
930 s.check_zero_claimed(1);
932
933 {
937 let (mut tx, _rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
938 tx.send(Gigantic).await.unwrap();
939 rt.advance_until_stalled().await;
940 }
941
942 let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
943 error_is_other_of::<crate::Error>(e).unwrap();
944
945 drop::<Item>(item);
946 });
947 }
948}