1
//! Wrap tor_cell::...:::ChannelCodec for use with the asynchronous_codec
2
//! crate.
3
use std::{io::Error as IoError, marker::PhantomData};
4

            
5
use futures::{AsyncRead, AsyncWrite};
6
use tor_cell::chancell::{codec, ChanCell, ChanMsg};
7

            
8
use asynchronous_codec;
9
use bytes::BytesMut;
10

            
11
/// An error from a ChannelCodec.
12
///
13
/// This is a separate error type for now because I suspect that we'll want to
14
/// handle these differently in the rest of our channel code.
15
#[derive(Debug, thiserror::Error)]
16
pub(crate) enum CodecError {
17
    /// An error from the underlying IO stream underneath a codec.
18
    ///
19
    /// (This isn't wrapped in an Arc, because we don't need this type to be
20
    /// clone; it's crate-internal.)
21
    #[error("Io error reading or writing a channel cell")]
22
    Io(#[from] IoError),
23
    /// An error from the cell decoding logic.
24
    #[error("Error decoding an incoming channel cell")]
25
    DecCell(#[source] tor_cell::Error),
26
    /// An error from the cell encoding logic.
27
    #[error("Error encoding an outgoing channel cell")]
28
    EncCell(#[source] tor_cell::Error),
29
}
30

            
31
/// Asynchronous wrapper around ChannelCodec in tor_cell, with implementation
32
/// for use with asynchronous_codec.
33
///
34
/// This type lets us wrap a TLS channel (or some other secure
35
/// AsyncRead+AsyncWrite type) as a Sink and a Stream of ChanCell, so we can
36
/// forget about byte-oriented communication.
37
///
38
/// It's parameterized on two message types: one that we're allowed to receive
39
/// (`IN`), and one that we're allowed to send (`OUT`).
40
pub(crate) struct ChannelCodec<IN, OUT> {
41
    /// The cell codec that we'll use to encode and decode our cells.
42
    inner: codec::ChannelCodec,
43
    /// Tells the compiler that we're using IN, and we might
44
    /// consume values of type IN.
45
    _phantom_in: PhantomData<fn(IN)>,
46
    /// Tells the compiler that we're using OUT, and we might
47
    /// produce values of type OUT.
48
    _phantom_out: PhantomData<fn() -> OUT>,
49
}
50

            
51
impl<IN, OUT> ChannelCodec<IN, OUT> {
52
    /// Create a new ChannelCodec with a given link protocol.
53
60
    pub(crate) fn new(link_proto: u16) -> Self {
54
60
        ChannelCodec {
55
60
            inner: codec::ChannelCodec::new(link_proto),
56
60
            _phantom_in: PhantomData,
57
60
            _phantom_out: PhantomData,
58
60
        }
59
60
    }
60

            
61
    /// Consume this codec, and return a new one that sends and receives
62
    /// different message types.
63
8
    pub(crate) fn change_message_types<IN2, OUT2>(self) -> ChannelCodec<IN2, OUT2> {
64
8
        ChannelCodec {
65
8
            inner: self.inner,
66
8
            _phantom_in: PhantomData,
67
8
            _phantom_out: PhantomData,
68
8
        }
69
8
    }
70
}
71

            
72
impl<IN, OUT> asynchronous_codec::Encoder for ChannelCodec<IN, OUT>
73
where
74
    OUT: ChanMsg,
75
{
76
    type Item<'a> = ChanCell<OUT>;
77
    type Error = CodecError;
78

            
79
20
    fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
80
20
        self.inner
81
20
            .write_cell(item, dst)
82
20
            .map_err(CodecError::EncCell)?;
83
20
        Ok(())
84
20
    }
85
}
86

            
87
impl<IN, OUT> asynchronous_codec::Decoder for ChannelCodec<IN, OUT>
88
where
89
    IN: ChanMsg,
90
{
91
    type Item = ChanCell<IN>;
92
    type Error = CodecError;
93

            
94
120
    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
95
120
        self.inner.decode_cell(src).map_err(CodecError::DecCell)
96
120
    }
97
}
98

            
99
/// Consume a [`Framed`](asynchronous_codec::Framed) codec user, and produce one that
100
/// sends and receives different message types.
101
8
pub(crate) fn change_message_types<T, IN, OUT, IN2, OUT2>(
102
8
    framed: asynchronous_codec::Framed<T, ChannelCodec<IN, OUT>>,
103
8
) -> asynchronous_codec::Framed<T, ChannelCodec<IN2, OUT2>>
104
8
where
105
8
    T: AsyncRead + AsyncWrite,
106
8
    IN: ChanMsg,
107
8
    OUT: ChanMsg,
108
8
    IN2: ChanMsg,
109
8
    OUT2: ChanMsg,
110
8
{
111
8
    asynchronous_codec::Framed::from_parts(
112
8
        framed
113
8
            .into_parts()
114
8
            .map_codec(ChannelCodec::change_message_types),
115
8
    )
116
8
}
117

            
118
#[cfg(test)]
119
pub(crate) mod test {
120
    #![allow(clippy::unwrap_used)]
121
    use futures::io::{AsyncRead, AsyncWrite, Cursor, Result};
122
    use futures::sink::SinkExt;
123
    use futures::stream::StreamExt;
124
    use futures::task::{Context, Poll};
125
    use hex_literal::hex;
126
    use std::pin::Pin;
127
    use tor_cell::chancell::msg::AnyChanMsg;
128
    use tor_rtcompat::StreamOps;
129

            
130
    use super::{asynchronous_codec, ChannelCodec};
131
    use tor_cell::chancell::{msg, AnyChanCell, ChanCmd, ChanMsg, CircId};
132

            
133
    /// Helper type for reading and writing bytes to/from buffers.
134
    // TODO: We might want to move this
135
    pub(crate) struct MsgBuf {
136
        /// Data we have received as a reader.
137
        inbuf: futures::io::Cursor<Vec<u8>>,
138
        /// Data we write as a writer.
139
        outbuf: futures::io::Cursor<Vec<u8>>,
140
    }
141

            
142
    impl AsyncRead for MsgBuf {
143
        fn poll_read(
144
            mut self: Pin<&mut Self>,
145
            cx: &mut Context<'_>,
146
            buf: &mut [u8],
147
        ) -> Poll<Result<usize>> {
148
            Pin::new(&mut self.inbuf).poll_read(cx, buf)
149
        }
150
    }
151
    impl AsyncWrite for MsgBuf {
152
        fn poll_write(
153
            mut self: Pin<&mut Self>,
154
            cx: &mut Context<'_>,
155
            buf: &[u8],
156
        ) -> Poll<Result<usize>> {
157
            Pin::new(&mut self.outbuf).poll_write(cx, buf)
158
        }
159
        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
160
            Pin::new(&mut self.outbuf).poll_flush(cx)
161
        }
162
        fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
163
            Pin::new(&mut self.outbuf).poll_close(cx)
164
        }
165
    }
166

            
167
    impl StreamOps for MsgBuf {}
168

            
169
    impl MsgBuf {
170
        pub(crate) fn new<T: Into<Vec<u8>>>(output: T) -> Self {
171
            let inbuf = Cursor::new(output.into());
172
            let outbuf = Cursor::new(Vec::new());
173
            MsgBuf { inbuf, outbuf }
174
        }
175

            
176
        pub(crate) fn consumed(&self) -> usize {
177
            self.inbuf.position() as usize
178
        }
179

            
180
        pub(crate) fn all_consumed(&self) -> bool {
181
            self.inbuf.get_ref().len() == self.consumed()
182
        }
183

            
184
        pub(crate) fn into_response(self) -> Vec<u8> {
185
            self.outbuf.into_inner()
186
        }
187
    }
188

            
189
    fn frame_buf(
190
        mbuf: MsgBuf,
191
    ) -> asynchronous_codec::Framed<MsgBuf, ChannelCodec<AnyChanMsg, AnyChanMsg>> {
192
        asynchronous_codec::Framed::new(mbuf, ChannelCodec::new(4))
193
    }
194

            
195
    #[test]
196
    fn check_encoding() {
197
        tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
198
            let mb = MsgBuf::new(&b""[..]);
199
            let mut framed = frame_buf(mb);
200

            
201
            let destroycell = msg::Destroy::new(2.into());
202
            framed
203
                .send(AnyChanCell::new(CircId::new(7), destroycell.into()))
204
                .await
205
                .unwrap();
206

            
207
            let nocerts = msg::Certs::new_empty();
208
            framed
209
                .send(AnyChanCell::new(None, nocerts.into()))
210
                .await
211
                .unwrap();
212

            
213
            framed.flush().await.unwrap();
214

            
215
            let data = framed.into_inner().into_response();
216

            
217
            assert_eq!(&data[0..10], &hex!("00000007 04 0200000000")[..]);
218

            
219
            assert_eq!(&data[514..], &hex!("00000000 81 0001 00")[..]);
220
        });
221
    }
222

            
223
    #[test]
224
    fn check_decoding() {
225
        tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
226
            let mut dat = Vec::new();
227
            dat.extend_from_slice(&hex!("00000007 04 0200000000")[..]);
228
            dat.resize(514, 0);
229
            dat.extend_from_slice(&hex!("00000000 81 0001 00")[..]);
230
            let mb = MsgBuf::new(&dat[..]);
231
            let mut framed = frame_buf(mb);
232

            
233
            let destroy = framed.next().await.unwrap().unwrap();
234
            let nocerts = framed.next().await.unwrap().unwrap();
235

            
236
            assert_eq!(destroy.circid(), CircId::new(7));
237
            assert_eq!(destroy.msg().cmd(), ChanCmd::DESTROY);
238
            assert_eq!(nocerts.circid(), None);
239
            assert_eq!(nocerts.msg().cmd(), ChanCmd::CERTS);
240

            
241
            assert!(framed.into_inner().all_consumed());
242
        });
243
    }
244
}