tor_proto/stream/
data.rs

1//! Declare DataStream, a type that wraps RawCellStream so as to be useful
2//! for byte-oriented communication.
3
4use crate::{Error, Result};
5use static_assertions::assert_impl_all;
6use tor_cell::relaycell::msg::EndReason;
7use tor_cell::relaycell::{RelayCellFormat, RelayCmd};
8
9use futures::io::{AsyncRead, AsyncWrite};
10use futures::stream::StreamExt;
11use futures::task::{Context, Poll};
12use futures::{Future, Stream};
13use pin_project::pin_project;
14use postage::watch;
15
16#[cfg(feature = "tokio")]
17use tokio_crate::io::ReadBuf;
18#[cfg(feature = "tokio")]
19use tokio_crate::io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite};
20#[cfg(feature = "tokio")]
21use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt};
22use tor_cell::restricted_msg;
23
24use std::fmt::Debug;
25use std::io::Result as IoResult;
26use std::num::NonZero;
27use std::pin::Pin;
28#[cfg(any(feature = "stream-ctrl", feature = "experimental-api"))]
29use std::sync::Arc;
30#[cfg(feature = "stream-ctrl")]
31use std::sync::{Mutex, Weak};
32
33use educe::Educe;
34
35#[cfg(any(feature = "experimental-api", feature = "stream-ctrl"))]
36use crate::tunnel::circuit::ClientCirc;
37
38use crate::memquota::StreamAccount;
39use crate::stream::{StreamRateLimit, StreamReceiver};
40use crate::tunnel::StreamTarget;
41use crate::util::token_bucket::dynamic_writer::DynamicRateLimitedWriter;
42use crate::util::token_bucket::writer::{RateLimitedWriter, RateLimitedWriterConfig};
43use tor_basic_utils::skip_fmt;
44use tor_cell::relaycell::msg::Data;
45use tor_error::internal;
46use tor_rtcompat::{CoarseTimeProvider, DynTimeProvider, SleepProvider};
47
48use super::AnyCmdChecker;
49
50/// A stream of [`RateLimitedWriterConfig`] used to update a [`DynamicRateLimitedWriter`].
51///
52/// Unfortunately we need to store the result of a [`StreamExt::map`] and [`StreamExt::fuse`] in
53/// [`DataWriter`], which leaves us with this ugly type.
54/// We use a type alias to make `DataWriter` a little nicer.
55type RateConfigStream = futures::stream::Map<
56    futures::stream::Fuse<watch::Receiver<StreamRateLimit>>,
57    fn(StreamRateLimit) -> RateLimitedWriterConfig,
58>;
59
60/// An anonymized stream over the Tor network.
61///
62/// For most purposes, you can think of this type as an anonymized
63/// TCP stream: it can read and write data, and get closed when it's done.
64///
65/// [`DataStream`] implements [`futures::io::AsyncRead`] and
66/// [`futures::io::AsyncWrite`], so you can use it anywhere that those
67/// traits are expected.
68///
69/// # Examples
70///
71/// Connecting to an HTTP server and sending a request, using
72/// [`AsyncWriteExt::write_all`](futures::io::AsyncWriteExt::write_all):
73///
74/// ```ignore
75/// let mut stream = tor_client.connect(("icanhazip.com", 80), None).await?;
76///
77/// use futures::io::AsyncWriteExt;
78///
79/// stream
80///     .write_all(b"GET / HTTP/1.1\r\nHost: icanhazip.com\r\nConnection: close\r\n\r\n")
81///     .await?;
82///
83/// // Flushing the stream is important; see below!
84/// stream.flush().await?;
85/// ```
86///
87/// Reading the result, using [`AsyncReadExt::read_to_end`](futures::io::AsyncReadExt::read_to_end):
88///
89/// ```ignore
90/// use futures::io::AsyncReadExt;
91///
92/// let mut buf = Vec::new();
93/// stream.read_to_end(&mut buf).await?;
94///
95/// println!("{}", String::from_utf8_lossy(&buf));
96/// ```
97///
98/// # Usage with Tokio
99///
100/// If the `tokio` crate feature is enabled, this type also implements
101/// [`tokio::io::AsyncRead`](tokio_crate::io::AsyncRead) and
102/// [`tokio::io::AsyncWrite`](tokio_crate::io::AsyncWrite) for easier integration
103/// with code that expects those traits.
104///
105/// # Remember to call `flush`!
106///
107/// DataStream buffers data internally, in order to write as few cells
108/// as possible onto the network.  In order to make sure that your
109/// data has actually been sent, you need to make sure that
110/// [`AsyncWrite::poll_flush`] runs to completion: probably via
111/// [`AsyncWriteExt::flush`](futures::io::AsyncWriteExt::flush).
112///
113/// # Splitting the type
114///
115/// This type is internally composed of a [`DataReader`] and a [`DataWriter`]; the
116/// `DataStream::split` method can be used to split it into those two parts, for more
117/// convenient usage with e.g. stream combinators.
118///
119/// # How long does a stream live?
120///
121/// A `DataStream` will live until all references to it are dropped,
122/// or until it is closed explicitly.
123///
124/// If you split the stream into a `DataReader` and a `DataWriter`, it
125/// will survive until _both_ are dropped, or until it is closed
126/// explicitly.
127///
128/// A stream can also close because of a network error,
129/// or because the other side of the stream decided to close it.
130///
131// # Semver note
132//
133// Note that this type is re-exported as a part of the public API of
134// the `arti-client` crate.  Any changes to its API here in
135// `tor-proto` need to be reflected above.
136#[derive(Debug)]
137pub struct DataStream {
138    /// Underlying writer for this stream
139    w: DataWriter,
140    /// Underlying reader for this stream
141    r: DataReader,
142    /// A control object that can be used to monitor and control this stream
143    /// without needing to own it.
144    #[cfg(feature = "stream-ctrl")]
145    ctrl: std::sync::Arc<ClientDataStreamCtrl>,
146}
147assert_impl_all! { DataStream: Send, Sync }
148
149/// An object used to control and monitor a data stream.
150///
151/// # Notes
152///
153/// This is a separate type from [`DataStream`] because it's useful to have
154/// multiple references to this object, whereas a [`DataReader`] and [`DataWriter`]
155/// need to have a single owner for the `AsyncRead` and `AsyncWrite` APIs to
156/// work correctly.
157#[cfg(feature = "stream-ctrl")]
158#[derive(Debug)]
159pub struct ClientDataStreamCtrl {
160    /// The circuit to which this stream is attached.
161    ///
162    /// Note that the stream's reader and writer halves each contain a `StreamTarget`,
163    /// which in turn has a strong reference to the `ClientCirc`.  So as long as any
164    /// one of those is alive, this reference will be present.
165    ///
166    /// We make this a Weak reference so that once the stream itself is closed,
167    /// we can't leak circuits.
168    // TODO(conflux): use ClientTunnel
169    circuit: Weak<ClientCirc>,
170
171    /// Shared user-visible information about the state of this stream.
172    ///
173    /// TODO RPC: This will probably want to be a `postage::Watch` or something
174    /// similar, if and when it stops moving around.
175    #[cfg(feature = "stream-ctrl")]
176    status: Arc<Mutex<DataStreamStatus>>,
177
178    /// The memory quota account that should be used for this stream's data
179    ///
180    /// Exists to keep the account alive
181    _memquota: StreamAccount,
182}
183
184/// The inner writer for [`DataWriter`].
185///
186/// This type is responsible for taking bytes and packaging them into cells.
187/// Rate limiting is implemented in [`DataWriter`] to avoid making this type more complex.
188#[derive(Debug)]
189struct DataWriterInner {
190    /// Internal state for this writer
191    ///
192    /// This is stored in an Option so that we can mutate it in the
193    /// AsyncWrite functions.  It might be possible to do better here,
194    /// and we should refactor if so.
195    state: Option<DataWriterState>,
196
197    /// The memory quota account that should be used for this stream's data
198    ///
199    /// Exists to keep the account alive
200    // If we liked, we could make this conditional; see DataReader.memquota
201    _memquota: StreamAccount,
202
203    /// A control object that can be used to monitor and control this stream
204    /// without needing to own it.
205    #[cfg(feature = "stream-ctrl")]
206    ctrl: std::sync::Arc<ClientDataStreamCtrl>,
207}
208
209/// The write half of a [`DataStream`], implementing [`futures::io::AsyncWrite`].
210///
211/// See the [`DataStream`] docs for more information. In particular, note
212/// that this writer requires `poll_flush` to complete in order to guarantee that
213/// all data has been written.
214///
215/// # Usage with Tokio
216///
217/// If the `tokio` crate feature is enabled, this type also implements
218/// [`tokio::io::AsyncWrite`](tokio_crate::io::AsyncWrite) for easier integration
219/// with code that expects that trait.
220///
221/// # Drop and close
222///
223/// Note that dropping a `DataWriter` has no special effect on its own:
224/// if the `DataWriter` is dropped, the underlying stream will still remain open
225/// until the `DataReader` is also dropped.
226///
227/// If you want the stream to close earlier, use [`close`](futures::io::AsyncWriteExt::close)
228/// (or [`shutdown`](tokio_crate::io::AsyncWriteExt::shutdown) with `tokio`).
229///
230/// Remember that Tor does not support half-open streams:
231/// If you `close` or `shutdown` a stream,
232/// the other side will not see the stream as half-open,
233/// and so will (probably) not finish sending you any in-progress data.
234/// Do not use `close`/`shutdown` to communicate anything besides
235/// "I am done using this stream."
236///
237// # Semver note
238//
239// Note that this type is re-exported as a part of the public API of
240// the `arti-client` crate.  Any changes to its API here in
241// `tor-proto` need to be reflected above.
242#[derive(Debug)]
243pub struct DataWriter {
244    /// A wrapper around [`DataWriterInner`] that adds rate limiting.
245    writer: DynamicRateLimitedWriter<DataWriterInner, RateConfigStream, DynTimeProvider>,
246}
247
248impl DataWriter {
249    /// Create a new rate-limited [`DataWriter`] from a [`DataWriterInner`].
250    fn new(
251        inner: DataWriterInner,
252        rate_limit_updates: watch::Receiver<StreamRateLimit>,
253        time_provider: DynTimeProvider,
254    ) -> Self {
255        /// Converts a `rate` into a `RateLimitedWriterConfig`.
256        fn rate_to_config(rate: StreamRateLimit) -> RateLimitedWriterConfig {
257            let rate = rate.bytes_per_sec();
258            RateLimitedWriterConfig {
259                rate,        // bytes per second
260                burst: rate, // bytes
261                // This number is chosen arbitrarily, but the idea is that we want to balance
262                // between throughput and latency. Assume the user tries to write a large buffer
263                // (~600 bytes). If we set this too small (for example 1), we'll be waking up
264                // frequently and writing a small number of bytes each time to the
265                // `DataWriterInner`, even if this isn't enough bytes to send a cell. If we set this
266                // too large (for example 510), we'll be waking up infrequently to write a larger
267                // number of bytes each time. So even if the `DataWriterInner` has almost a full
268                // cell's worth of data queued (for example 490) and only needs 509-490=19 more
269                // bytes before a cell can be sent, it will block until the rate limiter allows 510
270                // more bytes.
271                //
272                // TODO(arti#2028): Is there an optimal value here?
273                wake_when_bytes_available: NonZero::new(200).expect("200 != 0"), // bytes
274            }
275        }
276
277        // get the current rate from the `watch::Receiver`, which we'll use as the initial rate
278        let initial_rate: StreamRateLimit = *rate_limit_updates.borrow();
279
280        // map the rate update stream to the type required by `DynamicRateLimitedWriter`
281        let rate_limit_updates = rate_limit_updates.fuse().map(rate_to_config as fn(_) -> _);
282
283        // build the rate limiter
284        let writer = RateLimitedWriter::new(inner, &rate_to_config(initial_rate), time_provider);
285        let writer = DynamicRateLimitedWriter::new(writer, rate_limit_updates);
286
287        Self { writer }
288    }
289
290    /// Return a [`ClientDataStreamCtrl`] object that can be used to monitor and
291    /// interact with this stream without holding the stream itself.
292    #[cfg(feature = "stream-ctrl")]
293    pub fn client_stream_ctrl(&self) -> Option<&Arc<ClientDataStreamCtrl>> {
294        Some(self.writer.inner().client_stream_ctrl())
295    }
296}
297
298impl AsyncWrite for DataWriter {
299    fn poll_write(
300        mut self: Pin<&mut Self>,
301        cx: &mut Context<'_>,
302        buf: &[u8],
303    ) -> Poll<IoResult<usize>> {
304        AsyncWrite::poll_write(Pin::new(&mut self.writer), cx, buf)
305    }
306
307    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
308        AsyncWrite::poll_flush(Pin::new(&mut self.writer), cx)
309    }
310
311    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
312        AsyncWrite::poll_close(Pin::new(&mut self.writer), cx)
313    }
314}
315
316#[cfg(feature = "tokio")]
317impl TokioAsyncWrite for DataWriter {
318    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
319        TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
320    }
321
322    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
323        TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
324    }
325
326    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
327        TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
328    }
329}
330
331/// The read half of a [`DataStream`], implementing [`futures::io::AsyncRead`].
332///
333/// See the [`DataStream`] docs for more information.
334///
335/// # Usage with Tokio
336///
337/// If the `tokio` crate feature is enabled, this type also implements
338/// [`tokio::io::AsyncRead`](tokio_crate::io::AsyncRead) for easier integration
339/// with code that expects that trait.
340//
341// # Semver note
342//
343// Note that this type is re-exported as a part of the public API of
344// the `arti-client` crate.  Any changes to its API here in
345// `tor-proto` need to be reflected above.
346#[derive(Debug)]
347pub struct DataReader {
348    /// Internal state for this reader.
349    ///
350    /// This is stored in an Option so that we can mutate it in
351    /// poll_read().  It might be possible to do better here, and we
352    /// should refactor if so.
353    state: Option<DataReaderState>,
354
355    /// The memory quota account that should be used for this stream's data
356    ///
357    /// Exists to keep the account alive
358    // If we liked, we could make this conditional on not(cfg(feature = "stream-ctrl"))
359    // since, ClientDataStreamCtrl contains a StreamAccount clone too.  But that seems fragile.
360    _memquota: StreamAccount,
361
362    /// A control object that can be used to monitor and control this stream
363    /// without needing to own it.
364    #[cfg(feature = "stream-ctrl")]
365    ctrl: std::sync::Arc<ClientDataStreamCtrl>,
366}
367
368/// Shared status flags for tracking the status of as `DataStream`.
369///
370/// We expect to refactor this a bit, so it's not exposed at all.
371//
372// TODO RPC: Possibly instead of manipulating the fields of DataStreamStatus
373// from various points in this module, we should instead construct
374// DataStreamStatus as needed from information available elsewhere.  In any
375// case, we should really  eliminate as much duplicate state here as we can.
376// (See discussions at !1198 for some challenges with this.)
377#[cfg(feature = "stream-ctrl")]
378#[derive(Clone, Debug, Default)]
379struct DataStreamStatus {
380    /// True if we've received a CONNECTED message.
381    //
382    // TODO: This is redundant with `connected` in DataReaderImpl and
383    // `expecting_connected` in DataCmdChecker.
384    received_connected: bool,
385    /// True if we have decided to send an END message.
386    //
387    // TODO RPC: There is not an easy way to set this from this module!  Really,
388    // the decision to send an "end" is made when the StreamTarget object is
389    // dropped, but we don't currently have any way to see when that happens.
390    // Perhaps we need a different shared StreamStatus object that the
391    // StreamTarget holds?
392    sent_end: bool,
393    /// True if we have received an END message telling us to close the stream.
394    received_end: bool,
395    /// True if we have received an error.
396    ///
397    /// (This is not a subset or superset of received_end; some errors are END
398    /// messages but some aren't; some END messages are errors but some aren't.)
399    received_err: bool,
400}
401
402#[cfg(feature = "stream-ctrl")]
403impl DataStreamStatus {
404    /// Remember that we've received a connected message.
405    fn record_connected(&mut self) {
406        self.received_connected = true;
407    }
408
409    /// Remember that we've received an error of some kind.
410    fn record_error(&mut self, e: &Error) {
411        // TODO: Probably we should remember the actual error in a box or
412        // something.  But that means making a redundant copy of the error
413        // even if nobody will want it.  Do we care?
414        match e {
415            Error::EndReceived(EndReason::DONE) => self.received_end = true,
416            Error::EndReceived(_) => {
417                self.received_end = true;
418                self.received_err = true;
419            }
420            _ => self.received_err = true,
421        }
422    }
423}
424
425restricted_msg! {
426    /// An allowable incoming message on a data stream.
427    enum DataStreamMsg:RelayMsg {
428        // SENDME is handled by the reactor.
429        Data, End, Connected,
430    }
431}
432
433// TODO RPC: Should we also implement this trait for everything that holds a
434// ClientDataStreamCtrl?
435#[cfg(feature = "stream-ctrl")]
436impl super::ctrl::ClientStreamCtrl for ClientDataStreamCtrl {
437    // TODO(conflux): use ClientTunnel
438    fn circuit(&self) -> Option<Arc<ClientCirc>> {
439        self.circuit.upgrade()
440    }
441}
442
443#[cfg(feature = "stream-ctrl")]
444impl ClientDataStreamCtrl {
445    /// Return true if the underlying stream is connected. (That is, if it has
446    /// received a `CONNECTED` message, and has not been closed.)
447    pub fn is_connected(&self) -> bool {
448        let s = self.status.lock().expect("poisoned lock");
449        s.received_connected && !(s.sent_end || s.received_end || s.received_err)
450    }
451
452    // TODO RPC: Add more functions once we have the desired API more nailed
453    // down.
454}
455
456impl DataStream {
457    /// Wrap raw stream receiver and target parts as a DataStream.
458    ///
459    /// For non-optimistic stream, function `wait_for_connection`
460    /// must be called after to make sure CONNECTED is received.
461    pub(crate) fn new<P: SleepProvider + CoarseTimeProvider>(
462        time_provider: P,
463        receiver: StreamReceiver,
464        target: StreamTarget,
465        memquota: StreamAccount,
466    ) -> Self {
467        Self::new_inner(time_provider, receiver, target, false, memquota)
468    }
469
470    /// Wrap raw stream receiver and target parts as a connected DataStream.
471    ///
472    /// Unlike [`DataStream::new`], this creates a `DataStream` that does not expect to receive a
473    /// CONNECTED cell.
474    ///
475    /// This is used by hidden services, exit relays, and directory servers to accept streams.
476    #[cfg(feature = "hs-service")]
477    pub(crate) fn new_connected<P: SleepProvider + CoarseTimeProvider>(
478        time_provider: P,
479        receiver: StreamReceiver,
480        target: StreamTarget,
481        memquota: StreamAccount,
482    ) -> Self {
483        Self::new_inner(time_provider, receiver, target, true, memquota)
484    }
485
486    /// The shared implementation of the `new*()` functions.
487    fn new_inner<P: SleepProvider + CoarseTimeProvider>(
488        time_provider: P,
489        receiver: StreamReceiver,
490        target: StreamTarget,
491        connected: bool,
492        memquota: StreamAccount,
493    ) -> Self {
494        let relay_cell_format = target.relay_cell_format();
495        let out_buf_len = Data::max_body_len(relay_cell_format);
496        let rate_limit_stream = target.rate_limit_stream().clone();
497
498        #[cfg(feature = "stream-ctrl")]
499        let status = {
500            let mut data_stream_status = DataStreamStatus::default();
501            if connected {
502                data_stream_status.record_connected();
503            }
504            Arc::new(Mutex::new(data_stream_status))
505        };
506
507        #[cfg(feature = "stream-ctrl")]
508        let ctrl = Arc::new(ClientDataStreamCtrl {
509            circuit: Arc::downgrade(target.circuit()),
510            status: status.clone(),
511            _memquota: memquota.clone(),
512        });
513        let r = DataReader {
514            state: Some(DataReaderState::Open(DataReaderImpl {
515                s: receiver,
516                pending: Vec::new(),
517                offset: 0,
518                connected,
519                #[cfg(feature = "stream-ctrl")]
520                status: status.clone(),
521            })),
522            _memquota: memquota.clone(),
523            #[cfg(feature = "stream-ctrl")]
524            ctrl: ctrl.clone(),
525        };
526        let w = DataWriterInner {
527            state: Some(DataWriterState::Ready(DataWriterImpl {
528                s: target,
529                buf: vec![0; out_buf_len].into_boxed_slice(),
530                n_pending: 0,
531                #[cfg(feature = "stream-ctrl")]
532                status,
533                relay_cell_format,
534            })),
535            _memquota: memquota,
536            #[cfg(feature = "stream-ctrl")]
537            ctrl: ctrl.clone(),
538        };
539
540        let time_provider = DynTimeProvider::new(time_provider);
541
542        DataStream {
543            w: DataWriter::new(w, rate_limit_stream, time_provider),
544            r,
545            #[cfg(feature = "stream-ctrl")]
546            ctrl,
547        }
548    }
549
550    /// Divide this DataStream into its constituent parts.
551    pub fn split(self) -> (DataReader, DataWriter) {
552        (self.r, self.w)
553    }
554
555    /// Wait until a CONNECTED cell is received, or some other cell
556    /// is received to indicate an error.
557    ///
558    /// Does nothing if this stream is already connected.
559    pub async fn wait_for_connection(&mut self) -> Result<()> {
560        // We must put state back before returning
561        let state = self.r.state.take().expect("Missing state in DataReader");
562
563        if let DataReaderState::Open(mut imp) = state {
564            let result = if imp.connected {
565                Ok(())
566            } else {
567                // This succeeds if the cell is CONNECTED, and fails otherwise.
568                std::future::poll_fn(|cx| Pin::new(&mut imp).read_cell(cx)).await
569            };
570            self.r.state = Some(match result {
571                Err(_) => DataReaderState::Closed,
572                Ok(_) => DataReaderState::Open(imp),
573            });
574            result
575        } else {
576            Err(Error::from(internal!(
577                "Expected ready state, got {:?}",
578                state
579            )))
580        }
581    }
582
583    /// Return a [`ClientDataStreamCtrl`] object that can be used to monitor and
584    /// interact with this stream without holding the stream itself.
585    #[cfg(feature = "stream-ctrl")]
586    pub fn client_stream_ctrl(&self) -> Option<&Arc<ClientDataStreamCtrl>> {
587        Some(&self.ctrl)
588    }
589}
590
591impl AsyncRead for DataStream {
592    fn poll_read(
593        mut self: Pin<&mut Self>,
594        cx: &mut Context<'_>,
595        buf: &mut [u8],
596    ) -> Poll<IoResult<usize>> {
597        AsyncRead::poll_read(Pin::new(&mut self.r), cx, buf)
598    }
599}
600
601#[cfg(feature = "tokio")]
602impl TokioAsyncRead for DataStream {
603    fn poll_read(
604        self: Pin<&mut Self>,
605        cx: &mut Context<'_>,
606        buf: &mut ReadBuf<'_>,
607    ) -> Poll<IoResult<()>> {
608        TokioAsyncRead::poll_read(Pin::new(&mut self.compat()), cx, buf)
609    }
610}
611
612impl AsyncWrite for DataStream {
613    fn poll_write(
614        mut self: Pin<&mut Self>,
615        cx: &mut Context<'_>,
616        buf: &[u8],
617    ) -> Poll<IoResult<usize>> {
618        AsyncWrite::poll_write(Pin::new(&mut self.w), cx, buf)
619    }
620    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
621        AsyncWrite::poll_flush(Pin::new(&mut self.w), cx)
622    }
623    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
624        AsyncWrite::poll_close(Pin::new(&mut self.w), cx)
625    }
626}
627
628#[cfg(feature = "tokio")]
629impl TokioAsyncWrite for DataStream {
630    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
631        TokioAsyncWrite::poll_write(Pin::new(&mut self.compat()), cx, buf)
632    }
633
634    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
635        TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat()), cx)
636    }
637
638    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
639        TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat()), cx)
640    }
641}
642
643/// Helper type: Like BoxFuture, but also requires that the future be Sync.
644type BoxSyncFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + Sync + 'a>>;
645
646/// An enumeration for the state of a DataWriter.
647///
648/// We have to use an enum here because, for as long as we're waiting
649/// for a flush operation to complete, the future returned by
650/// `flush_cell()` owns the DataWriterImpl.
651#[derive(Educe)]
652#[educe(Debug)]
653enum DataWriterState {
654    /// The writer has closed or gotten an error: nothing more to do.
655    Closed,
656    /// The writer is not currently flushing; more data can get queued
657    /// immediately.
658    Ready(DataWriterImpl),
659    /// The writer is flushing a cell.
660    Flushing(
661        #[educe(Debug(method = "skip_fmt"))] //
662        BoxSyncFuture<'static, (DataWriterImpl, Result<()>)>,
663    ),
664}
665
666/// Internal: the write part of a DataStream
667#[derive(Educe)]
668#[educe(Debug)]
669struct DataWriterImpl {
670    /// The underlying StreamTarget object.
671    s: StreamTarget,
672
673    /// Buffered data to send over the connection.
674    ///
675    /// This buffer is currently allocated using a number of bytes
676    /// equal to the maximum that we can package at a time.
677    //
678    // TODO: this buffer is probably smaller than we want, but it's good
679    // enough for now.  If we _do_ make it bigger, we'll have to change
680    // our use of Data::split_from to handle the case where we can't fit
681    // all the data.
682    #[educe(Debug(method = "skip_fmt"))]
683    buf: Box<[u8]>,
684
685    /// Number of unflushed bytes in buf.
686    n_pending: usize,
687
688    /// Relay cell format in use
689    relay_cell_format: RelayCellFormat,
690
691    /// Shared user-visible information about the state of this stream.
692    #[cfg(feature = "stream-ctrl")]
693    status: Arc<Mutex<DataStreamStatus>>,
694}
695
696impl DataWriterInner {
697    /// See [`DataWriter::client_stream_ctrl`].
698    #[cfg(feature = "stream-ctrl")]
699    fn client_stream_ctrl(&self) -> &Arc<ClientDataStreamCtrl> {
700        &self.ctrl
701    }
702
703    /// Helper for poll_flush() and poll_close(): Performs a flush, then
704    /// closes the stream if should_close is true.
705    fn poll_flush_impl(
706        mut self: Pin<&mut Self>,
707        cx: &mut Context<'_>,
708        should_close: bool,
709    ) -> Poll<IoResult<()>> {
710        let state = self.state.take().expect("Missing state in DataWriter");
711
712        // TODO: this whole function is a bit copy-pasted.
713        let mut future: BoxSyncFuture<_> = match state {
714            DataWriterState::Ready(imp) => {
715                if imp.n_pending == 0 {
716                    // Nothing to flush!
717                    if should_close {
718                        // We need to actually continue with this function to do the closing.
719                        // Thus, make a future that does nothing and is ready immediately.
720                        Box::pin(futures::future::ready((imp, Ok(()))))
721                    } else {
722                        // There's nothing more to do; we can return.
723                        self.state = Some(DataWriterState::Ready(imp));
724                        return Poll::Ready(Ok(()));
725                    }
726                } else {
727                    // We need to flush the buffer's contents; Make a future for that.
728                    Box::pin(imp.flush_buf())
729                }
730            }
731            DataWriterState::Flushing(fut) => fut,
732            DataWriterState::Closed => {
733                self.state = Some(DataWriterState::Closed);
734                return Poll::Ready(Err(Error::NotConnected.into()));
735            }
736        };
737
738        match future.as_mut().poll(cx) {
739            Poll::Ready((_imp, Err(e))) => {
740                self.state = Some(DataWriterState::Closed);
741                Poll::Ready(Err(e.into()))
742            }
743            Poll::Ready((mut imp, Ok(()))) => {
744                if should_close {
745                    // Tell the StreamTarget to close, so that the reactor
746                    // realizes that we are done sending. (Dropping `imp.s` does not
747                    // suffice, since there may be other clones of it.  In particular,
748                    // the StreamReceiver has one, which it uses to keep the stream
749                    // open, among other things.)
750                    imp.s.close();
751
752                    #[cfg(feature = "stream-ctrl")]
753                    {
754                        // TODO RPC:  This is not sufficient to track every case
755                        // where we might have sent an End.  See note on the
756                        // `sent_end` field.
757                        imp.status.lock().expect("lock poisoned").sent_end = true;
758                    }
759                    self.state = Some(DataWriterState::Closed);
760                } else {
761                    self.state = Some(DataWriterState::Ready(imp));
762                }
763                Poll::Ready(Ok(()))
764            }
765            Poll::Pending => {
766                self.state = Some(DataWriterState::Flushing(future));
767                Poll::Pending
768            }
769        }
770    }
771}
772
773impl AsyncWrite for DataWriterInner {
774    fn poll_write(
775        mut self: Pin<&mut Self>,
776        cx: &mut Context<'_>,
777        buf: &[u8],
778    ) -> Poll<IoResult<usize>> {
779        if buf.is_empty() {
780            return Poll::Ready(Ok(0));
781        }
782
783        let state = self.state.take().expect("Missing state in DataWriter");
784
785        let mut future = match state {
786            DataWriterState::Ready(mut imp) => {
787                let n_queued = imp.queue_bytes(buf);
788                if n_queued != 0 {
789                    self.state = Some(DataWriterState::Ready(imp));
790                    return Poll::Ready(Ok(n_queued));
791                }
792                // we couldn't queue anything, so the current cell must be full.
793                Box::pin(imp.flush_buf())
794            }
795            DataWriterState::Flushing(fut) => fut,
796            DataWriterState::Closed => {
797                self.state = Some(DataWriterState::Closed);
798                return Poll::Ready(Err(Error::NotConnected.into()));
799            }
800        };
801
802        match future.as_mut().poll(cx) {
803            Poll::Ready((_imp, Err(e))) => {
804                #[cfg(feature = "stream-ctrl")]
805                {
806                    _imp.status.lock().expect("lock poisoned").record_error(&e);
807                }
808                self.state = Some(DataWriterState::Closed);
809                Poll::Ready(Err(e.into()))
810            }
811            Poll::Ready((mut imp, Ok(()))) => {
812                // Great!  We're done flushing.  Queue as much as we can of this
813                // cell.
814                let n_queued = imp.queue_bytes(buf);
815                self.state = Some(DataWriterState::Ready(imp));
816                Poll::Ready(Ok(n_queued))
817            }
818            Poll::Pending => {
819                self.state = Some(DataWriterState::Flushing(future));
820                Poll::Pending
821            }
822        }
823    }
824
825    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
826        self.poll_flush_impl(cx, false)
827    }
828
829    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
830        self.poll_flush_impl(cx, true)
831    }
832}
833
834#[cfg(feature = "tokio")]
835impl TokioAsyncWrite for DataWriterInner {
836    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
837        TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
838    }
839
840    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
841        TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
842    }
843
844    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
845        TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
846    }
847}
848
849impl DataWriterImpl {
850    /// Try to flush the current buffer contents as a data cell.
851    async fn flush_buf(mut self) -> (Self, Result<()>) {
852        let result = if let Some((cell, remainder)) =
853            Data::try_split_from(self.relay_cell_format, &self.buf[..self.n_pending])
854        {
855            // TODO: Eventually we may want a larger buffer; if we do,
856            // this invariant will become false.
857            assert!(remainder.is_empty());
858            self.n_pending = 0;
859            self.s.send(cell.into()).await
860        } else {
861            Ok(())
862        };
863
864        (self, result)
865    }
866
867    /// Add as many bytes as possible from `b` to our internal buffer;
868    /// return the number we were able to add.
869    fn queue_bytes(&mut self, b: &[u8]) -> usize {
870        let empty_space = &mut self.buf[self.n_pending..];
871        if empty_space.is_empty() {
872            // that is, len == 0
873            return 0;
874        }
875
876        let n_to_copy = std::cmp::min(b.len(), empty_space.len());
877        empty_space[..n_to_copy].copy_from_slice(&b[..n_to_copy]);
878        self.n_pending += n_to_copy;
879        n_to_copy
880    }
881}
882
883impl DataReader {
884    /// Return a [`ClientDataStreamCtrl`] object that can be used to monitor and
885    /// interact with this stream without holding the stream itself.
886    #[cfg(feature = "stream-ctrl")]
887    pub fn client_stream_ctrl(&self) -> Option<&Arc<ClientDataStreamCtrl>> {
888        Some(&self.ctrl)
889    }
890}
891
892/// An enumeration for the state of a [`DataReader`].
893// TODO: We don't need to implement the state in this way anymore now that we've removed the saved
894// future. There are a few ways we could simplify this. See:
895// https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/3076#note_3218210
896#[derive(Educe)]
897#[educe(Debug)]
898enum DataReaderState {
899    /// In this state we have received an end cell or an error.
900    Closed,
901    /// In this state the reader is open.
902    Open(DataReaderImpl),
903}
904
905/// Wrapper for the read part of a [`DataStream`].
906#[derive(Educe)]
907#[educe(Debug)]
908#[pin_project]
909struct DataReaderImpl {
910    /// The underlying StreamReceiver object.
911    #[educe(Debug(method = "skip_fmt"))]
912    #[pin]
913    s: StreamReceiver,
914
915    /// If present, data that we received on this stream but have not
916    /// been able to send to the caller yet.
917    // TODO: This data structure is probably not what we want, but
918    // it's good enough for now.
919    #[educe(Debug(method = "skip_fmt"))]
920    pending: Vec<u8>,
921
922    /// Index into pending to show what we've already read.
923    offset: usize,
924
925    /// If true, we have received a CONNECTED cell on this stream.
926    connected: bool,
927
928    /// Shared user-visible information about the state of this stream.
929    #[cfg(feature = "stream-ctrl")]
930    status: Arc<Mutex<DataStreamStatus>>,
931}
932
933impl AsyncRead for DataReader {
934    fn poll_read(
935        mut self: Pin<&mut Self>,
936        cx: &mut Context<'_>,
937        buf: &mut [u8],
938    ) -> Poll<IoResult<usize>> {
939        // We're pulling the state object out of the reader.  We MUST
940        // put it back before this function returns.
941        let mut state = self.state.take().expect("Missing state in DataReader");
942
943        loop {
944            let mut imp = match state {
945                DataReaderState::Open(mut imp) => {
946                    // There may be data to read already.
947                    let n_copied = imp.extract_bytes(buf);
948                    if n_copied != 0 || buf.is_empty() {
949                        // We read data into the buffer, or the buffer was 0-len to begin with.
950                        // Tell the caller.
951                        self.state = Some(DataReaderState::Open(imp));
952                        return Poll::Ready(Ok(n_copied));
953                    }
954
955                    // No data available!  We have to try reading.
956                    imp
957                }
958                DataReaderState::Closed => {
959                    // TODO: Why are we returning an error rather than continuing to return EOF?
960                    self.state = Some(DataReaderState::Closed);
961                    return Poll::Ready(Err(Error::NotConnected.into()));
962                }
963            };
964
965            // See if a cell is ready.
966            match Pin::new(&mut imp).read_cell(cx) {
967                Poll::Ready(Err(e)) => {
968                    // There aren't any survivable errors in the current
969                    // design.
970                    self.state = Some(DataReaderState::Closed);
971                    #[cfg(feature = "stream-ctrl")]
972                    {
973                        imp.status.lock().expect("lock poisoned").record_error(&e);
974                    }
975                    let result = if matches!(e, Error::EndReceived(EndReason::DONE)) {
976                        Ok(0)
977                    } else {
978                        Err(e.into())
979                    };
980                    return Poll::Ready(result);
981                }
982                Poll::Ready(Ok(())) => {
983                    // It read a cell!  Continue the loop.
984                    state = DataReaderState::Open(imp);
985                }
986                Poll::Pending => {
987                    // No cells ready, so tell the
988                    // caller to get back to us later.
989                    self.state = Some(DataReaderState::Open(imp));
990                    return Poll::Pending;
991                }
992            }
993        }
994    }
995}
996
997#[cfg(feature = "tokio")]
998impl TokioAsyncRead for DataReader {
999    fn poll_read(
1000        self: Pin<&mut Self>,
1001        cx: &mut Context<'_>,
1002        buf: &mut ReadBuf<'_>,
1003    ) -> Poll<IoResult<()>> {
1004        TokioAsyncRead::poll_read(Pin::new(&mut self.compat()), cx, buf)
1005    }
1006}
1007
1008impl DataReaderImpl {
1009    /// Pull as many bytes as we can off of self.pending, and return that
1010    /// number of bytes.
1011    fn extract_bytes(&mut self, buf: &mut [u8]) -> usize {
1012        let remainder = &self.pending[self.offset..];
1013        let n_to_copy = std::cmp::min(buf.len(), remainder.len());
1014        buf[..n_to_copy].copy_from_slice(&remainder[..n_to_copy]);
1015        self.offset += n_to_copy;
1016
1017        n_to_copy
1018    }
1019
1020    /// Return true iff there are no buffered bytes here to yield
1021    fn buf_is_empty(&self) -> bool {
1022        self.pending.len() == self.offset
1023    }
1024
1025    /// Load self.pending with the contents of a new data cell.
1026    fn read_cell(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
1027        use DataStreamMsg::*;
1028        let msg = match self.as_mut().project().s.poll_next(cx) {
1029            Poll::Pending => return Poll::Pending,
1030            Poll::Ready(Some(Ok(unparsed))) => match unparsed.decode::<DataStreamMsg>() {
1031                Ok(cell) => cell.into_msg(),
1032                Err(e) => {
1033                    self.s.protocol_error();
1034                    return Poll::Ready(Err(Error::from_bytes_err(e, "message on a data stream")));
1035                }
1036            },
1037            Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
1038            // TODO: This doesn't seem right to me, but seems to be the behaviour of the code before
1039            // the refactoring, so I've kept the same behaviour. I think if the cell stream is
1040            // terminated, we should be returning `None` here and not considering it as an error.
1041            // The `StreamReceiver` will have already returned an error if the cell stream was
1042            // terminated without an END message.
1043            Poll::Ready(None) => return Poll::Ready(Err(Error::NotConnected)),
1044        };
1045
1046        let result = match msg {
1047            Connected(_) if !self.connected => {
1048                self.connected = true;
1049                #[cfg(feature = "stream-ctrl")]
1050                {
1051                    self.status
1052                        .lock()
1053                        .expect("poisoned lock")
1054                        .record_connected();
1055                }
1056                Ok(())
1057            }
1058            Connected(_) => {
1059                self.s.protocol_error();
1060                Err(Error::StreamProto(
1061                    "Received a second connect cell on a data stream".to_string(),
1062                ))
1063            }
1064            Data(d) if self.connected => {
1065                self.add_data(d.into());
1066                Ok(())
1067            }
1068            Data(_) => {
1069                self.s.protocol_error();
1070                Err(Error::StreamProto(
1071                    "Received a data cell an unconnected stream".to_string(),
1072                ))
1073            }
1074            End(e) => Err(Error::EndReceived(e.reason())),
1075        };
1076
1077        Poll::Ready(result)
1078    }
1079
1080    /// Add the data from `d` to the end of our pending bytes.
1081    fn add_data(&mut self, mut d: Vec<u8>) {
1082        if self.buf_is_empty() {
1083            // No data pending?  Just take d as the new pending.
1084            self.pending = d;
1085            self.offset = 0;
1086        } else {
1087            // TODO(nickm) This has potential to grow `pending` without bound.
1088            // Fortunately, we don't currently read cells or call this
1089            // `add_data` method when pending is nonempty—but if we do in the
1090            // future, we'll have to be careful here.
1091            self.pending.append(&mut d);
1092        }
1093    }
1094}
1095
1096/// A `CmdChecker` that enforces invariants for outbound data streams.
1097#[derive(Debug)]
1098pub(crate) struct DataCmdChecker {
1099    /// True if we are expecting to receive a CONNECTED message on this stream.
1100    expecting_connected: bool,
1101}
1102
1103impl Default for DataCmdChecker {
1104    fn default() -> Self {
1105        Self {
1106            expecting_connected: true,
1107        }
1108    }
1109}
1110
1111impl super::CmdChecker for DataCmdChecker {
1112    fn check_msg(
1113        &mut self,
1114        msg: &tor_cell::relaycell::UnparsedRelayMsg,
1115    ) -> Result<super::StreamStatus> {
1116        use super::StreamStatus::*;
1117        match msg.cmd() {
1118            RelayCmd::CONNECTED => {
1119                if !self.expecting_connected {
1120                    Err(Error::StreamProto(
1121                        "Received CONNECTED twice on a stream.".into(),
1122                    ))
1123                } else {
1124                    self.expecting_connected = false;
1125                    Ok(Open)
1126                }
1127            }
1128            RelayCmd::DATA => {
1129                if !self.expecting_connected {
1130                    Ok(Open)
1131                } else {
1132                    Err(Error::StreamProto(
1133                        "Received DATA before CONNECTED on a stream".into(),
1134                    ))
1135                }
1136            }
1137            RelayCmd::END => Ok(Closed),
1138            _ => Err(Error::StreamProto(format!(
1139                "Unexpected {} on a data stream!",
1140                msg.cmd()
1141            ))),
1142        }
1143    }
1144
1145    fn consume_checked_msg(&mut self, msg: tor_cell::relaycell::UnparsedRelayMsg) -> Result<()> {
1146        let _ = msg
1147            .decode::<DataStreamMsg>()
1148            .map_err(|err| Error::from_bytes_err(err, "cell on half-closed stream"))?;
1149        Ok(())
1150    }
1151}
1152
1153impl DataCmdChecker {
1154    /// Return a new boxed `DataCmdChecker` in a state suitable for a newly
1155    /// constructed connection.
1156    pub(crate) fn new_any() -> AnyCmdChecker {
1157        Box::<Self>::default()
1158    }
1159
1160    /// Return a new boxed `DataCmdChecker` in a state suitable for a
1161    /// connection where an initial CONNECTED cell is not expected.
1162    ///
1163    /// This is used by hidden services, exit relays, and directory servers
1164    /// to accept streams.
1165    #[cfg(feature = "hs-service")]
1166    pub(crate) fn new_connected() -> AnyCmdChecker {
1167        Box::new(Self {
1168            expecting_connected: false,
1169        })
1170    }
1171}