1
//! Types and encodings used during circuit extension.
2

            
3
use crate::{Error, Result};
4
use caret::caret_int;
5
use tor_bytes::{EncodeResult, Readable, Reader, Writeable, Writer};
6

            
7
caret_int! {
8
    /// A type of ntor v3 extension data (`EXT_FIELD_TYPE`).
9
    pub struct NtorV3ExtensionType(u8) {
10
        /// Request congestion control be enabled for a circuit.
11
        CC_REQUEST = 1,
12
        /// Acknowledge a congestion control request.
13
        CC_RESPONSE = 2
14
    }
15
}
16

            
17
/// A piece of extension data, to be encoded as the message in an ntor v3 handshake.
18
#[derive(Clone, Debug, PartialEq, Eq)]
19
#[non_exhaustive]
20
pub enum NtorV3Extension {
21
    /// Request congestion control be enabled for this circuit (client → exit node).
22
    ///
23
    /// (`EXT_FIELD_TYPE` = 01)
24
    RequestCongestionControl,
25
    /// Acknowledge a congestion control request (exit node → client).
26
    ///
27
    /// (`EXT_FIELD_TYPE` = 02)
28
    AckCongestionControl {
29
        /// The exit's current view of the `cc_sendme_inc` consensus parameter.
30
        sendme_inc: u8,
31
    },
32
    /// An unknown piece of extension data.
33
    Unrecognized {
34
        /// The extension type (`EXT_FIELD_TYPE`).
35
        field_type: NtorV3ExtensionType,
36
        /// The raw bytes of unrecognized extension data.
37
        data: Vec<u8>,
38
    },
39
}
40

            
41
impl NtorV3Extension {
42
    /// Encode a set of extensions into a "message" for an ntor v3 handshake.
43
40
    pub fn write_many_onto<W: Writer>(exts: &[NtorV3Extension], out: &mut W) -> EncodeResult<()> {
44
40
        let n_extensions =
45
40
            u8::try_from(exts.len()).map_err(|_| tor_bytes::EncodeError::BadLengthValue)?;
46
40
        out.write_u8(n_extensions);
47
40
        exts.iter().try_for_each(|x| x.write_onto(out))
48
40
    }
49

            
50
    /// Decode a slice of bytes representing the "message" of an ntor v3 handshake into a set of
51
    /// extensions.
52
980
    pub fn decode(message: &[u8]) -> Result<Vec<Self>> {
53
980
        let mut reader = Reader::from_slice(message);
54
980
        let mut ret = vec![];
55
980
        let n_extensions = reader.take_u8().map_err(|e| Error::BytesErr {
56
            err: e,
57
            parsed: "n_extensions",
58
980
        })?;
59
980
        for _ in 0..n_extensions {
60
98
            ret.push(
61
98
                NtorV3Extension::take_from(&mut reader).map_err(|err| Error::BytesErr {
62
                    err,
63
                    parsed: "an ntor extension",
64
98
                })?,
65
            );
66
        }
67
980
        if reader.remaining() > 0 {
68
            return Err(Error::BytesErr {
69
                err: tor_bytes::Error::ExtraneousBytes,
70
                parsed: "ntor extensions set",
71
            });
72
980
        }
73
980
        Ok(ret)
74
980
    }
75
}
76

            
77
impl Writeable for NtorV3Extension {
78
4
    fn write_onto<W: Writer + ?Sized>(&self, out: &mut W) -> EncodeResult<()> {
79
4
        match self {
80
2
            NtorV3Extension::RequestCongestionControl => {
81
2
                out.write_all(&[1, 0]);
82
2
            }
83
2
            NtorV3Extension::AckCongestionControl { sendme_inc } => {
84
2
                out.write_all(&[2, 1, *sendme_inc]);
85
2
            }
86
            NtorV3Extension::Unrecognized { field_type, data } => {
87
                // FIXME(eta): This will break if you try and fill `data` with more than 255 bytes.
88
                //             This is only a problem if you construct your own `Unrecognized`, though.
89
                out.write_all(&[field_type.get(), data.len() as u8]);
90
                out.write_all(data);
91
            }
92
        }
93
4
        Ok(())
94
4
    }
95
}
96

            
97
impl Readable for NtorV3Extension {
98
98
    fn take_from(reader: &mut Reader<'_>) -> tor_bytes::Result<Self> {
99
98
        let tag: NtorV3ExtensionType = reader.take_u8()?.into();
100
98
        let len = reader.take_u8()?;
101
98
        Ok(match tag {
102
            NtorV3ExtensionType::CC_REQUEST => {
103
49
                if len != 0 {
104
                    return Err(tor_bytes::Error::InvalidMessage(
105
                        "invalid length for RequestCongestionControl".into(),
106
                    ));
107
49
                }
108
49
                NtorV3Extension::RequestCongestionControl
109
            }
110
            NtorV3ExtensionType::CC_RESPONSE => {
111
49
                if len != 1 {
112
                    return Err(tor_bytes::Error::InvalidMessage(
113
                        "invalid length for AckCongestionControl".into(),
114
                    ));
115
49
                }
116
49
                let sendme_inc = reader.take_u8()?;
117
49
                NtorV3Extension::AckCongestionControl { sendme_inc }
118
            }
119
            x => {
120
                let mut data = vec![0; len as usize];
121
                reader.take_into(&mut data)?;
122
                NtorV3Extension::Unrecognized {
123
                    field_type: x,
124
                    data,
125
                }
126
            }
127
        })
128
98
    }
129
}