1use std::future::{Future, IntoFuture};
13use std::ops::Drop;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex, OnceLock, Weak};
16use std::task::{ready, Context, Poll, Waker};
17
18use slotmap_careful::DenseSlotMap;
19
20slotmap_careful::new_key_type! { struct WakerKey; }
21
22#[derive(Debug)]
24pub(crate) struct Sender<T> {
25 shared: Weak<Shared<T>>,
27}
28
29#[derive(Clone, Debug)]
47pub(crate) struct Receiver<T> {
48 shared: Arc<Shared<T>>,
50}
51
52#[derive(Debug)]
75struct Shared<T> {
76 msg: OnceLock<Result<T, SenderDropped>>,
78 wakers: Mutex<Result<DenseSlotMap<WakerKey, Waker>, WakersAlreadyWoken>>,
84}
85
86#[derive(Debug)]
91pub(crate) struct BorrowedReceiverFuture<'a, T> {
92 shared: &'a Shared<T>,
94 waker_key: Option<WakerKey>,
96}
97
98#[derive(Debug)]
109pub(crate) struct ReceiverFuture<T> {
110 shared: Arc<Shared<T>>,
112 waker_key: Option<WakerKey>,
114}
115
116#[derive(Copy, Clone, Debug)]
123struct WakersAlreadyWoken;
124
125#[derive(Copy, Clone, Debug, thiserror::Error)]
127#[error("the message was already set")]
128struct MessageAlreadySet;
129
130#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
132#[error("the sender was dropped")]
133pub(crate) struct SenderDropped;
134
135pub(crate) fn channel<T>() -> (Sender<T>, Receiver<T>) {
145 let shared = Arc::new(Shared {
146 msg: OnceLock::new(),
147 wakers: Mutex::new(Ok(DenseSlotMap::with_key())),
148 });
149
150 let sender = Sender {
151 shared: Arc::downgrade(&shared),
152 };
153
154 let receiver = Receiver { shared };
155
156 (sender, receiver)
157}
158
159impl<T> Sender<T> {
160 #[cfg_attr(not(test), allow(dead_code))]
164 pub(crate) fn send(self, msg: T) {
165 Self::send_and_wake(&self.shared, Ok(msg))
167 .expect("could not set the message");
171 }
172
173 fn send_and_wake(
179 shared: &Weak<Shared<T>>,
180 msg: Result<T, SenderDropped>,
181 ) -> Result<(), MessageAlreadySet> {
182 let Some(shared) = shared.upgrade() else {
187 return Ok(());
189 };
190
191 shared.msg.set(msg).or(Err(MessageAlreadySet))?;
193
194 let mut wakers = {
195 let mut wakers = shared.wakers.lock().expect("poisoned");
196 std::mem::replace(&mut *wakers, Err(WakersAlreadyWoken))
205 .expect("wakers were taken more than once")
206 };
207
208 for (_key, waker) in wakers.drain() {
218 waker.wake();
219 }
220
221 Ok(())
222 }
223
224 #[cfg_attr(not(test), allow(dead_code))]
234 pub(crate) fn is_cancelled(&self) -> bool {
235 self.shared.strong_count() == 0
236 }
237}
238
239impl<T> Drop for Sender<T> {
240 fn drop(&mut self) {
241 let _ = Self::send_and_wake(&self.shared, Err(SenderDropped));
245 }
246}
247
248impl<T> Receiver<T> {
249 #[cfg_attr(not(test), allow(dead_code))]
256 pub(crate) fn borrowed(&self) -> BorrowedReceiverFuture<'_, T> {
257 BorrowedReceiverFuture {
258 shared: &self.shared,
259 waker_key: None,
260 }
261 }
262
263 pub(crate) fn is_ready(&self) -> bool {
267 self.shared.msg.get().is_some()
268 }
269}
270
271impl<T: Clone> IntoFuture for Receiver<T> {
272 type Output = Result<T, SenderDropped>;
273 type IntoFuture = ReceiverFuture<T>;
274
275 fn into_future(self) -> Self::IntoFuture {
277 ReceiverFuture {
278 shared: self.shared,
279 waker_key: None,
280 }
281 }
282}
283
284impl<'a, T> Future for BorrowedReceiverFuture<'a, T> {
285 type Output = Result<&'a T, SenderDropped>;
286
287 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
288 let self_ = self.get_mut();
289 receiver_fut_poll(self_.shared, &mut self_.waker_key, cx.waker())
290 }
291}
292
293impl<T> Drop for BorrowedReceiverFuture<'_, T> {
294 fn drop(&mut self) {
295 receiver_fut_drop(self.shared, &mut self.waker_key);
296 }
297}
298
299impl<T: Clone> Future for ReceiverFuture<T> {
300 type Output = Result<T, SenderDropped>;
301
302 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
303 let self_ = self.get_mut();
304 let poll = receiver_fut_poll(&self_.shared, &mut self_.waker_key, cx.waker());
305 Poll::Ready(ready!(poll)).map_ok(Clone::clone)
306 }
307}
308
309impl<T> Drop for ReceiverFuture<T> {
310 fn drop(&mut self) {
311 receiver_fut_drop(&self.shared, &mut self.waker_key);
312 }
313}
314
315fn receiver_fut_poll<'a, T>(
317 shared: &'a Shared<T>,
318 waker_key: &mut Option<WakerKey>,
319 new_waker: &Waker,
320) -> Poll<Result<&'a T, SenderDropped>> {
321 if let Some(msg) = shared.msg.get() {
323 return Poll::Ready(msg.as_ref().or(Err(SenderDropped)));
324 }
325
326 let mut wakers = shared.wakers.lock().expect("poisoned");
327
328 if let Some(msg) = shared.msg.get() {
330 return Poll::Ready(msg.as_ref().or(Err(SenderDropped)));
331 }
332
333 let wakers = wakers.as_mut().expect("wakers were already woken");
337
338 match waker_key {
339 Some(waker_key) => {
341 let waker = wakers
343 .get_mut(*waker_key)
344 .expect("waker key is missing from map");
347 waker.clone_from(new_waker);
348 }
349 None => {
351 let new_key = wakers.insert(new_waker.clone());
353 *waker_key = Some(new_key);
354 }
355 }
356
357 Poll::Pending
358}
359
360fn receiver_fut_drop<T>(shared: &Shared<T>, waker_key: &mut Option<WakerKey>) {
362 if let Some(waker_key) = waker_key.take() {
363 let mut wakers = shared.wakers.lock().expect("poisoned");
364 if let Ok(wakers) = wakers.as_mut() {
365 let waker = wakers.remove(waker_key);
366 debug_assert!(waker.is_some(), "the waker key was not found");
369 }
370 }
371}
372
373#[cfg(test)]
374mod test {
375 #![allow(clippy::unwrap_used)]
376
377 use super::*;
378
379 use futures::future::FutureExt;
380 use futures::task::SpawnExt;
381
382 impl<T> Shared<T> {
383 fn count_wakers(&self) -> usize {
385 self.wakers
386 .lock()
387 .expect("poisoned")
388 .as_ref()
389 .map(|x| x.len())
390 .unwrap_or(0)
391 }
392 }
393
394 #[test]
395 fn standard_usage() {
396 tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
397 let (tx, rx) = channel();
398 tx.send(0_u8);
399 assert_eq!(rx.borrowed().await, Ok(&0));
400
401 let (tx, rx) = channel();
402 tx.send(0_u8);
403 assert_eq!(rx.await, Ok(0));
404 });
405 }
406
407 #[test]
408 fn immediate_drop() {
409 let _ = channel::<()>();
410
411 let (tx, rx) = channel::<()>();
412 drop(tx);
413 drop(rx);
414
415 let (tx, rx) = channel::<()>();
416 drop(rx);
417 drop(tx);
418 }
419
420 #[test]
421 fn drop_sender() {
422 tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
423 let (tx, rx_1) = channel::<u8>();
424
425 let rx_2 = rx_1.clone();
426 drop(tx);
427 let rx_3 = rx_1.clone();
428 assert_eq!(rx_1.borrowed().await, Err(SenderDropped));
429 assert_eq!(rx_2.borrowed().await, Err(SenderDropped));
430 assert_eq!(rx_3.borrowed().await, Err(SenderDropped));
431 });
432 }
433
434 #[test]
435 fn clone_before_send() {
436 tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
437 let (tx, rx_1) = channel();
438
439 let rx_2 = rx_1.clone();
440 tx.send(0_u8);
441 assert_eq!(rx_1.borrowed().await, Ok(&0));
442 assert_eq!(rx_2.borrowed().await, Ok(&0));
443 });
444 }
445
446 #[test]
447 fn clone_after_send() {
448 tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
449 let (tx, rx_1) = channel();
450
451 tx.send(0_u8);
452 let rx_2 = rx_1.clone();
453 assert_eq!(rx_1.borrowed().await, Ok(&0));
454 assert_eq!(rx_2.borrowed().await, Ok(&0));
455 });
456 }
457
458 #[test]
459 fn clone_after_borrowed() {
460 tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
461 let (tx, rx_1) = channel();
462
463 tx.send(0_u8);
464 assert_eq!(rx_1.borrowed().await, Ok(&0));
465 let rx_2 = rx_1.clone();
466 assert_eq!(rx_2.borrowed().await, Ok(&0));
467 });
468 }
469
470 #[test]
471 fn drop_one_receiver() {
472 tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
473 let (tx, rx_1) = channel();
474
475 let rx_2 = rx_1.clone();
476 drop(rx_1);
477 tx.send(0_u8);
478 assert_eq!(rx_2.borrowed().await, Ok(&0));
479 });
480 }
481
482 #[test]
483 fn drop_all_receivers() {
484 let (tx, rx_1) = channel();
485
486 let rx_2 = rx_1.clone();
487 drop(rx_1);
488 drop(rx_2);
489 tx.send(0_u8);
490 }
491
492 #[test]
493 fn drop_fut() {
494 let (_tx, rx) = channel::<u8>();
495 let fut = rx.borrowed();
496 assert_eq!(rx.shared.count_wakers(), 0);
497 drop(fut);
498 assert_eq!(rx.shared.count_wakers(), 0);
499
500 let (tx, rx) = channel();
502 tx.send(0_u8);
503 let fut = rx.borrowed();
504 assert_eq!(rx.shared.count_wakers(), 0);
505 drop(fut);
506 assert_eq!(rx.shared.count_wakers(), 0);
507
508 let (_tx, rx) = channel::<u8>();
510 let mut fut = Box::pin(rx.borrowed());
511 assert_eq!(rx.shared.count_wakers(), 0);
512 assert_eq!(fut.as_mut().now_or_never(), None);
513 assert_eq!(rx.shared.count_wakers(), 1);
514 drop(fut);
515 assert_eq!(rx.shared.count_wakers(), 0);
516
517 let (tx, rx) = channel();
519 let mut fut = Box::pin(rx.borrowed());
520 assert_eq!(rx.shared.count_wakers(), 0);
521 assert_eq!(fut.as_mut().now_or_never(), None);
522 assert_eq!(rx.shared.count_wakers(), 1);
523 tx.send(0_u8);
524 assert_eq!(rx.shared.count_wakers(), 0);
525 drop(fut);
526 }
527
528 #[test]
529 fn drop_owned_fut() {
530 let (_tx, rx) = channel::<u8>();
531 let fut = rx.clone().into_future();
532 assert_eq!(rx.shared.count_wakers(), 0);
533 drop(fut);
534 assert_eq!(rx.shared.count_wakers(), 0);
535
536 let (tx, rx) = channel();
538 tx.send(0_u8);
539 let fut = rx.clone().into_future();
540 assert_eq!(rx.shared.count_wakers(), 0);
541 drop(fut);
542 assert_eq!(rx.shared.count_wakers(), 0);
543
544 let (_tx, rx) = channel::<u8>();
546 let mut fut = Box::pin(rx.clone().into_future());
547 assert_eq!(rx.shared.count_wakers(), 0);
548 assert_eq!(fut.as_mut().now_or_never(), None);
549 assert_eq!(rx.shared.count_wakers(), 1);
550 drop(fut);
551 assert_eq!(rx.shared.count_wakers(), 0);
552
553 let (tx, rx) = channel();
555 let mut fut = Box::pin(rx.clone().into_future());
556 assert_eq!(rx.shared.count_wakers(), 0);
557 assert_eq!(fut.as_mut().now_or_never(), None);
558 assert_eq!(rx.shared.count_wakers(), 1);
559 tx.send(0_u8);
560 assert_eq!(rx.shared.count_wakers(), 0);
561 drop(fut);
562 }
563
564 #[test]
565 fn is_ready_after_send() {
566 let (tx, rx_1) = channel();
567 assert!(!rx_1.is_ready());
568 let rx_2 = rx_1.clone();
569 assert!(!rx_2.is_ready());
570
571 tx.send(0_u8);
572
573 assert!(rx_1.is_ready());
574 assert!(rx_2.is_ready());
575
576 let rx_3 = rx_1.clone();
577 assert!(rx_3.is_ready());
578 }
579
580 #[test]
581 fn is_ready_after_drop() {
582 let (tx, rx_1) = channel::<u8>();
583 assert!(!rx_1.is_ready());
584 let rx_2 = rx_1.clone();
585 assert!(!rx_2.is_ready());
586
587 drop(tx);
588
589 assert!(rx_1.is_ready());
590 assert!(rx_2.is_ready());
591
592 let rx_3 = rx_1.clone();
593 assert!(rx_3.is_ready());
594 }
595
596 #[test]
597 fn is_cancelled() {
598 let (tx, rx) = channel::<u8>();
599 assert!(!tx.is_cancelled());
600 drop(rx);
601 assert!(tx.is_cancelled());
602
603 let (tx, rx_1) = channel::<u8>();
604 assert!(!tx.is_cancelled());
605 let rx_2 = rx_1.clone();
606 drop(rx_1);
607 assert!(!tx.is_cancelled());
608 drop(rx_2);
609 assert!(tx.is_cancelled());
610 }
611
612 #[test]
613 fn recv_in_task() {
614 tor_rtmock::MockRuntime::test_with_various(|rt| async move {
615 let (tx, rx) = channel();
616
617 let join = rt
618 .spawn_with_handle(async move {
619 assert_eq!(rx.borrowed().await, Ok(&0));
620 assert_eq!(rx.await, Ok(0));
621 })
622 .unwrap();
623
624 tx.send(0_u8);
625
626 join.await;
627 });
628 }
629
630 #[test]
631 fn recv_multiple_in_task() {
632 tor_rtmock::MockRuntime::test_with_various(|rt| async move {
633 let (tx, rx) = channel();
634 let rx_1 = rx.clone();
635 let rx_2 = rx.clone();
636
637 let join_1 = rt
638 .spawn_with_handle(async move {
639 assert_eq!(rx_1.borrowed().await, Ok(&0));
640 })
641 .unwrap();
642 let join_2 = rt
643 .spawn_with_handle(async move {
644 assert_eq!(rx_2.await, Ok(0));
645 })
646 .unwrap();
647
648 tx.send(0_u8);
649
650 join_1.await;
651 join_2.await;
652 assert_eq!(rx.borrowed().await, Ok(&0));
653 });
654 }
655
656 #[test]
657 fn recv_multiple_times() {
658 tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
659 let (tx, rx) = channel();
660
661 tx.send(0_u8);
662 assert_eq!(rx.borrowed().await, Ok(&0));
663 assert_eq!(rx.borrowed().await, Ok(&0));
664 assert_eq!(rx.clone().await, Ok(0));
665 assert_eq!(rx.await, Ok(0));
666 });
667 }
668
669 #[test]
670 fn stress() {
671 tor_rtmock::MockRuntime::test_with_various(|rt| async move {
685 let (tx, rx) = channel();
686
687 rt.spawn(async move {
688 for _ in 0..20 {
691 tor_rtcompat::task::yield_now().await;
692 }
693 tx.send(0_u8);
694 })
695 .unwrap();
696
697 let mut joins = vec![];
698 for _ in 0..100 {
699 let rx_clone = rx.clone();
700 let join = rt
701 .spawn_with_handle(async move { rx_clone.borrowed().await.cloned() })
702 .unwrap();
703 joins.push(join);
704 tor_rtcompat::task::yield_now().await;
706 }
707
708 for join in joins {
709 assert!(matches!(join.await, Ok(0)));
710 }
711 });
712 }
713}