1
//! Types and encodings used during circuit extension.
2

            
3
use super::extlist::{Ext, ExtList, ExtListRef, decl_extension_group};
4
#[cfg(feature = "hs")]
5
use super::hs::pow::ProofOfWork;
6
use caret::caret_int;
7
use itertools::Itertools as _;
8
use tor_bytes::{EncodeResult, Reader, Writeable as _, Writer};
9
use tor_protover::NumberedSubver;
10

            
11
caret_int! {
12
    /// A type of circuit request extension data (`EXT_FIELD_TYPE`).
13
    #[derive(PartialOrd,Ord)]
14
    pub struct CircRequestExtType(u8) {
15
        /// Request congestion control be enabled for a circuit.
16
        CC_REQUEST = 1,
17
        /// HS only: provide a completed proof-of-work solution for denial of service
18
        /// mitigation
19
        PROOF_OF_WORK = 2,
20
        /// Request that certain subprotocol features be enabled.
21
        SUBPROTOCOL_REQUEST = 3,
22
    }
23
}
24

            
25
caret_int! {
26
    /// A type of circuit response extension data (`EXT_FIELD_TYPE`).
27
    #[derive(PartialOrd,Ord)]
28
    pub struct CircResponseExtType(u8) {
29
        /// Acknowledge a congestion control request.
30
        CC_RESPONSE = 2
31
    }
32
}
33

            
34
/// Request congestion control be enabled for this circuit (client → exit node).
35
///
36
/// (`EXT_FIELD_TYPE` = 01)
37
#[derive(Clone, Debug, PartialEq, Eq, Default)]
38
#[non_exhaustive]
39
pub struct CcRequest {}
40

            
41
impl Ext for CcRequest {
42
    type Id = CircRequestExtType;
43
399
    fn type_id(&self) -> Self::Id {
44
399
        CircRequestExtType::CC_REQUEST
45
399
    }
46
399
    fn take_body_from(_b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
47
399
        Ok(Self {})
48
399
    }
49
14
    fn write_body_onto<B: Writer + ?Sized>(&self, _b: &mut B) -> EncodeResult<()> {
50
14
        Ok(())
51
14
    }
52
}
53

            
54
/// Acknowledge a congestion control request (exit node → client).
55
///
56
/// (`EXT_FIELD_TYPE` = 02)
57
#[derive(Clone, Debug, PartialEq, Eq)]
58
pub struct CcResponse {
59
    /// The exit's current view of the `cc_sendme_inc` consensus parameter.
60
    sendme_inc: u8,
61
}
62

            
63
impl CcResponse {
64
    /// Create a new AckCongestionControl with a given value for the
65
    /// `sendme_inc` parameter.
66
399
    pub fn new(sendme_inc: u8) -> Self {
67
399
        CcResponse { sendme_inc }
68
399
    }
69

            
70
    /// Return the value of the `sendme_inc` parameter for this extension.
71
342
    pub fn sendme_inc(&self) -> u8 {
72
342
        self.sendme_inc
73
342
    }
74
}
75

            
76
impl Ext for CcResponse {
77
    type Id = CircResponseExtType;
78
399
    fn type_id(&self) -> Self::Id {
79
399
        CircResponseExtType::CC_RESPONSE
80
399
    }
81

            
82
399
    fn take_body_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
83
399
        let sendme_inc = b.take_u8()?;
84
399
        Ok(Self { sendme_inc })
85
399
    }
86

            
87
14
    fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
88
14
        b.write_u8(self.sendme_inc);
89
14
        Ok(())
90
14
    }
91
}
92

            
93
/// A request that a certain set of protocols should be enabled. (client to server)
94
#[derive(Clone, Debug, PartialEq, Eq)]
95
pub struct SubprotocolRequest {
96
    /// The protocols to enable.
97
    protocols: Vec<tor_protover::NumberedSubver>,
98
}
99

            
100
impl<A> FromIterator<A> for SubprotocolRequest
101
where
102
    A: Into<tor_protover::NumberedSubver>,
103
{
104
4
    fn from_iter<T: IntoIterator<Item = A>>(iter: T) -> Self {
105
4
        let mut protocols: Vec<_> = iter.into_iter().map(Into::into).collect();
106
4
        protocols.sort();
107
4
        protocols.dedup();
108
4
        Self { protocols }
109
4
    }
110
}
111

            
112
impl Ext for SubprotocolRequest {
113
    type Id = CircRequestExtType;
114

            
115
    fn type_id(&self) -> Self::Id {
116
        CircRequestExtType::SUBPROTOCOL_REQUEST
117
    }
118

            
119
8
    fn take_body_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
120
8
        let mut protocols = Vec::new();
121
22
        while b.remaining() != 0 {
122
16
            protocols.push(b.extract()?);
123
        }
124

            
125
6
        if !is_strictly_ascending(&protocols) {
126
4
            return Err(tor_bytes::Error::InvalidMessage(
127
4
                "SubprotocolRequest not sorted and deduplicated.".into(),
128
4
            ));
129
2
        }
130

            
131
2
        Ok(Self { protocols })
132
8
    }
133

            
134
2
    fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
135
4
        for p in self.protocols.iter() {
136
4
            b.write(p)?;
137
        }
138
2
        Ok(())
139
2
    }
140
}
141
impl SubprotocolRequest {
142
    /// Return true if this [`SubprotocolRequest`] contains the listed capability.
143
4
    pub fn contains(&self, cap: tor_protover::NamedSubver) -> bool {
144
4
        self.protocols.binary_search(&cap.into()).is_ok()
145
4
    }
146

            
147
    /// Return true if this [`SubprotocolRequest`] contains no other
148
    /// capabilities except those listed in `list`.
149
10
    pub fn contains_only(&self, list: &tor_protover::Protocols) -> bool {
150
10
        self.protocols
151
10
            .iter()
152
23
            .all(|p| list.supports_numbered_subver(*p))
153
10
    }
154
}
155

            
156
decl_extension_group! {
157
    /// An extension to be sent along with a circuit extension request
158
    /// (CREATE2, EXTEND2, or INTRODUCE.)
159
    #[derive(Debug,Clone,PartialEq)]
160
    #[non_exhaustive]
161
    pub enum CircRequestExt [ CircRequestExtType ] {
162
        /// Request to enable congestion control.
163
        CcRequest,
164
        /// HS-only: Provide a proof-of-work solution.
165
        [ feature: #[cfg(feature = "hs")] ]
166
        ProofOfWork,
167
        /// Request to enable one or more subprotocol capabilities.
168
        SubprotocolRequest,
169
    }
170
}
171

            
172
decl_extension_group! {
173
    /// An extension to be sent along with a circuit extension response
174
    /// (CREATED2 or EXTENDED2.)
175
    ///
176
    /// RENDEZVOUS is not currently supported, but once we replace hs-ntor
177
    /// with something better, extensions will be possible there too.
178
    #[derive(Debug,Clone,PartialEq)]
179
    #[non_exhaustive]
180
    pub enum CircResponseExt [ CircResponseExtType ] {
181
        /// Response indicating that congestion control is enabled.
182
        CcResponse,
183
    }
184
}
185

            
186
macro_rules! impl_encode_decode {
187
    ($extgroup:ty, $name:expr) => {
188
        impl $extgroup {
189
            /// Encode a set of extensions into a "message" for a circuit handshake.
190
80
            pub fn write_many_onto<W: Writer>(exts: &[Self], out: &mut W) -> EncodeResult<()> {
191
80
                ExtListRef::from(exts).write_onto(out)?;
192
80
                Ok(())
193
80
            }
194
            /// Decode a slice of bytes representing the "message" of a circuit handshake into a set of
195
            /// extensions.
196
2280
            pub fn decode(message: &[u8]) -> crate::Result<Vec<Self>> {
197
2280
                let err_cvt = |err| crate::Error::BytesErr { err, parsed: $name };
198
2280
                let mut r = tor_bytes::Reader::from_slice(message);
199
2280
                let list: ExtList<_> = r.extract().map_err(err_cvt)?;
200
2280
                r.should_be_exhausted().map_err(err_cvt)?;
201
2280
                Ok(list.into_vec())
202
2280
            }
203
        }
204
    };
205
}
206

            
207
impl_encode_decode!(CircRequestExt, "CREATE2 extension list");
208
impl_encode_decode!(CircResponseExt, "CREATED2 extension list");
209

            
210
/// Return true iff the list of protocol capabilities is strictly ascending.
211
6
fn is_strictly_ascending(vers: &[NumberedSubver]) -> bool {
212
    // We don't use is_sorted, since that doesn't detect duplicates.
213
9
    vers.iter().tuple_windows().all(|(a, b)| a < b)
214
6
}
215

            
216
#[cfg(test)]
217
mod test {
218
    // @@ begin test lint list maintained by maint/add_warning @@
219
    #![allow(clippy::bool_assert_comparison)]
220
    #![allow(clippy::clone_on_copy)]
221
    #![allow(clippy::dbg_macro)]
222
    #![allow(clippy::mixed_attributes_style)]
223
    #![allow(clippy::print_stderr)]
224
    #![allow(clippy::print_stdout)]
225
    #![allow(clippy::single_char_pattern)]
226
    #![allow(clippy::unwrap_used)]
227
    #![allow(clippy::unchecked_time_subtraction)]
228
    #![allow(clippy::useless_vec)]
229
    #![allow(clippy::needless_pass_by_value)]
230
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
231
    use super::*;
232

            
233
    #[test]
234
    fn subproto_ext_valid() {
235
        use tor_protover::named::*;
236
        let sp: SubprotocolRequest = [RELAY_NTORV3, RELAY_NTORV3, LINK_V4].into_iter().collect();
237
        let mut v = Vec::new();
238
        sp.write_body_onto(&mut v).unwrap();
239
        assert_eq!(&v[..], [0, 4, 2, 4]);
240

            
241
        let mut r = Reader::from_slice(&v[..]);
242
        let sp2: SubprotocolRequest = SubprotocolRequest::take_body_from(&mut r).unwrap();
243
        assert_eq!(sp, sp2);
244
    }
245

            
246
    #[test]
247
    fn subproto_invalid() {
248
        // Odd length.
249
        let mut r = Reader::from_slice(&[0, 4, 2]);
250
        let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
251
        dbg!(e.to_string());
252
        assert!(e.to_string().contains("too short"));
253

            
254
        // Duplicate protocols.
255
        let mut r = Reader::from_slice(&[0, 4, 0, 4]);
256
        let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
257
        dbg!(e.to_string());
258
        assert!(e.to_string().contains("deduplicated"));
259

            
260
        // not-sorted protocols.
261
        let mut r = Reader::from_slice(&[2, 4, 0, 4]);
262
        let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
263
        dbg!(e.to_string());
264
        assert!(e.to_string().contains("sorted"));
265
    }
266

            
267
    #[test]
268
    fn subproto_supported() {
269
        use tor_protover::named::*;
270
        let sp: SubprotocolRequest = [RELAY_NTORV3, RELAY_NTORV3, LINK_V4].into_iter().collect();
271
        // "contains" tells us if a subprotocol capability is a member of the request.
272
        assert!(sp.contains(LINK_V4));
273
        assert!(!sp.contains(LINK_V2));
274

            
275
        // contains_only tells us if there are any subprotocol capabilities in the request
276
        // other than those listed.
277
        assert!(sp.contains_only(&[RELAY_NTORV3, LINK_V4, CONFLUX_BASE].into_iter().collect()));
278
        assert!(sp.contains_only(&[RELAY_NTORV3, LINK_V4].into_iter().collect()));
279
        assert!(!sp.contains_only(&[LINK_V4].into_iter().collect()));
280
        assert!(!sp.contains_only(&[LINK_V4, CONFLUX_BASE].into_iter().collect()));
281
        assert!(!sp.contains_only(&[CONFLUX_BASE].into_iter().collect()));
282
    }
283
}