tor_proto/channel/
codec.rs
1use std::{io::Error as IoError, marker::PhantomData};
4
5use futures::{AsyncRead, AsyncWrite};
6use tor_cell::chancell::{codec, ChanCell, ChanMsg};
7
8use asynchronous_codec;
9use bytes::BytesMut;
10
11#[derive(Debug, thiserror::Error)]
16pub(crate) enum CodecError {
17 #[error("Io error reading or writing a channel cell")]
22 Io(#[from] IoError),
23 #[error("Error decoding an incoming channel cell")]
25 DecCell(#[source] tor_cell::Error),
26 #[error("Error encoding an outgoing channel cell")]
28 EncCell(#[source] tor_cell::Error),
29}
30
31pub(crate) struct ChannelCodec<IN, OUT> {
41 inner: codec::ChannelCodec,
43 _phantom_in: PhantomData<fn(IN)>,
46 _phantom_out: PhantomData<fn() -> OUT>,
49}
50
51impl<IN, OUT> ChannelCodec<IN, OUT> {
52 pub(crate) fn new(link_proto: u16) -> Self {
54 ChannelCodec {
55 inner: codec::ChannelCodec::new(link_proto),
56 _phantom_in: PhantomData,
57 _phantom_out: PhantomData,
58 }
59 }
60
61 pub(crate) fn change_message_types<IN2, OUT2>(self) -> ChannelCodec<IN2, OUT2> {
64 ChannelCodec {
65 inner: self.inner,
66 _phantom_in: PhantomData,
67 _phantom_out: PhantomData,
68 }
69 }
70}
71
72impl<IN, OUT> asynchronous_codec::Encoder for ChannelCodec<IN, OUT>
73where
74 OUT: ChanMsg,
75{
76 type Item<'a> = ChanCell<OUT>;
77 type Error = CodecError;
78
79 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
80 self.inner
81 .write_cell(item, dst)
82 .map_err(CodecError::EncCell)?;
83 Ok(())
84 }
85}
86
87impl<IN, OUT> asynchronous_codec::Decoder for ChannelCodec<IN, OUT>
88where
89 IN: ChanMsg,
90{
91 type Item = ChanCell<IN>;
92 type Error = CodecError;
93
94 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
95 self.inner.decode_cell(src).map_err(CodecError::DecCell)
96 }
97}
98
99pub(crate) fn change_message_types<T, IN, OUT, IN2, OUT2>(
102 framed: asynchronous_codec::Framed<T, ChannelCodec<IN, OUT>>,
103) -> asynchronous_codec::Framed<T, ChannelCodec<IN2, OUT2>>
104where
105 T: AsyncRead + AsyncWrite,
106 IN: ChanMsg,
107 OUT: ChanMsg,
108 IN2: ChanMsg,
109 OUT2: ChanMsg,
110{
111 asynchronous_codec::Framed::from_parts(
112 framed
113 .into_parts()
114 .map_codec(ChannelCodec::change_message_types),
115 )
116}
117
118#[cfg(test)]
119pub(crate) mod test {
120 #![allow(clippy::unwrap_used)]
121 use futures::io::{AsyncRead, AsyncWrite, Cursor, Result};
122 use futures::sink::SinkExt;
123 use futures::stream::StreamExt;
124 use futures::task::{Context, Poll};
125 use hex_literal::hex;
126 use std::pin::Pin;
127 use tor_cell::chancell::msg::AnyChanMsg;
128 use tor_rtcompat::StreamOps;
129
130 use super::{asynchronous_codec, ChannelCodec};
131 use tor_cell::chancell::{msg, AnyChanCell, ChanCmd, ChanMsg, CircId};
132
133 pub(crate) struct MsgBuf {
136 inbuf: futures::io::Cursor<Vec<u8>>,
138 outbuf: futures::io::Cursor<Vec<u8>>,
140 }
141
142 impl AsyncRead for MsgBuf {
143 fn poll_read(
144 mut self: Pin<&mut Self>,
145 cx: &mut Context<'_>,
146 buf: &mut [u8],
147 ) -> Poll<Result<usize>> {
148 Pin::new(&mut self.inbuf).poll_read(cx, buf)
149 }
150 }
151 impl AsyncWrite for MsgBuf {
152 fn poll_write(
153 mut self: Pin<&mut Self>,
154 cx: &mut Context<'_>,
155 buf: &[u8],
156 ) -> Poll<Result<usize>> {
157 Pin::new(&mut self.outbuf).poll_write(cx, buf)
158 }
159 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
160 Pin::new(&mut self.outbuf).poll_flush(cx)
161 }
162 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
163 Pin::new(&mut self.outbuf).poll_close(cx)
164 }
165 }
166
167 impl StreamOps for MsgBuf {}
168
169 impl MsgBuf {
170 pub(crate) fn new<T: Into<Vec<u8>>>(output: T) -> Self {
171 let inbuf = Cursor::new(output.into());
172 let outbuf = Cursor::new(Vec::new());
173 MsgBuf { inbuf, outbuf }
174 }
175
176 pub(crate) fn consumed(&self) -> usize {
177 self.inbuf.position() as usize
178 }
179
180 pub(crate) fn all_consumed(&self) -> bool {
181 self.inbuf.get_ref().len() == self.consumed()
182 }
183
184 pub(crate) fn into_response(self) -> Vec<u8> {
185 self.outbuf.into_inner()
186 }
187 }
188
189 fn frame_buf(
190 mbuf: MsgBuf,
191 ) -> asynchronous_codec::Framed<MsgBuf, ChannelCodec<AnyChanMsg, AnyChanMsg>> {
192 asynchronous_codec::Framed::new(mbuf, ChannelCodec::new(4))
193 }
194
195 #[test]
196 fn check_encoding() {
197 tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
198 let mb = MsgBuf::new(&b""[..]);
199 let mut framed = frame_buf(mb);
200
201 let destroycell = msg::Destroy::new(2.into());
202 framed
203 .send(AnyChanCell::new(CircId::new(7), destroycell.into()))
204 .await
205 .unwrap();
206
207 let nocerts = msg::Certs::new_empty();
208 framed
209 .send(AnyChanCell::new(None, nocerts.into()))
210 .await
211 .unwrap();
212
213 framed.flush().await.unwrap();
214
215 let data = framed.into_inner().into_response();
216
217 assert_eq!(&data[0..10], &hex!("00000007 04 0200000000")[..]);
218
219 assert_eq!(&data[514..], &hex!("00000000 81 0001 00")[..]);
220 });
221 }
222
223 #[test]
224 fn check_decoding() {
225 tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
226 let mut dat = Vec::new();
227 dat.extend_from_slice(&hex!("00000007 04 0200000000")[..]);
228 dat.resize(514, 0);
229 dat.extend_from_slice(&hex!("00000000 81 0001 00")[..]);
230 let mb = MsgBuf::new(&dat[..]);
231 let mut framed = frame_buf(mb);
232
233 let destroy = framed.next().await.unwrap().unwrap();
234 let nocerts = framed.next().await.unwrap().unwrap();
235
236 assert_eq!(destroy.circid(), CircId::new(7));
237 assert_eq!(destroy.msg().cmd(), ChanCmd::DESTROY);
238 assert_eq!(nocerts.circid(), None);
239 assert_eq!(nocerts.msg().cmd(), ChanCmd::CERTS);
240
241 assert!(framed.into_inner().all_consumed());
242 });
243 }
244}