1
//! Types and encodings used during circuit extension.
2

            
3
use super::extlist::{decl_extension_group, Ext, ExtList, ExtListRef};
4
#[cfg(feature = "hs")]
5
use super::hs::pow::ProofOfWork;
6
use caret::caret_int;
7
use tor_bytes::{EncodeResult, Reader, Writeable as _, Writer};
8

            
9
caret_int! {
10
    /// A type of circuit request extension data (`EXT_FIELD_TYPE`).
11
    #[derive(PartialOrd,Ord)]
12
    pub struct CircRequestExtType(u8) {
13
        /// Request congestion control be enabled for a circuit.
14
        CC_REQUEST = 1,
15
        /// HS only: provide a completed proof-of-work solution for denial of service
16
        /// mitigation
17
        PROOF_OF_WORK = 2,
18
        /// Request that certain subprotocol features be enabled.
19
        SUBPROTOCOL_REQUEST = 3,
20
    }
21
}
22

            
23
caret_int! {
24
    /// A type of circuit response extension data (`EXT_FIELD_TYPE`).
25
    #[derive(PartialOrd,Ord)]
26
    pub struct CircResponseExtType(u8) {
27
        /// Acknowledge a congestion control request.
28
        CC_RESPONSE = 2
29
    }
30
}
31

            
32
/// Request congestion control be enabled for this circuit (client → exit node).
33
///
34
/// (`EXT_FIELD_TYPE` = 01)
35
#[derive(Clone, Debug, PartialEq, Eq, Default)]
36
#[non_exhaustive]
37
pub struct CcRequest {}
38

            
39
impl Ext for CcRequest {
40
    type Id = CircRequestExtType;
41
275
    fn type_id(&self) -> Self::Id {
42
275
        CircRequestExtType::CC_REQUEST
43
275
    }
44
275
    fn take_body_from(_b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
45
275
        Ok(Self {})
46
275
    }
47
10
    fn write_body_onto<B: Writer + ?Sized>(&self, _b: &mut B) -> EncodeResult<()> {
48
10
        Ok(())
49
10
    }
50
}
51

            
52
/// Acknowledge a congestion control request (exit node → client).
53
///
54
/// (`EXT_FIELD_TYPE` = 02)
55
#[derive(Clone, Debug, PartialEq, Eq)]
56
pub struct CcResponse {
57
    /// The exit's current view of the `cc_sendme_inc` consensus parameter.
58
    sendme_inc: u8,
59
}
60

            
61
impl CcResponse {
62
    /// Create a new AckCongestionControl with a given value for the
63
    /// `sendme_inc` parameter.
64
275
    pub fn new(sendme_inc: u8) -> Self {
65
275
        CcResponse { sendme_inc }
66
275
    }
67

            
68
    /// Return the value of the `sendme_inc` parameter for this extension.
69
220
    pub fn sendme_inc(&self) -> u8 {
70
220
        self.sendme_inc
71
220
    }
72
}
73

            
74
impl Ext for CcResponse {
75
    type Id = CircResponseExtType;
76
275
    fn type_id(&self) -> Self::Id {
77
275
        CircResponseExtType::CC_RESPONSE
78
275
    }
79

            
80
275
    fn take_body_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
81
275
        let sendme_inc = b.take_u8()?;
82
275
        Ok(Self { sendme_inc })
83
275
    }
84

            
85
10
    fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
86
10
        b.write_u8(self.sendme_inc);
87
10
        Ok(())
88
10
    }
89
}
90

            
91
/// A request that a certain set of protocols should be enabled. (client to server)
92
#[derive(Clone, Debug, PartialEq, Eq)]
93
pub struct SubprotocolRequest {
94
    /// The protocols to enable.
95
    protocols: Vec<tor_protover::NumberedSubver>,
96
}
97

            
98
impl<A> FromIterator<A> for SubprotocolRequest
99
where
100
    A: Into<tor_protover::NumberedSubver>,
101
{
102
4
    fn from_iter<T: IntoIterator<Item = A>>(iter: T) -> Self {
103
4
        let mut protocols: Vec<_> = iter.into_iter().map(Into::into).collect();
104
4
        protocols.sort();
105
4
        protocols.dedup();
106
4
        Self { protocols }
107
4
    }
108
}
109

            
110
impl Ext for SubprotocolRequest {
111
    type Id = CircRequestExtType;
112

            
113
    fn type_id(&self) -> Self::Id {
114
        CircRequestExtType::SUBPROTOCOL_REQUEST
115
    }
116

            
117
8
    fn take_body_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
118
8
        let mut protocols = Vec::new();
119
22
        while b.remaining() != 0 {
120
16
            protocols.push(b.extract()?);
121
        }
122
6
        let protocols_orig = protocols.clone();
123
6
        // TODO MSRV 1.82: Use is_sorted, and avoid creating protocols_orig.
124
6
        protocols.sort();
125
6
        protocols.dedup();
126
6
        if protocols_orig != protocols {
127
4
            return Err(tor_bytes::Error::InvalidMessage(
128
4
                "SubprotocolRequest not sorted and deduplicated.".into(),
129
4
            ));
130
2
        }
131
2
        Ok(Self { protocols })
132
8
    }
133

            
134
2
    fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
135
4
        for p in self.protocols.iter() {
136
4
            b.write(p)?;
137
        }
138
2
        Ok(())
139
2
    }
140
}
141
impl SubprotocolRequest {
142
    /// Return true if this [`SubprotocolRequest`] contains the listed capability.
143
4
    pub fn contains(&self, cap: tor_protover::NamedSubver) -> bool {
144
4
        self.protocols.binary_search(&cap.into()).is_ok()
145
4
    }
146

            
147
    /// Return true if this [`SubprotocolRequest`] contains no other
148
    /// capabilities except those listed in `list`.
149
10
    pub fn contains_only(&self, list: &tor_protover::Protocols) -> bool {
150
10
        self.protocols
151
10
            .iter()
152
23
            .all(|p| list.supports_numbered_subver(*p))
153
10
    }
154
}
155

            
156
decl_extension_group! {
157
    /// An extension to be sent along with a circuit extension request
158
    /// (CREATE2, EXTEND2, or INTRODUCE.)
159
    #[derive(Debug,Clone,PartialEq)]
160
    #[non_exhaustive]
161
    pub enum CircRequestExt [ CircRequestExtType ] {
162
        /// Request to enable congestion control.
163
        CcRequest,
164
        /// HS-only: Provide a proof-of-work solution.
165
        [ feature: #[cfg(feature = "hs")] ]
166
        ProofOfWork,
167
        /// Request to enable one or more subprotocol capabilities.
168
        SubprotocolRequest,
169
    }
170
}
171

            
172
decl_extension_group! {
173
    /// An extension to be sent along with a circuit extension response
174
    /// (CREATED2 or EXTENDED2.)
175
    ///
176
    /// RENDEZVOUS is not currently supported, but once we replace hs-ntor
177
    /// with something better, extensions will be possible there too.
178
    #[derive(Debug,Clone,PartialEq)]
179
    #[non_exhaustive]
180
    pub enum CircResponseExt [ CircResponseExtType ] {
181
        /// Response indicating that congestion control is enabled.
182
        CcResponse,
183
    }
184
}
185

            
186
macro_rules! impl_encode_decode {
187
    ($extgroup:ty, $name:expr) => {
188
        impl $extgroup {
189
            /// Encode a set of extensions into a "message" for a circuit handshake.
190
56
            pub fn write_many_onto<W: Writer>(exts: &[Self], out: &mut W) -> EncodeResult<()> {
191
56
                ExtListRef::from(exts).write_onto(out)?;
192
56
                Ok(())
193
56
            }
194
            /// Decode a slice of bytes representing the "message" of a circuit handshake into a set of
195
            /// extensions.
196
1540
            pub fn decode(message: &[u8]) -> crate::Result<Vec<Self>> {
197
1540
                let err_cvt = |err| crate::Error::BytesErr { err, parsed: $name };
198
1540
                let mut r = tor_bytes::Reader::from_slice(message);
199
1540
                let list: ExtList<_> = r.extract().map_err(err_cvt)?;
200
1540
                r.should_be_exhausted().map_err(err_cvt)?;
201
1540
                Ok(list.into_vec())
202
1540
            }
203
        }
204
    };
205
}
206

            
207
impl_encode_decode!(CircRequestExt, "CREATE2 extension list");
208
impl_encode_decode!(CircResponseExt, "CREATED2 extension list");
209

            
210
#[cfg(test)]
211
mod test {
212
    // @@ begin test lint list maintained by maint/add_warning @@
213
    #![allow(clippy::bool_assert_comparison)]
214
    #![allow(clippy::clone_on_copy)]
215
    #![allow(clippy::dbg_macro)]
216
    #![allow(clippy::mixed_attributes_style)]
217
    #![allow(clippy::print_stderr)]
218
    #![allow(clippy::print_stdout)]
219
    #![allow(clippy::single_char_pattern)]
220
    #![allow(clippy::unwrap_used)]
221
    #![allow(clippy::unchecked_duration_subtraction)]
222
    #![allow(clippy::useless_vec)]
223
    #![allow(clippy::needless_pass_by_value)]
224
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
225
    use super::*;
226

            
227
    #[test]
228
    fn subproto_ext_valid() {
229
        use tor_protover::named::*;
230
        let sp: SubprotocolRequest = [RELAY_NTORV3, RELAY_NTORV3, LINK_V4].into_iter().collect();
231
        let mut v = Vec::new();
232
        sp.write_body_onto(&mut v).unwrap();
233
        assert_eq!(&v[..], [0, 4, 2, 4]);
234

            
235
        let mut r = Reader::from_slice(&v[..]);
236
        let sp2: SubprotocolRequest = SubprotocolRequest::take_body_from(&mut r).unwrap();
237
        assert_eq!(sp, sp2);
238
    }
239

            
240
    #[test]
241
    fn subproto_invalid() {
242
        // Odd length.
243
        let mut r = Reader::from_slice(&[0, 4, 2]);
244
        let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
245
        dbg!(e.to_string());
246
        assert!(e.to_string().contains("too short"));
247

            
248
        // Duplicate protocols.
249
        let mut r = Reader::from_slice(&[0, 4, 0, 4]);
250
        let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
251
        dbg!(e.to_string());
252
        assert!(e.to_string().contains("deduplicated"));
253

            
254
        // not-sorted protocols.
255
        let mut r = Reader::from_slice(&[2, 4, 0, 4]);
256
        let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
257
        dbg!(e.to_string());
258
        assert!(e.to_string().contains("sorted"));
259
    }
260

            
261
    #[test]
262
    fn subproto_supported() {
263
        use tor_protover::named::*;
264
        let sp: SubprotocolRequest = [RELAY_NTORV3, RELAY_NTORV3, LINK_V4].into_iter().collect();
265
        // "contains" tells us if a subprotocol capability is a member of the request.
266
        assert!(sp.contains(LINK_V4));
267
        assert!(!sp.contains(LINK_V2));
268

            
269
        // contains_only tells us if there are any subprotocol capabilities in the request
270
        // other than those listed.
271
        assert!(sp.contains_only(&[RELAY_NTORV3, LINK_V4, CONFLUX_BASE].into_iter().collect()));
272
        assert!(sp.contains_only(&[RELAY_NTORV3, LINK_V4].into_iter().collect()));
273
        assert!(!sp.contains_only(&[LINK_V4].into_iter().collect()));
274
        assert!(!sp.contains_only(&[LINK_V4, CONFLUX_BASE].into_iter().collect()));
275
        assert!(!sp.contains_only(&[CONFLUX_BASE].into_iter().collect()));
276
    }
277
}