1
//! Mocking helpers for testing with futures::io types.
2
//!
3
//! Note that some of this code might be of general use, but for now
4
//! we're only trying it for testing.
5

            
6
#![forbid(unsafe_code)] // if you remove this, enable (or write) miri tests (git grep miri)
7

            
8
use crate::util::mpsc_channel;
9
use futures::channel::mpsc;
10
use futures::io::{AsyncRead, AsyncWrite};
11
use futures::sink::{Sink, SinkExt};
12
use futures::stream::Stream;
13
use std::io::{Error as IoError, ErrorKind, Result as IoResult};
14
use std::pin::Pin;
15
use std::task::{Context, Poll};
16
use tor_rtcompat::{StreamOps, UnsupportedStreamOp};
17

            
18
/// Channel capacity for our internal MPSC channels.
19
///
20
/// We keep this intentionally low to make sure that some blocking
21
/// will occur occur.
22
const CAPACITY: usize = 4;
23

            
24
/// Maximum size for a queued buffer on a local chunk.
25
///
26
/// This size is deliberately weird, to try to find errors.
27
const CHUNKSZ: usize = 213;
28

            
29
/// Construct a new pair of linked LocalStream objects.
30
///
31
/// Any bytes written to one will be readable on the other, and vice
32
/// versa.  These streams will behave more or less like a socketpair,
33
/// except without actually going through the operating system.
34
///
35
/// Note that this implementation is intended for testing only, and
36
/// isn't optimized.
37
1296
pub fn stream_pair() -> (LocalStream, LocalStream) {
38
1296
    let (w1, r2) = mpsc_channel(CAPACITY);
39
1296
    let (w2, r1) = mpsc_channel(CAPACITY);
40
1296
    let s1 = LocalStream {
41
1296
        w: w1,
42
1296
        r: r1,
43
1296
        pending_bytes: Vec::new(),
44
1296
        tls_cert: None,
45
1296
    };
46
1296
    let s2 = LocalStream {
47
1296
        w: w2,
48
1296
        r: r2,
49
1296
        pending_bytes: Vec::new(),
50
1296
        tls_cert: None,
51
1296
    };
52
1296
    (s1, s2)
53
1296
}
54

            
55
/// One half of a pair of linked streams returned by [`stream_pair`].
56
//
57
// Implementation notes: linked streams are made out a pair of mpsc
58
// channels.  There's one channel for sending bytes in each direction.
59
// Bytes are sent as IoResult<Vec<u8>>: sending an error causes an error
60
// to occur on the other side.
61
pub struct LocalStream {
62
    /// The writing side of the channel that we use to implement this
63
    /// stream.
64
    ///
65
    /// The reading side is held by the other linked stream.
66
    w: mpsc::Sender<IoResult<Vec<u8>>>,
67
    /// The reading side of the channel that we use to implement this
68
    /// stream.
69
    ///
70
    /// The writing side is held by the other linked stream.
71
    r: mpsc::Receiver<IoResult<Vec<u8>>>,
72
    /// Bytes that we have read from `r` but not yet delivered.
73
    pending_bytes: Vec<u8>,
74
    /// Data about the other side of this stream's fake TLS certificate, if any.
75
    /// If this is present, I/O operations will fail with an error.
76
    ///
77
    /// How this is intended to work: things that return `LocalStream`s that could potentially
78
    /// be connected to a fake TLS listener should set this field. Then, a fake TLS wrapper
79
    /// type would clear this field (after checking its contents are as expected).
80
    ///
81
    /// FIXME(eta): this is a bit of a layering violation, but it's hard to do otherwise
82
    pub(crate) tls_cert: Option<Vec<u8>>,
83
}
84

            
85
/// Helper: pull bytes off the front of `pending_bytes` and put them
86
/// onto `buf.  Return the number of bytes moved.
87
12218
fn drain_helper(buf: &mut [u8], pending_bytes: &mut Vec<u8>) -> usize {
88
12218
    let n_to_drain = std::cmp::min(buf.len(), pending_bytes.len());
89
12218
    buf[..n_to_drain].copy_from_slice(&pending_bytes[..n_to_drain]);
90
12218
    pending_bytes.drain(..n_to_drain);
91
12218
    n_to_drain
92
12218
}
93

            
94
impl AsyncRead for LocalStream {
95
16366
    fn poll_read(
96
16366
        mut self: Pin<&mut Self>,
97
16366
        cx: &mut Context<'_>,
98
16366
        buf: &mut [u8],
99
16366
    ) -> Poll<IoResult<usize>> {
100
16366
        if buf.is_empty() {
101
            return Poll::Ready(Ok(0));
102
16366
        }
103
16366
        if self.tls_cert.is_some() {
104
            return Poll::Ready(Err(std::io::Error::new(
105
                std::io::ErrorKind::Other,
106
                "attempted to treat a TLS stream as non-TLS!",
107
            )));
108
16366
        }
109
16366
        if !self.pending_bytes.is_empty() {
110
6652
            return Poll::Ready(Ok(drain_helper(buf, &mut self.pending_bytes)));
111
9714
        }
112

            
113
9714
        match futures::ready!(Pin::new(&mut self.r).poll_next(cx)) {
114
2
            Some(Err(e)) => Poll::Ready(Err(e)),
115
5566
            Some(Ok(bytes)) => {
116
5566
                self.pending_bytes = bytes;
117
5566
                let n = drain_helper(buf, &mut self.pending_bytes);
118
5566
                Poll::Ready(Ok(n))
119
            }
120
686
            None => Poll::Ready(Ok(0)), // This is an EOF
121
        }
122
16366
    }
123
}
124

            
125
impl AsyncWrite for LocalStream {
126
6732
    fn poll_write(
127
6732
        mut self: Pin<&mut Self>,
128
6732
        cx: &mut Context<'_>,
129
6732
        buf: &[u8],
130
6732
    ) -> Poll<IoResult<usize>> {
131
6732
        if self.tls_cert.is_some() {
132
            return Poll::Ready(Err(std::io::Error::new(
133
                std::io::ErrorKind::Other,
134
                "attempted to treat a TLS stream as non-TLS!",
135
            )));
136
6732
        }
137

            
138
6732
        match futures::ready!(Pin::new(&mut self.w).poll_ready(cx)) {
139
5760
            Ok(()) => (),
140
2
            Err(e) => return Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))),
141
        }
142

            
143
5760
        let buf = if buf.len() > CHUNKSZ {
144
4886
            &buf[..CHUNKSZ]
145
        } else {
146
874
            buf
147
        };
148
5760
        let len = buf.len();
149
5760
        match Pin::new(&mut self.w).start_send(Ok(buf.to_vec())) {
150
5760
            Ok(()) => Poll::Ready(Ok(len)),
151
            Err(e) => Poll::Ready(Err(IoError::new(ErrorKind::BrokenPipe, e))),
152
        }
153
6732
    }
154
414
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
155
414
        Pin::new(&mut self.w)
156
414
            .poll_flush(cx)
157
414
            .map_err(|e| IoError::new(ErrorKind::BrokenPipe, e))
158
414
    }
159
686
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
160
686
        Pin::new(&mut self.w)
161
686
            .poll_close(cx)
162
686
            .map_err(|e| IoError::new(ErrorKind::Other, e))
163
686
    }
164
}
165

            
166
impl StreamOps for LocalStream {
167
    fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
168
        Err(
169
            UnsupportedStreamOp::new("set_tcp_notsent_lowat", "unsupported on local streams")
170
                .into(),
171
        )
172
    }
173
}
174

            
175
/// An error generated by [`LocalStream::send_err`].
176
#[derive(Debug, Clone, Eq, PartialEq)]
177
#[non_exhaustive]
178
pub struct SyntheticError;
179
impl std::error::Error for SyntheticError {}
180
impl std::fmt::Display for SyntheticError {
181
2
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182
2
        write!(f, "Synthetic error")
183
2
    }
184
}
185

            
186
impl LocalStream {
187
    /// Send an error to the other linked local stream.
188
    ///
189
    /// When the other stream reads this message, it will generate a
190
    /// [`std::io::Error`] with the provided `ErrorKind`.
191
3
    pub async fn send_err(&mut self, kind: ErrorKind) {
192
2
        let _ignore = self.w.send(Err(IoError::new(kind, SyntheticError))).await;
193
2
    }
194
}
195

            
196
#[cfg(all(test, not(miri)))] // These tests are very slow under miri
197
mod test {
198
    // @@ begin test lint list maintained by maint/add_warning @@
199
    #![allow(clippy::bool_assert_comparison)]
200
    #![allow(clippy::clone_on_copy)]
201
    #![allow(clippy::dbg_macro)]
202
    #![allow(clippy::mixed_attributes_style)]
203
    #![allow(clippy::print_stderr)]
204
    #![allow(clippy::print_stdout)]
205
    #![allow(clippy::single_char_pattern)]
206
    #![allow(clippy::unwrap_used)]
207
    #![allow(clippy::unchecked_duration_subtraction)]
208
    #![allow(clippy::useless_vec)]
209
    #![allow(clippy::needless_pass_by_value)]
210
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
211
    use super::*;
212

            
213
    use futures::io::{AsyncReadExt, AsyncWriteExt};
214
    use futures_await_test::async_test;
215
    use rand::Rng;
216
    use tor_basic_utils::test_rng::testing_rng;
217

            
218
    #[async_test]
219
    async fn basic_rw() {
220
        let (mut s1, mut s2) = stream_pair();
221
        let mut text1 = vec![0_u8; 9999];
222
        testing_rng().fill(&mut text1[..]);
223

            
224
        let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
225
            async {
226
                for _ in 0_u8..10 {
227
                    s1.write_all(&text1[..]).await?;
228
                }
229
                s1.close().await?;
230
                Ok(())
231
            },
232
            async {
233
                let mut text2: Vec<u8> = Vec::new();
234
                let mut buf = [0_u8; 33];
235
                loop {
236
                    let n = s2.read(&mut buf[..]).await?;
237
                    if n == 0 {
238
                        break;
239
                    }
240
                    text2.extend(&buf[..n]);
241
                }
242
                for ch in text2[..].chunks(text1.len()) {
243
                    assert_eq!(ch, &text1[..]);
244
                }
245
                Ok(())
246
            }
247
        );
248

            
249
        v1.unwrap();
250
        v2.unwrap();
251
    }
252

            
253
    #[async_test]
254
    async fn send_error() {
255
        let (mut s1, mut s2) = stream_pair();
256

            
257
        let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
258
            async {
259
                s1.write_all(b"hello world").await?;
260
                s1.send_err(ErrorKind::PermissionDenied).await;
261
                Ok(())
262
            },
263
            async {
264
                let mut buf = [0_u8; 33];
265
                loop {
266
                    let n = s2.read(&mut buf[..]).await?;
267
                    if n == 0 {
268
                        break;
269
                    }
270
                }
271
                Ok(())
272
            }
273
        );
274

            
275
        v1.unwrap();
276
        let e = v2.err().unwrap();
277
        assert_eq!(e.kind(), ErrorKind::PermissionDenied);
278
        let synth = e.into_inner().unwrap();
279
        assert_eq!(synth.to_string(), "Synthetic error");
280
    }
281

            
282
    #[async_test]
283
    async fn drop_reader() {
284
        let (mut s1, s2) = stream_pair();
285

            
286
        let (v1, v2): (IoResult<()>, IoResult<()>) = futures::join!(
287
            async {
288
                for _ in 0_u16..1000 {
289
                    s1.write_all(&[9_u8; 9999]).await?;
290
                }
291
                Ok(())
292
            },
293
            async {
294
                drop(s2);
295
                Ok(())
296
            }
297
        );
298

            
299
        v2.unwrap();
300
        let e = v1.err().unwrap();
301
        assert_eq!(e.kind(), ErrorKind::BrokenPipe);
302
    }
303
}