arti_rpcserver/cancel.rs
1//! Cancellable futures.
2
3use std::{
4 pin::Pin,
5 sync::{Arc, Mutex},
6 task::{Context, Poll, Waker},
7};
8
9use futures::Future;
10use pin_project::pin_project;
11
12/// A cancellable future type, loosely influenced by `RemoteHandle`.
13///
14/// This type is useful for cases when we can't cancel a future simply by
15/// dropping it, because the future is owned by some other object (like a
16/// `FuturesUnordered`) that won't give it up.
17///
18/// # Limitations
19///
20/// Do not try to cancel a future from inside a cancellable future,
21/// including the future itself:
22/// this may cause a panic or deadlock.
23///
24/// In `arti-rpcserver`, we prevent this happening by ensuring that
25/// every method that calls `cancel()` is itself uncancellable.
26///
27// TODO: We should probably fix this limitation somehow before exposing
28// this code outside of this crate. But see comments inside `Cancel::poll`
29// for why we might not want to just drop the lock while polling.
30//
31// Also: We could use `tokio_util`'s cancellable futures instead here, but I don't
32// think we want an unconditional tokio_util dependency.
33#[pin_project]
34pub(crate) struct Cancel<F> {
35 /// Shared state between the `Cancel` and the `CancelHandle`.
36 //
37 // It would be nice not to have to stick this behind a mutex, but that would
38 // make it a bit tricky to manage the Waker.
39 inner: Arc<Mutex<Inner>>,
40 /// The inner future.
41 ///
42 /// TODO: Possibly we should move this into `inner`,
43 /// so that we can make sure that we don't execute the future without holding the lock,
44 /// and so we can drop the future immediately when it's cancelled.
45 /// But that would take some fairly tricky type erasure, so maybe it isn't worth it?
46 #[pin]
47 fut: F,
48}
49
50/// Possible status of `Cancel` future.
51#[derive(Clone, Copy, Debug)]
52enum Status {
53 /// The future has neither finished, nor been cancelled.
54 Pending,
55 /// The future has finished; it can no longer be cancelled.
56 Finished,
57 /// The future has been cancelled; it should no longer be polled.
58 Cancelled,
59}
60
61/// Inner state shared between `Cancel` and the `CancelHandle.
62struct Inner {
63 /// Current status of the future.
64 status: Status,
65 /// A waker to use in telling this future that it's cancelled.
66 waker: Option<Waker>,
67}
68
69/// An object that can be used to cancel a future.
70#[derive(Clone)]
71pub(crate) struct CancelHandle {
72 /// The shared state for the cancellable future between `Cancel` and
73 /// `CancelHandle`.
74 inner: Arc<Mutex<Inner>>,
75}
76
77impl<F> Cancel<F> {
78 /// Wrap `fut` in a new future that can be cancelled.
79 ///
80 /// Returns a handle to cancel the future, and the cancellable future.
81 pub(crate) fn new(fut: F) -> (CancelHandle, Cancel<F>) {
82 let inner = Arc::new(Mutex::new(Inner {
83 status: Status::Pending,
84 waker: None,
85 }));
86 let handle = CancelHandle {
87 inner: inner.clone(),
88 };
89 let future = Cancel { inner, fut };
90 (handle, future)
91 }
92}
93
94impl CancelHandle {
95 /// Cancel the associated future, if it has not already finished.
96 ///
97 /// # Limitations
98 ///
99 /// This function may panic or deadlock if you call it from inside a `Cancel<F>`
100 /// future. See discussion in [`Cancel`] documentation.
101 pub(crate) fn cancel(&self) -> Result<(), CannotCancel> {
102 let mut inner = self.inner.lock().expect("poisoned lock");
103 match inner.status {
104 Status::Pending => inner.status = Status::Cancelled,
105 Status::Finished => return Err(CannotCancel::Finished),
106 Status::Cancelled => return Err(CannotCancel::Cancelled),
107 }
108 if let Some(waker) = inner.waker.take() {
109 drop(inner); // release lock.
110 waker.wake();
111 }
112 Ok(())
113 }
114}
115
116/// An error returned from a `Cancel` future if it is cancelled.
117#[derive(thiserror::Error, Clone, Debug)]
118#[error("Future was cancelled")]
119pub(crate) struct Cancelled;
120
121/// An error returned when we cannot cancel a future.
122#[derive(thiserror::Error, Clone, Debug)]
123pub(crate) enum CannotCancel {
124 /// This future was already cancelled, and can't be cancelled again.
125 #[error("Already cancelled")]
126 Cancelled,
127
128 /// This future has already completed, and can't be cancelled.
129 #[error("Already finished")]
130 Finished,
131}
132
133impl<F: Future> Future for Cancel<F> {
134 type Output = Result<F::Output, Cancelled>;
135
136 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137 let this = self.project();
138
139 let mut inner = this.inner.lock().expect("lock poisoned");
140 match inner.status {
141 Status::Pending => {}
142 Status::Finished => {
143 // Yes, we do intentionally allow a finished future to be polled again.
144 // This does not violate our invariants.
145 // If you want to prevent this, you need to use Fuse or a similar mechanism.
146 }
147 Status::Cancelled => return Poll::Ready(Err(Cancelled)),
148 }
149 // Note that we're holding the mutex here while we poll the future.
150 // This guarantees that the future can't make _any_ progress after it has been
151 // cancelled. If we someday decide we don't care about that, we could release the mutex
152 // while polling, and pick it up again after we're done polling.
153 match this.fut.poll(cx) {
154 Poll::Ready(val) => {
155 inner.status = Status::Finished;
156 Poll::Ready(Ok(val))
157 }
158 Poll::Pending => {
159 if let Some(existing_waker) = &mut inner.waker {
160 // If we already have a waker, we use clone_from here,
161 // since that function knows to use will_wake
162 // to avoid a needless clone.
163 existing_waker.clone_from(cx.waker());
164 } else {
165 // Otherwise, we need to clone cx.waker().
166 inner.waker = Some(cx.waker().clone());
167 }
168 Poll::Pending
169 }
170 }
171 }
172}
173
174#[cfg(test)]
175mod test {
176 // @@ begin test lint list maintained by maint/add_warning @@
177 #![allow(clippy::bool_assert_comparison)]
178 #![allow(clippy::clone_on_copy)]
179 #![allow(clippy::dbg_macro)]
180 #![allow(clippy::mixed_attributes_style)]
181 #![allow(clippy::print_stderr)]
182 #![allow(clippy::print_stdout)]
183 #![allow(clippy::single_char_pattern)]
184 #![allow(clippy::unwrap_used)]
185 #![allow(clippy::unchecked_duration_subtraction)]
186 #![allow(clippy::useless_vec)]
187 #![allow(clippy::needless_pass_by_value)]
188 //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
189
190 use std::{future, time::Duration};
191
192 use super::*;
193 use futures::{stream::FuturesUnordered, FutureExt as _, StreamExt as _};
194 use futures_await_test::async_test;
195 use oneshot_fused_workaround as oneshot;
196 use tor_basic_utils::RngExt;
197 use tor_rtcompat::SleepProvider as _;
198
199 #[async_test]
200 async fn not_cancelled() {
201 let f = futures::future::ready("hello");
202 let (_h, f) = Cancel::new(f);
203 assert_eq!(f.await.unwrap(), "hello");
204 }
205
206 #[async_test]
207 async fn cancelled() {
208 let f = futures::future::pending::<()>();
209 let (h, f) = Cancel::new(f);
210 let (r, ()) = futures::join!(f, async {
211 h.cancel().unwrap();
212 });
213 assert!(matches!(r, Err(Cancelled)));
214
215 let (_tx, rx) = oneshot::channel::<()>();
216 let (h, f) = Cancel::new(rx);
217 let (r, ()) = futures::join!(f, async {
218 h.cancel().unwrap();
219 });
220 assert!(matches!(r, Err(Cancelled)));
221 }
222
223 #[test]
224 fn cancelled_or_not() {
225 // This looks pretty complicated! But really what we're doing is running a whole bunch
226 // of tasks and cancelling them almost-immediately, to make sure that every task either
227 // succeeds or fails.
228
229 tor_rtmock::MockRuntime::test_with_various(|rt| async move {
230 #[allow(deprecated)] // TODO #1885
231 let rt = tor_rtmock::MockSleepRuntime::new(rt);
232
233 const N_TRIES: usize = 1024;
234 // Time is virtual here, so the interval doesn't matter.
235 const SLEEP_CEIL: Duration = Duration::from_millis(1);
236 let work_succeeded = Arc::new(Mutex::new([None; N_TRIES]));
237 let cancel_succeeded = Arc::new(Mutex::new([None; N_TRIES]));
238
239 let mut futs = FuturesUnordered::new();
240 for idx in 0..N_TRIES {
241 let work_succeeded = Arc::clone(&work_succeeded);
242 let cancel_succeeded = Arc::clone(&cancel_succeeded);
243 let rt1 = rt.clone();
244 let rt2 = rt.clone();
245 let t1 = rand::rng().gen_range_infallible(..=SLEEP_CEIL);
246 let t2 = rand::rng().gen_range_infallible(..=SLEEP_CEIL);
247
248 let work = future::ready(());
249 let (handle, work) = Cancel::new(work);
250 let f1 = async move {
251 rt1.sleep(t1).await;
252 let r = handle.cancel();
253 cancel_succeeded.lock().unwrap()[idx] = Some(r.is_ok());
254 };
255 let f2 = async move {
256 rt2.sleep(t2).await;
257 let r = work.await;
258 work_succeeded.lock().unwrap()[idx] = Some(r.is_ok());
259 };
260
261 futs.push(f1.boxed());
262 futs.push(f2.boxed());
263 }
264
265 rt.wait_for(async { while let Some(()) = futs.next().await {} })
266 .await;
267 for idx in 0..N_TRIES {
268 let ws = work_succeeded.lock().unwrap()[idx];
269 let cs = cancel_succeeded.lock().unwrap()[idx];
270 match (ws, cs) {
271 (Some(true), Some(false)) => {}
272 (Some(false), Some(true)) => {}
273 _ => panic!("incorrect values {:?}", (idx, ws, cs)),
274 }
275 }
276 });
277 }
278}