1
//! Internal: Declare the Writer type for tor-bytes
2

            
3
use std::marker::PhantomData;
4

            
5
use educe::Educe;
6

            
7
use crate::EncodeError;
8
use crate::EncodeResult;
9
use crate::Writeable;
10
use crate::WriteableOnce;
11

            
12
/// A byte-oriented trait for writing to small arrays.
13
///
14
/// Most code will want to use the fact that `Vec<u8>` implements this trait.
15
/// To define a new implementation, just define the write_all method.
16
///
17
/// # Examples
18
///
19
/// You can use a Writer to add bytes explicitly:
20
/// ```
21
/// use tor_bytes::Writer;
22
/// let mut w: Vec<u8> = Vec::new(); // Vec<u8> implements Writer.
23
/// w.write_u32(0x12345);
24
/// w.write_u8(0x22);
25
/// w.write_zeros(3);
26
/// assert_eq!(w, &[0x00, 0x01, 0x23, 0x45, 0x22, 0x00, 0x00, 0x00]);
27
/// ```
28
///
29
/// You can also use a Writer to encode things that implement the
30
/// Writeable trait:
31
///
32
/// ```
33
/// use tor_bytes::{Writer,Writeable};
34
/// let mut w: Vec<u8> = Vec::new();
35
/// w.write(&4_u16); // The unsigned types all implement Writeable.
36
///
37
/// // We also provide Writeable implementations for several important types.
38
/// use std::net::Ipv4Addr;
39
/// let ip = Ipv4Addr::new(127, 0, 0, 1);
40
/// w.write(&ip);
41
///
42
/// assert_eq!(w, &[0x00, 0x04, 0x7f, 0x00, 0x00, 0x01]);
43
/// ```
44
pub trait Writer {
45
    /// Append a slice to the end of this writer.
46
    fn write_all(&mut self, b: &[u8]);
47

            
48
    /// Append a single u8 to this writer.
49
5384
    fn write_u8(&mut self, x: u8) {
50
5384
        self.write_all(&[x]);
51
5384
    }
52
    /// Append a single u16 to this writer, encoded in big-endian order.
53
37956
    fn write_u16(&mut self, x: u16) {
54
37956
        self.write_all(&x.to_be_bytes());
55
37956
    }
56
    /// Append a single u32 to this writer, encoded in big-endian order.
57
4932
    fn write_u32(&mut self, x: u32) {
58
4932
        self.write_all(&x.to_be_bytes());
59
4932
    }
60
    /// Append a single u64 to this writer, encoded in big-endian order.
61
50038
    fn write_u64(&mut self, x: u64) {
62
50038
        self.write_all(&x.to_be_bytes());
63
50038
    }
64
    /// Append a single u128 to this writer, encoded in big-endian order.
65
2
    fn write_u128(&mut self, x: u128) {
66
2
        self.write_all(&x.to_be_bytes());
67
2
    }
68
    /// Write n bytes to this writer, all with the value zero.
69
    ///
70
    /// NOTE: This implementation is somewhat inefficient, since it allocates
71
    /// a vector.  You should probably replace it if you can.
72
20
    fn write_zeros(&mut self, n: usize) {
73
20
        let v = vec![0_u8; n];
74
20
        self.write_all(&v[..]);
75
20
    }
76

            
77
    /// Encode a Writeable object onto this writer, using its
78
    /// write_onto method.
79
29650
    fn write<E: Writeable + ?Sized>(&mut self, e: &E) -> EncodeResult<()> {
80
29650
        // TODO(nickm): should we recover from errors by undoing any partial
81
29650
        // writes that occurred?
82
29650
        e.write_onto(self)
83
29650
    }
84
    /// Encode a WriteableOnce object onto this writer, using its
85
    /// write_into method.
86
76
    fn write_and_consume<E: WriteableOnce>(&mut self, e: E) -> EncodeResult<()> {
87
76
        // TODO(nickm): should we recover from errors by undoing any partial
88
76
        // writes that occurred?
89
76
        e.write_into(self)
90
76
    }
91
    /// Arranges to write a u8 length, and some data whose encoding is that length
92
    ///
93
    /// Prefer to use this function, rather than manual length calculations
94
    /// and ad-hoc `write_u8`,
95
    /// Using this facility eliminates the need to separately keep track of the lengths.
96
    ///
97
    /// The returned `NestedWriter` should be used to write the contents,
98
    /// inside the byte-counted section.
99
    ///
100
    /// Then you **must** call `finish` to finalise the buffer.
101
244
    fn write_nested_u8len(&mut self) -> NestedWriter<'_, Self, u8> {
102
244
        write_nested_generic(self)
103
244
    }
104
    /// Arranges to writes a u16 length and some data whose encoding is that length
105
30
    fn write_nested_u16len(&mut self) -> NestedWriter<'_, Self, u16> {
106
30
        write_nested_generic(self)
107
30
    }
108
    /// Arranges to writes a u32 length and some data whose encoding is that length
109
2
    fn write_nested_u32len(&mut self) -> NestedWriter<'_, Self, u32> {
110
2
        write_nested_generic(self)
111
2
    }
112
}
113

            
114
/// Work in progress state for writing a nested (length-counted) item
115
///
116
/// You must call `finish` !
117
276
#[derive(Educe)]
118
#[educe(Deref, DerefMut)]
119
pub struct NestedWriter<'w, W, L>
120
where
121
    W: ?Sized,
122
{
123
    /// Variance doesn't matter since this is local to the module, but for form's sake:
124
    /// Be invariant in `L`, as maximally conservative.
125
    length_type: PhantomData<*mut L>,
126

            
127
    /// The outer writer
128
    outer: &'w mut W,
129

            
130
    /// Our inner buffer
131
    ///
132
    /// Caller can use us as `Writer` via `DerefMut`
133
    ///
134
    /// (An alternative would be to `impl Writer` but that involves recapitulating
135
    /// the impl for `Vec` and we do not have the `ambassador` crate to help us.
136
    /// Exposing this inner `Vec` is harmless.)
137
    ///
138
    /// We must allocate here because some `Writer`s are streaming
139
    #[educe(Deref, DerefMut)]
140
    inner: Vec<u8>,
141
}
142

            
143
/// Implementation of `write_nested_*` - generic over the length type
144
276
fn write_nested_generic<W, L>(w: &mut W) -> NestedWriter<W, L>
145
276
where
146
276
    W: Writer + ?Sized,
147
276
    L: Default + Copy + Sized + Writeable + TryFrom<usize>,
148
276
{
149
276
    NestedWriter {
150
276
        length_type: PhantomData,
151
276
        outer: w,
152
276
        inner: vec![],
153
276
    }
154
276
}
155

            
156
impl<'w, W, L> NestedWriter<'w, W, L>
157
where
158
    W: Writer + ?Sized,
159
    L: Default + Copy + Sized + Writeable + TryFrom<usize> + std::ops::Not<Output = L>,
160
{
161
    /// Ends writing the nested data, and updates the length appropriately
162
    ///
163
    /// You must check the return value.
164
    /// It will only be `Err` if the amount you wrote doesn't fit into the length field.
165
    ///
166
    /// Sadly, you may well be implementing a `Writeable`, in which case you
167
    /// will have nothing good to do with the error, and must panic.
168
    /// In these cases you should have ensured, somehow, that overflow cannot happen.
169
    /// Ideally, by making your `Writeable` type incapable of holding values
170
    /// whose encoded length doesn't fit in the length field.
171
276
    pub fn finish(self) -> Result<(), EncodeError> {
172
276
        let length = self.inner.len();
173
276
        let length: L = length.try_into().map_err(|_| EncodeError::BadLengthValue)?;
174
274
        self.outer.write(&length)?;
175
274
        self.outer.write(&self.inner)?;
176
274
        Ok(())
177
276
    }
178
}
179

            
180
#[cfg(test)]
181
#[allow(clippy::unwrap_used)]
182
mod tests {
183
    use super::*;
184
    #[test]
185
    fn write_ints() {
186
        let mut b = bytes::BytesMut::new();
187
        b.write_u8(1);
188
        b.write_u16(2);
189
        b.write_u32(3);
190
        b.write_u64(4);
191
        b.write_u128(5);
192

            
193
        assert_eq!(
194
            &b[..],
195
            &[
196
                1, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
197
                0, 0, 5
198
            ]
199
        );
200
    }
201

            
202
    #[test]
203
    fn write_slice() {
204
        let mut v = Vec::new();
205
        v.write_u16(0x5468);
206
        v.write(&b"ey're good dogs, Bront"[..]).unwrap();
207

            
208
        assert_eq!(&v[..], &b"They're good dogs, Bront"[..]);
209
    }
210

            
211
    #[test]
212
    fn writeable() -> EncodeResult<()> {
213
        struct Sequence(u8);
214
        impl Writeable for Sequence {
215
            fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
216
                for i in 0..self.0 {
217
                    b.write_u8(i);
218
                }
219
                Ok(())
220
            }
221
        }
222

            
223
        let mut v = Vec::new();
224
        v.write(&Sequence(6))?;
225
        assert_eq!(&v[..], &[0, 1, 2, 3, 4, 5]);
226

            
227
        v.write_and_consume(Sequence(3))?;
228
        assert_eq!(&v[..], &[0, 1, 2, 3, 4, 5, 0, 1, 2]);
229
        Ok(())
230
    }
231

            
232
    #[test]
233
    fn nested() {
234
        let mut v: Vec<u8> = b"abc".to_vec();
235

            
236
        let mut w = v.write_nested_u8len();
237
        w.write_u8(b'x');
238
        w.finish().unwrap();
239

            
240
        let mut w = v.write_nested_u16len();
241
        w.write_u8(b'y');
242
        w.finish().unwrap();
243

            
244
        let mut w = v.write_nested_u32len();
245
        w.write_u8(b'z');
246
        w.finish().unwrap();
247

            
248
        assert_eq!(&v, b"abc\x01x\0\x01y\0\0\0\x01z");
249

            
250
        let mut w = v.write_nested_u8len();
251
        w.write_zeros(256);
252
        assert!(matches!(
253
            w.finish().err().unwrap(),
254
            EncodeError::BadLengthValue
255
        ));
256
    }
257
}