tor_cell/relaycell/
extend.rs
        
        
        
        1use super::extlist::{Ext, ExtList, ExtListRef, decl_extension_group};
4#[cfg(feature = "hs")]
5use super::hs::pow::ProofOfWork;
6use caret::caret_int;
7use itertools::Itertools as _;
8use tor_bytes::{EncodeResult, Reader, Writeable as _, Writer};
9use tor_protover::NumberedSubver;
10
11caret_int! {
12    #[derive(PartialOrd,Ord)]
14    pub struct CircRequestExtType(u8) {
15        CC_REQUEST = 1,
17        PROOF_OF_WORK = 2,
20        SUBPROTOCOL_REQUEST = 3,
22    }
23}
24
25caret_int! {
26    #[derive(PartialOrd,Ord)]
28    pub struct CircResponseExtType(u8) {
29        CC_RESPONSE = 2
31    }
32}
33
34#[derive(Clone, Debug, PartialEq, Eq, Default)]
38#[non_exhaustive]
39pub struct CcRequest {}
40
41impl Ext for CcRequest {
42    type Id = CircRequestExtType;
43    fn type_id(&self) -> Self::Id {
44        CircRequestExtType::CC_REQUEST
45    }
46    fn take_body_from(_b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
47        Ok(Self {})
48    }
49    fn write_body_onto<B: Writer + ?Sized>(&self, _b: &mut B) -> EncodeResult<()> {
50        Ok(())
51    }
52}
53
54#[derive(Clone, Debug, PartialEq, Eq)]
58pub struct CcResponse {
59    sendme_inc: u8,
61}
62
63impl CcResponse {
64    pub fn new(sendme_inc: u8) -> Self {
67        CcResponse { sendme_inc }
68    }
69
70    pub fn sendme_inc(&self) -> u8 {
72        self.sendme_inc
73    }
74}
75
76impl Ext for CcResponse {
77    type Id = CircResponseExtType;
78    fn type_id(&self) -> Self::Id {
79        CircResponseExtType::CC_RESPONSE
80    }
81
82    fn take_body_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
83        let sendme_inc = b.take_u8()?;
84        Ok(Self { sendme_inc })
85    }
86
87    fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
88        b.write_u8(self.sendme_inc);
89        Ok(())
90    }
91}
92
93#[derive(Clone, Debug, PartialEq, Eq)]
95pub struct SubprotocolRequest {
96    protocols: Vec<tor_protover::NumberedSubver>,
98}
99
100impl<A> FromIterator<A> for SubprotocolRequest
101where
102    A: Into<tor_protover::NumberedSubver>,
103{
104    fn from_iter<T: IntoIterator<Item = A>>(iter: T) -> Self {
105        let mut protocols: Vec<_> = iter.into_iter().map(Into::into).collect();
106        protocols.sort();
107        protocols.dedup();
108        Self { protocols }
109    }
110}
111
112impl Ext for SubprotocolRequest {
113    type Id = CircRequestExtType;
114
115    fn type_id(&self) -> Self::Id {
116        CircRequestExtType::SUBPROTOCOL_REQUEST
117    }
118
119    fn take_body_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
120        let mut protocols = Vec::new();
121        while b.remaining() != 0 {
122            protocols.push(b.extract()?);
123        }
124
125        if !is_strictly_ascending(&protocols) {
126            return Err(tor_bytes::Error::InvalidMessage(
127                "SubprotocolRequest not sorted and deduplicated.".into(),
128            ));
129        }
130
131        Ok(Self { protocols })
132    }
133
134    fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
135        for p in self.protocols.iter() {
136            b.write(p)?;
137        }
138        Ok(())
139    }
140}
141impl SubprotocolRequest {
142    pub fn contains(&self, cap: tor_protover::NamedSubver) -> bool {
144        self.protocols.binary_search(&cap.into()).is_ok()
145    }
146
147    pub fn contains_only(&self, list: &tor_protover::Protocols) -> bool {
150        self.protocols
151            .iter()
152            .all(|p| list.supports_numbered_subver(*p))
153    }
154}
155
156decl_extension_group! {
157    #[derive(Debug,Clone,PartialEq)]
160    #[non_exhaustive]
161    pub enum CircRequestExt [ CircRequestExtType ] {
162        CcRequest,
164        [ feature: #[cfg(feature = "hs")] ]
166        ProofOfWork,
167        SubprotocolRequest,
169    }
170}
171
172decl_extension_group! {
173    #[derive(Debug,Clone,PartialEq)]
179    #[non_exhaustive]
180    pub enum CircResponseExt [ CircResponseExtType ] {
181        CcResponse,
183    }
184}
185
186macro_rules! impl_encode_decode {
187    ($extgroup:ty, $name:expr) => {
188        impl $extgroup {
189            pub fn write_many_onto<W: Writer>(exts: &[Self], out: &mut W) -> EncodeResult<()> {
191                ExtListRef::from(exts).write_onto(out)?;
192                Ok(())
193            }
194            pub fn decode(message: &[u8]) -> crate::Result<Vec<Self>> {
197                let err_cvt = |err| crate::Error::BytesErr { err, parsed: $name };
198                let mut r = tor_bytes::Reader::from_slice(message);
199                let list: ExtList<_> = r.extract().map_err(err_cvt)?;
200                r.should_be_exhausted().map_err(err_cvt)?;
201                Ok(list.into_vec())
202            }
203        }
204    };
205}
206
207impl_encode_decode!(CircRequestExt, "CREATE2 extension list");
208impl_encode_decode!(CircResponseExt, "CREATED2 extension list");
209
210fn is_strictly_ascending(vers: &[NumberedSubver]) -> bool {
212    vers.iter().tuple_windows().all(|(a, b)| a < b)
214}
215
216#[cfg(test)]
217mod test {
218    #![allow(clippy::bool_assert_comparison)]
220    #![allow(clippy::clone_on_copy)]
221    #![allow(clippy::dbg_macro)]
222    #![allow(clippy::mixed_attributes_style)]
223    #![allow(clippy::print_stderr)]
224    #![allow(clippy::print_stdout)]
225    #![allow(clippy::single_char_pattern)]
226    #![allow(clippy::unwrap_used)]
227    #![allow(clippy::unchecked_duration_subtraction)]
228    #![allow(clippy::useless_vec)]
229    #![allow(clippy::needless_pass_by_value)]
230    use super::*;
232
233    #[test]
234    fn subproto_ext_valid() {
235        use tor_protover::named::*;
236        let sp: SubprotocolRequest = [RELAY_NTORV3, RELAY_NTORV3, LINK_V4].into_iter().collect();
237        let mut v = Vec::new();
238        sp.write_body_onto(&mut v).unwrap();
239        assert_eq!(&v[..], [0, 4, 2, 4]);
240
241        let mut r = Reader::from_slice(&v[..]);
242        let sp2: SubprotocolRequest = SubprotocolRequest::take_body_from(&mut r).unwrap();
243        assert_eq!(sp, sp2);
244    }
245
246    #[test]
247    fn subproto_invalid() {
248        let mut r = Reader::from_slice(&[0, 4, 2]);
250        let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
251        dbg!(e.to_string());
252        assert!(e.to_string().contains("too short"));
253
254        let mut r = Reader::from_slice(&[0, 4, 0, 4]);
256        let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
257        dbg!(e.to_string());
258        assert!(e.to_string().contains("deduplicated"));
259
260        let mut r = Reader::from_slice(&[2, 4, 0, 4]);
262        let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
263        dbg!(e.to_string());
264        assert!(e.to_string().contains("sorted"));
265    }
266
267    #[test]
268    fn subproto_supported() {
269        use tor_protover::named::*;
270        let sp: SubprotocolRequest = [RELAY_NTORV3, RELAY_NTORV3, LINK_V4].into_iter().collect();
271        assert!(sp.contains(LINK_V4));
273        assert!(!sp.contains(LINK_V2));
274
275        assert!(sp.contains_only(&[RELAY_NTORV3, LINK_V4, CONFLUX_BASE].into_iter().collect()));
278        assert!(sp.contains_only(&[RELAY_NTORV3, LINK_V4].into_iter().collect()));
279        assert!(!sp.contains_only(&[LINK_V4].into_iter().collect()));
280        assert!(!sp.contains_only(&[LINK_V4, CONFLUX_BASE].into_iter().collect()));
281        assert!(!sp.contains_only(&[CONFLUX_BASE].into_iter().collect()));
282    }
283}