tor_proto/util/notify.rs
1//! An async notification channel.
2//!
3//! This channel allows one task to notify another. No data is passed from the sender to receiver. A
4//! [`NotifySender`] may send multiple notifications and a [`NotifyReceiver`] may receive multiple
5//! notifications. Notifications will be coalesced, so if a `NotifySender` sends multiple
6//! notifications, the `NotifyReceiver` may or may not receive all of the notifications. If there
7//! are multiple `NotifyReceiver`s, each will be notified.
8//!
9//! An optional type can be attached to the `NotifySender` and `NotifyReceiver` to identify the
10//! purpose of the notifications and to provide type checking.
11
12// TODO(arti#534): we expect to use this for flow control, so we should remove this later
13#![cfg_attr(not(test), expect(dead_code))]
14
15use std::marker::PhantomData;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19use educe::Educe;
20use futures::stream::{Fuse, FusedStream};
21use futures::{Stream, StreamExt};
22use pin_project::pin_project;
23use postage::watch;
24
25/// A [`NotifySender`] which can notify [`NotifyReceiver`]s.
26///
27/// See the [module documentation](self) for details.
28#[derive(Educe)]
29#[educe(Debug)]
30pub(crate) struct NotifySender<T = ()> {
31 /// The "sender" we use to implement the async behaviour.
32 sender: watch::Sender<()>,
33 /// Allows the user to optionally attach a type marker to identify the purpose of the
34 /// notifications.
35 #[educe(Debug(ignore))]
36 _marker: PhantomData<fn() -> T>,
37}
38
39/// A [`NotifyReceiver`] which can receive notifications from a [`NotifySender`].
40///
41/// See the [module documentation](self) for details.
42// We should theoretically be able to impl `Clone`, but `Fuse` does not implement `Clone` so we'd
43// have to implement something manually. If we do want `Clone` in the future, be careful about the
44// initial state of the new `NotifyReceiver` (see the `try_recv` in `NotifySender::subscribe`).
45#[derive(Educe)]
46#[educe(Debug)]
47#[pin_project]
48pub(crate) struct NotifyReceiver<T = ()> {
49 /// The "receiver" we use to implement the async behaviour.
50 #[pin]
51 receiver: Fuse<watch::Receiver<()>>,
52 /// Allows the user to optionally attach a type marker to identify the purpose of the
53 /// notifications.
54 #[educe(Debug(ignore))]
55 _marker: PhantomData<fn() -> T>,
56}
57
58impl NotifySender {
59 /// Create a new untyped [`NotifySender`].
60 pub(crate) fn new() -> Self {
61 Self::new_typed()
62 }
63}
64
65impl<T> NotifySender<T> {
66 /// Create a new typed [`NotifySender<T>`].
67 pub(crate) fn new_typed() -> Self {
68 let (sender, _receiver) = watch::channel();
69 Self {
70 sender,
71 _marker: Default::default(),
72 }
73 }
74
75 /// Notify all [`NotifyReceiver`]s.
76 pub(crate) fn notify(&mut self) {
77 // from `postage::watch::Sender`:
78 // > Mutably borrows the contained value, blocking the channel while the borrow is held.
79 // > After the borrow is released, receivers will be notified of a new value.
80 self.sender.borrow_mut();
81 }
82
83 /// Create a new [`NotifyReceiver`] for this [`NotifySender`].
84 ///
85 /// A new `NotifyReceiver` will not see any past notifications.
86 pub(crate) fn subscribe(&mut self) -> NotifyReceiver<T> {
87 let mut receiver = self.sender.subscribe();
88
89 // a `watch::Receiver` will always return the existing status of the `watch::Sender` as the
90 // first stream item, so we need to recv and discard it so that this `NotifyReceiver` begins
91 // in the "pending" state
92 use postage::stream::Stream as PostageStream;
93 use postage::stream::TryRecvError;
94 assert_eq!(PostageStream::try_recv(&mut receiver), Ok(()));
95 assert_eq!(
96 PostageStream::try_recv(&mut receiver),
97 Err(TryRecvError::Pending),
98 );
99
100 NotifyReceiver {
101 receiver: receiver.fuse(),
102 _marker: Default::default(),
103 }
104 }
105}
106
107impl<T> Stream for NotifyReceiver<T> {
108 type Item = ();
109
110 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111 self.project().receiver.poll_next(cx)
112 }
113}
114
115// the `NotifyReceiver` stores a `Fuse`
116impl<T> FusedStream for NotifyReceiver<T> {
117 fn is_terminated(&self) -> bool {
118 self.receiver.is_terminated()
119 }
120}
121
122#[cfg(test)]
123mod test {
124 #![allow(clippy::unwrap_used)]
125
126 use super::*;
127
128 use futures::FutureExt;
129
130 #[test]
131 fn notify() {
132 tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
133 let mut sender = NotifySender::new();
134 let mut receiver = sender.subscribe();
135
136 // receivers should initially wait for a notification
137 assert_eq!(receiver.next().now_or_never(), None);
138 assert_eq!(receiver.next().now_or_never(), None);
139
140 sender.notify();
141
142 // we should receive a single notification
143 assert_eq!(receiver.next().now_or_never(), Some(Some(())));
144 assert_eq!(receiver.next().now_or_never(), None);
145
146 sender.notify();
147 sender.notify();
148 sender.notify();
149
150 // we should still receive a single notification
151 assert_eq!(receiver.next().now_or_never(), Some(Some(())));
152 assert_eq!(receiver.next().now_or_never(), None);
153
154 sender.notify();
155 drop(sender);
156
157 // we should see the last notification, and then since we dropped the sender, the stream
158 // should indicate that it's finished
159 assert_eq!(receiver.next().now_or_never(), Some(Some(())));
160 assert_eq!(receiver.next().now_or_never(), Some(None));
161 assert_eq!(receiver.next().now_or_never(), Some(None));
162 });
163 }
164
165 #[test]
166 fn notify_multiple_receivers() {
167 tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
168 let mut sender = NotifySender::new();
169 let mut receiver_1 = sender.subscribe();
170 let mut receiver_2 = sender.subscribe();
171
172 sender.notify();
173
174 let mut receiver_3 = sender.subscribe();
175
176 // first two receivers should each receive a notification
177 assert_eq!(receiver_1.next().now_or_never(), Some(Some(())));
178 assert_eq!(receiver_2.next().now_or_never(), Some(Some(())));
179
180 // third receiver should not receive a notification since it was created after the
181 // notification was generated
182 assert_eq!(receiver_3.next().now_or_never(), None);
183 });
184 }
185}