tor_proto/channel/
codec.rs

1//! Wrap tor_cell::...:::ChannelCodec for use with the asynchronous_codec
2//! crate.
3use std::{io::Error as IoError, marker::PhantomData};
4
5use futures::{AsyncRead, AsyncWrite};
6use tor_cell::chancell::{codec, ChanCell, ChanMsg};
7
8use asynchronous_codec;
9use bytes::BytesMut;
10
11/// An error from a ChannelCodec.
12///
13/// This is a separate error type for now because I suspect that we'll want to
14/// handle these differently in the rest of our channel code.
15#[derive(Debug, thiserror::Error)]
16pub(crate) enum CodecError {
17    /// An error from the underlying IO stream underneath a codec.
18    ///
19    /// (This isn't wrapped in an Arc, because we don't need this type to be
20    /// clone; it's crate-internal.)
21    #[error("Io error reading or writing a channel cell")]
22    Io(#[from] IoError),
23    /// An error from the cell decoding logic.
24    #[error("Error decoding an incoming channel cell")]
25    DecCell(#[source] tor_cell::Error),
26    /// An error from the cell encoding logic.
27    #[error("Error encoding an outgoing channel cell")]
28    EncCell(#[source] tor_cell::Error),
29}
30
31/// Asynchronous wrapper around ChannelCodec in tor_cell, with implementation
32/// for use with asynchronous_codec.
33///
34/// This type lets us wrap a TLS channel (or some other secure
35/// AsyncRead+AsyncWrite type) as a Sink and a Stream of ChanCell, so we can
36/// forget about byte-oriented communication.
37///
38/// It's parameterized on two message types: one that we're allowed to receive
39/// (`IN`), and one that we're allowed to send (`OUT`).
40pub(crate) struct ChannelCodec<IN, OUT> {
41    /// The cell codec that we'll use to encode and decode our cells.
42    inner: codec::ChannelCodec,
43    /// Tells the compiler that we're using IN, and we might
44    /// consume values of type IN.
45    _phantom_in: PhantomData<fn(IN)>,
46    /// Tells the compiler that we're using OUT, and we might
47    /// produce values of type OUT.
48    _phantom_out: PhantomData<fn() -> OUT>,
49}
50
51impl<IN, OUT> ChannelCodec<IN, OUT> {
52    /// Create a new ChannelCodec with a given link protocol.
53    pub(crate) fn new(link_proto: u16) -> Self {
54        ChannelCodec {
55            inner: codec::ChannelCodec::new(link_proto),
56            _phantom_in: PhantomData,
57            _phantom_out: PhantomData,
58        }
59    }
60
61    /// Consume this codec, and return a new one that sends and receives
62    /// different message types.
63    pub(crate) fn change_message_types<IN2, OUT2>(self) -> ChannelCodec<IN2, OUT2> {
64        ChannelCodec {
65            inner: self.inner,
66            _phantom_in: PhantomData,
67            _phantom_out: PhantomData,
68        }
69    }
70}
71
72impl<IN, OUT> asynchronous_codec::Encoder for ChannelCodec<IN, OUT>
73where
74    OUT: ChanMsg,
75{
76    type Item<'a> = ChanCell<OUT>;
77    type Error = CodecError;
78
79    fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
80        self.inner
81            .write_cell(item, dst)
82            .map_err(CodecError::EncCell)?;
83        Ok(())
84    }
85}
86
87impl<IN, OUT> asynchronous_codec::Decoder for ChannelCodec<IN, OUT>
88where
89    IN: ChanMsg,
90{
91    type Item = ChanCell<IN>;
92    type Error = CodecError;
93
94    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
95        self.inner.decode_cell(src).map_err(CodecError::DecCell)
96    }
97}
98
99/// Consume a [`Framed`](asynchronous_codec::Framed) codec user, and produce one that
100/// sends and receives different message types.
101pub(crate) fn change_message_types<T, IN, OUT, IN2, OUT2>(
102    framed: asynchronous_codec::Framed<T, ChannelCodec<IN, OUT>>,
103) -> asynchronous_codec::Framed<T, ChannelCodec<IN2, OUT2>>
104where
105    T: AsyncRead + AsyncWrite,
106    IN: ChanMsg,
107    OUT: ChanMsg,
108    IN2: ChanMsg,
109    OUT2: ChanMsg,
110{
111    asynchronous_codec::Framed::from_parts(
112        framed
113            .into_parts()
114            .map_codec(ChannelCodec::change_message_types),
115    )
116}
117
118#[cfg(test)]
119pub(crate) mod test {
120    #![allow(clippy::unwrap_used)]
121    use futures::io::{AsyncRead, AsyncWrite, Cursor, Result};
122    use futures::sink::SinkExt;
123    use futures::stream::StreamExt;
124    use futures::task::{Context, Poll};
125    use hex_literal::hex;
126    use std::pin::Pin;
127    use tor_cell::chancell::msg::AnyChanMsg;
128    use tor_rtcompat::StreamOps;
129
130    use super::{asynchronous_codec, ChannelCodec};
131    use tor_cell::chancell::{msg, AnyChanCell, ChanCmd, ChanMsg, CircId};
132
133    /// Helper type for reading and writing bytes to/from buffers.
134    // TODO: We might want to move this
135    pub(crate) struct MsgBuf {
136        /// Data we have received as a reader.
137        inbuf: futures::io::Cursor<Vec<u8>>,
138        /// Data we write as a writer.
139        outbuf: futures::io::Cursor<Vec<u8>>,
140    }
141
142    impl AsyncRead for MsgBuf {
143        fn poll_read(
144            mut self: Pin<&mut Self>,
145            cx: &mut Context<'_>,
146            buf: &mut [u8],
147        ) -> Poll<Result<usize>> {
148            Pin::new(&mut self.inbuf).poll_read(cx, buf)
149        }
150    }
151    impl AsyncWrite for MsgBuf {
152        fn poll_write(
153            mut self: Pin<&mut Self>,
154            cx: &mut Context<'_>,
155            buf: &[u8],
156        ) -> Poll<Result<usize>> {
157            Pin::new(&mut self.outbuf).poll_write(cx, buf)
158        }
159        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
160            Pin::new(&mut self.outbuf).poll_flush(cx)
161        }
162        fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
163            Pin::new(&mut self.outbuf).poll_close(cx)
164        }
165    }
166
167    impl StreamOps for MsgBuf {}
168
169    impl MsgBuf {
170        pub(crate) fn new<T: Into<Vec<u8>>>(output: T) -> Self {
171            let inbuf = Cursor::new(output.into());
172            let outbuf = Cursor::new(Vec::new());
173            MsgBuf { inbuf, outbuf }
174        }
175
176        pub(crate) fn consumed(&self) -> usize {
177            self.inbuf.position() as usize
178        }
179
180        pub(crate) fn all_consumed(&self) -> bool {
181            self.inbuf.get_ref().len() == self.consumed()
182        }
183
184        pub(crate) fn into_response(self) -> Vec<u8> {
185            self.outbuf.into_inner()
186        }
187    }
188
189    fn frame_buf(
190        mbuf: MsgBuf,
191    ) -> asynchronous_codec::Framed<MsgBuf, ChannelCodec<AnyChanMsg, AnyChanMsg>> {
192        asynchronous_codec::Framed::new(mbuf, ChannelCodec::new(4))
193    }
194
195    #[test]
196    fn check_encoding() {
197        tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
198            let mb = MsgBuf::new(&b""[..]);
199            let mut framed = frame_buf(mb);
200
201            let destroycell = msg::Destroy::new(2.into());
202            framed
203                .send(AnyChanCell::new(CircId::new(7), destroycell.into()))
204                .await
205                .unwrap();
206
207            let nocerts = msg::Certs::new_empty();
208            framed
209                .send(AnyChanCell::new(None, nocerts.into()))
210                .await
211                .unwrap();
212
213            framed.flush().await.unwrap();
214
215            let data = framed.into_inner().into_response();
216
217            assert_eq!(&data[0..10], &hex!("00000007 04 0200000000")[..]);
218
219            assert_eq!(&data[514..], &hex!("00000000 81 0001 00")[..]);
220        });
221    }
222
223    #[test]
224    fn check_decoding() {
225        tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
226            let mut dat = Vec::new();
227            dat.extend_from_slice(&hex!("00000007 04 0200000000")[..]);
228            dat.resize(514, 0);
229            dat.extend_from_slice(&hex!("00000000 81 0001 00")[..]);
230            let mb = MsgBuf::new(&dat[..]);
231            let mut framed = frame_buf(mb);
232
233            let destroy = framed.next().await.unwrap().unwrap();
234            let nocerts = framed.next().await.unwrap().unwrap();
235
236            assert_eq!(destroy.circid(), CircId::new(7));
237            assert_eq!(destroy.msg().cmd(), ChanCmd::DESTROY);
238            assert_eq!(nocerts.circid(), None);
239            assert_eq!(nocerts.msg().cmd(), ChanCmd::CERTS);
240
241            assert!(framed.into_inner().all_consumed());
242        });
243    }
244}