1
//! Cancellable futures.
2

            
3
use std::{
4
    pin::Pin,
5
    sync::{Arc, Mutex},
6
    task::{Context, Poll, Waker},
7
};
8

            
9
use futures::{future::FusedFuture, Future};
10
use pin_project::pin_project;
11

            
12
/// A cancellable future type, loosely influenced by `RemoteHandle`.
13
///
14
/// This type is useful for cases when we can't cancel a future simply by
15
/// dropping it, because the future is owned by some other object (like a
16
/// `FuturesUnordered`) that won't give it up.
17
//
18
// We could use `tokio_util`'s cancellable futures instead here, but I don't
19
// think we want an unconditional tokio_util dependency.
20
6
#[pin_project]
21
pub(crate) struct Cancel<F> {
22
    /// Shared state between the `Cancel` and the `CancelHandle`.
23
    //
24
    // It would be nice not to have to stick this behind a mutex, but that would
25
    // make it a bit tricky to manage the Waker.
26
    inner: Arc<Mutex<Inner>>,
27
    /// The inner future.
28
    #[pin]
29
    fut: F,
30
}
31

            
32
/// Inner state shared between `Cancel` and the `CancelHandle.
33
struct Inner {
34
    /// True if this future has been cancelled.
35
    cancelled: bool,
36
    /// A waker to use in telling this future that it's cancelled.
37
    waker: Option<Waker>,
38
}
39

            
40
/// An object that can be used to cancel a future.
41
#[derive(Clone)]
42
pub(crate) struct CancelHandle {
43
    /// The shared state for the cancellable future between `Cancel` and
44
    /// `CancelHandle`.
45
    inner: Arc<Mutex<Inner>>,
46
}
47

            
48
impl<F> Cancel<F> {
49
    /// Wrap `fut` in a new future that can be cancelled.
50
    ///
51
    /// Returns a handle to cancel the future, and the cancellable future.
52
6
    pub(crate) fn new(fut: F) -> (CancelHandle, Cancel<F>) {
53
6
        let inner = Arc::new(Mutex::new(Inner {
54
6
            cancelled: false,
55
6
            waker: None,
56
6
        }));
57
6
        let handle = CancelHandle {
58
6
            inner: inner.clone(),
59
6
        };
60
6
        let future = Cancel { inner, fut };
61
6
        (handle, future)
62
6
    }
63
}
64

            
65
impl CancelHandle {
66
    /// Cancel the associated future, if it has not already finished.
67
    #[allow(dead_code)] // TODO RPC
68
4
    pub(crate) fn cancel(&self) {
69
4
        let mut inner = self.inner.lock().expect("poisoned lock");
70
4
        inner.cancelled = true;
71
4
        if let Some(waker) = inner.waker.take() {
72
4
            waker.wake();
73
4
        }
74
4
    }
75
}
76

            
77
/// An error returned from a `Cancel` future if it is cancelled.
78
#[derive(thiserror::Error, Clone, Debug)]
79
#[error("Future was cancelled")]
80
pub(crate) struct Cancelled;
81

            
82
impl<F: Future> Future for Cancel<F> {
83
    type Output = Result<F::Output, Cancelled>;
84

            
85
10
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
86
10
        {
87
10
            let mut inner = self.inner.lock().expect("lock poisoned");
88
10
            if inner.cancelled {
89
4
                return Poll::Ready(Err(Cancelled));
90
6
            }
91
6
            inner.waker = Some(cx.waker().clone());
92
6
        }
93
6
        let this = self.project();
94
6
        this.fut.poll(cx).map(Ok)
95
10
    }
96
}
97

            
98
impl<F: FusedFuture> FusedFuture for Cancel<F> {
99
    fn is_terminated(&self) -> bool {
100
        {
101
            let inner = self.inner.lock().expect("lock poisoned");
102
            if inner.cancelled {
103
                return true;
104
            }
105
        }
106
        self.fut.is_terminated()
107
    }
108
}
109

            
110
#[cfg(test)]
111
mod test {
112
    // @@ begin test lint list maintained by maint/add_warning @@
113
    #![allow(clippy::bool_assert_comparison)]
114
    #![allow(clippy::clone_on_copy)]
115
    #![allow(clippy::dbg_macro)]
116
    #![allow(clippy::mixed_attributes_style)]
117
    #![allow(clippy::print_stderr)]
118
    #![allow(clippy::print_stdout)]
119
    #![allow(clippy::single_char_pattern)]
120
    #![allow(clippy::unwrap_used)]
121
    #![allow(clippy::unchecked_duration_subtraction)]
122
    #![allow(clippy::useless_vec)]
123
    #![allow(clippy::needless_pass_by_value)]
124
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
125

            
126
    use super::*;
127
    use futures_await_test::async_test;
128
    use oneshot_fused_workaround as oneshot;
129

            
130
    #[async_test]
131
    async fn not_cancelled() {
132
        let f = futures::future::ready("hello");
133
        let (_h, f) = Cancel::new(f);
134
        assert_eq!(f.await.unwrap(), "hello");
135
    }
136

            
137
    #[async_test]
138
    async fn cancelled() {
139
        let f = futures::future::pending::<()>();
140
        let (h, f) = Cancel::new(f);
141
        let (r, ()) = futures::join!(f, async {
142
            h.cancel();
143
        });
144
        assert!(matches!(r, Err(Cancelled)));
145

            
146
        let (_tx, rx) = oneshot::channel::<()>();
147
        let (h, f) = Cancel::new(rx);
148
        let (r, ()) = futures::join!(f, async {
149
            h.cancel();
150
        });
151
        assert!(matches!(r, Err(Cancelled)));
152
    }
153
}