tor_bytes/
writer.rs

1//! Internal: Declare the Writer type for tor-bytes
2
3use std::marker::PhantomData;
4
5use educe::Educe;
6
7use crate::EncodeError;
8use crate::EncodeResult;
9use crate::Writeable;
10use crate::WriteableOnce;
11
12/// A byte-oriented trait for writing to small arrays.
13///
14/// Most code will want to use the fact that `Vec<u8>` implements this trait.
15/// To define a new implementation, just define the write_all method.
16///
17/// # Examples
18///
19/// You can use a Writer to add bytes explicitly:
20/// ```
21/// use tor_bytes::Writer;
22/// let mut w: Vec<u8> = Vec::new(); // Vec<u8> implements Writer.
23/// w.write_u32(0x12345);
24/// w.write_u8(0x22);
25/// w.write_zeros(3);
26/// assert_eq!(w, &[0x00, 0x01, 0x23, 0x45, 0x22, 0x00, 0x00, 0x00]);
27/// ```
28///
29/// You can also use a Writer to encode things that implement the
30/// Writeable trait:
31///
32/// ```
33/// use tor_bytes::{Writer,Writeable};
34/// let mut w: Vec<u8> = Vec::new();
35/// w.write(&4_u16); // The unsigned types all implement Writeable.
36///
37/// // We also provide Writeable implementations for several important types.
38/// use std::net::Ipv4Addr;
39/// let ip = Ipv4Addr::new(127, 0, 0, 1);
40/// w.write(&ip);
41///
42/// assert_eq!(w, &[0x00, 0x04, 0x7f, 0x00, 0x00, 0x01]);
43/// ```
44pub trait Writer {
45    /// Append a slice to the end of this writer.
46    fn write_all(&mut self, b: &[u8]);
47
48    /// Append a single u8 to this writer.
49    fn write_u8(&mut self, x: u8) {
50        self.write_all(&[x]);
51    }
52    /// Append a single u16 to this writer, encoded in big-endian order.
53    fn write_u16(&mut self, x: u16) {
54        self.write_all(&x.to_be_bytes());
55    }
56    /// Append a single u32 to this writer, encoded in big-endian order.
57    fn write_u32(&mut self, x: u32) {
58        self.write_all(&x.to_be_bytes());
59    }
60    /// Append a single u64 to this writer, encoded in big-endian order.
61    fn write_u64(&mut self, x: u64) {
62        self.write_all(&x.to_be_bytes());
63    }
64    /// Append a single u128 to this writer, encoded in big-endian order.
65    fn write_u128(&mut self, x: u128) {
66        self.write_all(&x.to_be_bytes());
67    }
68    /// Write n bytes to this writer, all with the value zero.
69    ///
70    /// NOTE: This implementation is somewhat inefficient, since it allocates
71    /// a vector.  You should probably replace it if you can.
72    fn write_zeros(&mut self, n: usize) {
73        let v = vec![0_u8; n];
74        self.write_all(&v[..]);
75    }
76
77    /// Encode a Writeable object onto this writer, using its
78    /// write_onto method.
79    fn write<E: Writeable + ?Sized>(&mut self, e: &E) -> EncodeResult<()> {
80        // TODO(nickm): should we recover from errors by undoing any partial
81        // writes that occurred?
82        e.write_onto(self)
83    }
84    /// Encode a WriteableOnce object onto this writer, using its
85    /// write_into method.
86    fn write_and_consume<E: WriteableOnce>(&mut self, e: E) -> EncodeResult<()> {
87        // TODO(nickm): should we recover from errors by undoing any partial
88        // writes that occurred?
89        e.write_into(self)
90    }
91    /// Arranges to write a u8 length, and some data whose encoding is that length
92    ///
93    /// Prefer to use this function, rather than manual length calculations
94    /// and ad-hoc `write_u8`,
95    /// Using this facility eliminates the need to separately keep track of the lengths.
96    ///
97    /// The returned `NestedWriter` should be used to write the contents,
98    /// inside the byte-counted section.
99    ///
100    /// Then you **must** call `finish` to finalise the buffer.
101    fn write_nested_u8len(&mut self) -> NestedWriter<'_, Self, u8> {
102        write_nested_generic(self)
103    }
104    /// Arranges to writes a u16 length and some data whose encoding is that length
105    fn write_nested_u16len(&mut self) -> NestedWriter<'_, Self, u16> {
106        write_nested_generic(self)
107    }
108    /// Arranges to writes a u32 length and some data whose encoding is that length
109    fn write_nested_u32len(&mut self) -> NestedWriter<'_, Self, u32> {
110        write_nested_generic(self)
111    }
112}
113
114/// Work in progress state for writing a nested (length-counted) item
115///
116/// You must call `finish` !
117#[derive(Educe)]
118#[educe(Deref, DerefMut)]
119pub struct NestedWriter<'w, W, L>
120where
121    W: ?Sized,
122{
123    /// Variance doesn't matter since this is local to the module, but for form's sake:
124    /// Be invariant in `L`, as maximally conservative.
125    length_type: PhantomData<*mut L>,
126
127    /// The outer writer
128    outer: &'w mut W,
129
130    /// Our inner buffer
131    ///
132    /// Caller can use us as `Writer` via `DerefMut`
133    ///
134    /// (An alternative would be to `impl Writer` but that involves recapitulating
135    /// the impl for `Vec` and we do not have the `ambassador` crate to help us.
136    /// Exposing this inner `Vec` is harmless.)
137    ///
138    /// We must allocate here because some `Writer`s are streaming
139    #[educe(Deref, DerefMut)]
140    inner: Vec<u8>,
141}
142
143/// Implementation of `write_nested_*` - generic over the length type
144fn write_nested_generic<W, L>(w: &mut W) -> NestedWriter<W, L>
145where
146    W: Writer + ?Sized,
147    L: Default + Copy + Sized + Writeable + TryFrom<usize>,
148{
149    NestedWriter {
150        length_type: PhantomData,
151        outer: w,
152        inner: vec![],
153    }
154}
155
156impl<'w, W, L> NestedWriter<'w, W, L>
157where
158    W: Writer + ?Sized,
159    L: Default + Copy + Sized + Writeable + TryFrom<usize> + std::ops::Not<Output = L>,
160{
161    /// Ends writing the nested data, and updates the length appropriately
162    ///
163    /// You must check the return value.
164    /// It will only be `Err` if the amount you wrote doesn't fit into the length field.
165    ///
166    /// Sadly, you may well be implementing a `Writeable`, in which case you
167    /// will have nothing good to do with the error, and must panic.
168    /// In these cases you should have ensured, somehow, that overflow cannot happen.
169    /// Ideally, by making your `Writeable` type incapable of holding values
170    /// whose encoded length doesn't fit in the length field.
171    pub fn finish(self) -> Result<(), EncodeError> {
172        let length = self.inner.len();
173        let length: L = length.try_into().map_err(|_| EncodeError::BadLengthValue)?;
174        self.outer.write(&length)?;
175        self.outer.write(&self.inner)?;
176        Ok(())
177    }
178}
179
180#[cfg(test)]
181#[allow(clippy::unwrap_used)]
182mod tests {
183    use super::*;
184    #[test]
185    fn write_ints() {
186        let mut b = bytes::BytesMut::new();
187        b.write_u8(1);
188        b.write_u16(2);
189        b.write_u32(3);
190        b.write_u64(4);
191        b.write_u128(5);
192
193        assert_eq!(
194            &b[..],
195            &[
196                1, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
197                0, 0, 5
198            ]
199        );
200    }
201
202    #[test]
203    fn write_slice() {
204        let mut v = Vec::new();
205        v.write_u16(0x5468);
206        v.write(&b"ey're good dogs, Bront"[..]).unwrap();
207
208        assert_eq!(&v[..], &b"They're good dogs, Bront"[..]);
209    }
210
211    #[test]
212    fn writeable() -> EncodeResult<()> {
213        struct Sequence(u8);
214        impl Writeable for Sequence {
215            fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
216                for i in 0..self.0 {
217                    b.write_u8(i);
218                }
219                Ok(())
220            }
221        }
222
223        let mut v = Vec::new();
224        v.write(&Sequence(6))?;
225        assert_eq!(&v[..], &[0, 1, 2, 3, 4, 5]);
226
227        v.write_and_consume(Sequence(3))?;
228        assert_eq!(&v[..], &[0, 1, 2, 3, 4, 5, 0, 1, 2]);
229        Ok(())
230    }
231
232    #[test]
233    fn nested() {
234        let mut v: Vec<u8> = b"abc".to_vec();
235
236        let mut w = v.write_nested_u8len();
237        w.write_u8(b'x');
238        w.finish().unwrap();
239
240        let mut w = v.write_nested_u16len();
241        w.write_u8(b'y');
242        w.finish().unwrap();
243
244        let mut w = v.write_nested_u32len();
245        w.write_u8(b'z');
246        w.finish().unwrap();
247
248        assert_eq!(&v, b"abc\x01x\0\x01y\0\0\0\x01z");
249
250        let mut w = v.write_nested_u8len();
251        w.write_zeros(256);
252        assert!(matches!(
253            w.finish().err().unwrap(),
254            EncodeError::BadLengthValue
255        ));
256    }
257}