1
//! Internal: Declare the Writer type for tor-bytes
2

            
3
use std::marker::PhantomData;
4

            
5
use educe::Educe;
6

            
7
use crate::EncodeError;
8
use crate::EncodeResult;
9
use crate::Writeable;
10
use crate::WriteableOnce;
11

            
12
/// A byte-oriented trait for writing to small arrays.
13
///
14
/// Most code will want to use the fact that `Vec<u8>` implements this trait.
15
/// To define a new implementation, just define the write_all method.
16
///
17
/// # Examples
18
///
19
/// You can use a Writer to add bytes explicitly:
20
/// ```
21
/// use tor_bytes::Writer;
22
/// let mut w: Vec<u8> = Vec::new(); // Vec<u8> implements Writer.
23
/// w.write_u32(0x12345);
24
/// w.write_u8(0x22);
25
/// w.write_zeros(3);
26
/// assert_eq!(w, &[0x00, 0x01, 0x23, 0x45, 0x22, 0x00, 0x00, 0x00]);
27
/// ```
28
///
29
/// You can also use a Writer to encode things that implement the
30
/// Writeable trait:
31
///
32
/// ```
33
/// use tor_bytes::{Writer,Writeable};
34
/// let mut w: Vec<u8> = Vec::new();
35
/// w.write(&4_u16); // The unsigned types all implement Writeable.
36
///
37
/// // We also provide Writeable implementations for several important types.
38
/// use std::net::Ipv4Addr;
39
/// let ip = Ipv4Addr::new(127, 0, 0, 1);
40
/// w.write(&ip);
41
///
42
/// assert_eq!(w, &[0x00, 0x04, 0x7f, 0x00, 0x00, 0x01]);
43
/// ```
44
pub trait Writer {
45
    /// Append a slice to the end of this writer.
46
    fn write_all(&mut self, b: &[u8]);
47

            
48
    /// Append a single u8 to this writer.
49
3656
    fn write_u8(&mut self, x: u8) {
50
3656
        self.write_all(&[x]);
51
3656
    }
52
    /// Append a single u16 to this writer, encoded in big-endian order.
53
61854
    fn write_u16(&mut self, x: u16) {
54
61854
        self.write_all(&x.to_be_bytes());
55
61854
    }
56
    /// Append a single u32 to this writer, encoded in big-endian order.
57
27162
    fn write_u32(&mut self, x: u32) {
58
27162
        self.write_all(&x.to_be_bytes());
59
27162
    }
60
    /// Append a single u64 to this writer, encoded in big-endian order.
61
49790
    fn write_u64(&mut self, x: u64) {
62
49790
        self.write_all(&x.to_be_bytes());
63
49790
    }
64
    /// Append a single u128 to this writer, encoded in big-endian order.
65
2
    fn write_u128(&mut self, x: u128) {
66
2
        self.write_all(&x.to_be_bytes());
67
2
    }
68
    /// Write n bytes to this writer, all with the value zero.
69
    ///
70
    /// NOTE: This implementation is somewhat inefficient, since it allocates
71
    /// a vector.  You should probably replace it if you can.
72
20
    fn write_zeros(&mut self, n: usize) {
73
20
        let v = vec![0_u8; n];
74
20
        self.write_all(&v[..]);
75
20
    }
76

            
77
    /// Encode a Writeable object onto this writer, using its
78
    /// write_onto method.
79
115274
    fn write<E: Writeable + ?Sized>(&mut self, e: &E) -> EncodeResult<()> {
80
115274
        // TODO(nickm): should we recover from errors by undoing any partial
81
115274
        // writes that occurred?
82
115274
        e.write_onto(self)
83
115274
    }
84
    /// Encode a WriteableOnce object onto this writer, using its
85
    /// write_into method.
86
76
    fn write_and_consume<E: WriteableOnce>(&mut self, e: E) -> EncodeResult<()> {
87
76
        // TODO(nickm): should we recover from errors by undoing any partial
88
76
        // writes that occurred?
89
76
        e.write_into(self)
90
76
    }
91
    /// Arranges to write a u8 length, and some data whose encoding is that length
92
    ///
93
    /// Prefer to use this function, rather than manual length calculations
94
    /// and ad-hoc `write_u8`,
95
    /// Using this facility eliminates the need to separately keep track of the lengths.
96
    ///
97
    /// The returned `NestedWriter` should be used to write the contents,
98
    /// inside the byte-counted section.
99
    ///
100
    /// Then you **must** call `finish` to finalise the buffer.
101
30700
    fn write_nested_u8len(&mut self) -> NestedWriter<'_, Self, u8> {
102
30700
        write_nested_generic(self)
103
30700
    }
104
    /// Arranges to writes a u16 length and some data whose encoding is that length
105
30
    fn write_nested_u16len(&mut self) -> NestedWriter<'_, Self, u16> {
106
30
        write_nested_generic(self)
107
30
    }
108
    /// Arranges to writes a u32 length and some data whose encoding is that length
109
2
    fn write_nested_u32len(&mut self) -> NestedWriter<'_, Self, u32> {
110
2
        write_nested_generic(self)
111
2
    }
112
}
113

            
114
/// Work in progress state for writing a nested (length-counted) item
115
///
116
/// You must call `finish` !
117
30732
#[derive(Educe)]
118
#[educe(Deref, DerefMut)]
119
pub struct NestedWriter<'w, W, L>
120
where
121
    W: ?Sized,
122
{
123
    /// Variance doesn't matter since this is local to the module, but for form's sake:
124
    /// Be invariant in `L`, as maximally conservative.
125
    length_type: PhantomData<*mut L>,
126

            
127
    /// The outer writer
128
    outer: &'w mut W,
129

            
130
    /// Our inner buffer
131
    ///
132
    /// Caller can use us as `Writer` via `DerefMut`
133
    ///
134
    /// (An alternative would be to `impl Writer` but that involves recapitulating
135
    /// the impl for `Vec` and we do not have the `ambassador` crate to help us.
136
    /// Exposing this inner `Vec` is harmless.)
137
    ///
138
    /// We must allocate here because some `Writer`s are streaming
139
    #[educe(Deref, DerefMut)]
140
    inner: Vec<u8>,
141
}
142

            
143
/// Implementation of `write_nested_*` - generic over the length type
144
30732
fn write_nested_generic<W, L>(w: &mut W) -> NestedWriter<W, L>
145
30732
where
146
30732
    W: Writer + ?Sized,
147
30732
    L: Default + Copy + Sized + Writeable + TryFrom<usize>,
148
30732
{
149
30732
    NestedWriter {
150
30732
        length_type: PhantomData,
151
30732
        outer: w,
152
30732
        inner: vec![],
153
30732
    }
154
30732
}
155

            
156
impl<'w, W, L> NestedWriter<'w, W, L>
157
where
158
    W: Writer + ?Sized,
159
    L: Default + Copy + Sized + Writeable + TryFrom<usize> + std::ops::Not<Output = L>,
160
{
161
    /// Ends writing the nested data, and updates the length appropriately
162
    ///
163
    /// You must check the return value.
164
    /// It will only be `Err` if the amount you wrote doesn't fit into the length field.
165
    ///
166
    /// Sadly, you may well be implementing a `Writeable`, in which case you
167
    /// will have nothing good to do with the error, and must panic.
168
    /// In these cases you should have ensured, somehow, that overflow cannot happen.
169
    /// Ideally, by making your `Writeable` type incapable of holding values
170
    /// whose encoded length doesn't fit in the length field.
171
30732
    pub fn finish(self) -> Result<(), EncodeError> {
172
30732
        let length = self.inner.len();
173
30732
        let length: L = length.try_into().map_err(|_| EncodeError::BadLengthValue)?;
174
30730
        self.outer.write(&length)?;
175
30730
        self.outer.write(&self.inner)?;
176
30730
        Ok(())
177
30732
    }
178
}
179

            
180
#[cfg(test)]
181
#[allow(clippy::unwrap_used)]
182
mod tests {
183
    use super::*;
184
    #[test]
185
    fn write_ints() {
186
        let mut b = bytes::BytesMut::new();
187
        b.write_u8(1);
188
        b.write_u16(2);
189
        b.write_u32(3);
190
        b.write_u64(4);
191
        b.write_u128(5);
192

            
193
        assert_eq!(
194
            &b[..],
195
            &[
196
                1, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
197
                0, 0, 5
198
            ]
199
        );
200
    }
201

            
202
    #[test]
203
    fn write_slice() {
204
        let mut v = Vec::new();
205
        v.write_u16(0x5468);
206
        v.write(&b"ey're good dogs, Bront"[..]).unwrap();
207

            
208
        assert_eq!(&v[..], &b"They're good dogs, Bront"[..]);
209
    }
210

            
211
    #[test]
212
    fn writeable() -> EncodeResult<()> {
213
        struct Sequence(u8);
214
        impl Writeable for Sequence {
215
            fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
216
                for i in 0..self.0 {
217
                    b.write_u8(i);
218
                }
219
                Ok(())
220
            }
221
        }
222

            
223
        let mut v = Vec::new();
224
        v.write(&Sequence(6))?;
225
        assert_eq!(&v[..], &[0, 1, 2, 3, 4, 5]);
226

            
227
        v.write_and_consume(Sequence(3))?;
228
        assert_eq!(&v[..], &[0, 1, 2, 3, 4, 5, 0, 1, 2]);
229
        Ok(())
230
    }
231

            
232
    #[test]
233
    fn nested() {
234
        let mut v: Vec<u8> = b"abc".to_vec();
235

            
236
        let mut w = v.write_nested_u8len();
237
        w.write_u8(b'x');
238
        w.finish().unwrap();
239

            
240
        let mut w = v.write_nested_u16len();
241
        w.write_u8(b'y');
242
        w.finish().unwrap();
243

            
244
        let mut w = v.write_nested_u32len();
245
        w.write_u8(b'z');
246
        w.finish().unwrap();
247

            
248
        assert_eq!(&v, b"abc\x01x\0\x01y\0\0\0\x01z");
249

            
250
        let mut w = v.write_nested_u8len();
251
        w.write_zeros(256);
252
        assert!(matches!(
253
            w.finish().err().unwrap(),
254
            EncodeError::BadLengthValue
255
        ));
256
    }
257
}