tor_cell/relaycell/
extend.rs
1use super::extlist::{decl_extension_group, Ext, ExtList, ExtListRef};
4#[cfg(feature = "hs")]
5use super::hs::pow::ProofOfWork;
6use caret::caret_int;
7use tor_bytes::{EncodeResult, Reader, Writeable as _, Writer};
8
9caret_int! {
10 #[derive(PartialOrd,Ord)]
12 pub struct CircRequestExtType(u8) {
13 CC_REQUEST = 1,
15 PROOF_OF_WORK = 2,
18 SUBPROTOCOL_REQUEST = 3,
20 }
21}
22
23caret_int! {
24 #[derive(PartialOrd,Ord)]
26 pub struct CircResponseExtType(u8) {
27 CC_RESPONSE = 2
29 }
30}
31
32#[derive(Clone, Debug, PartialEq, Eq, Default)]
36#[non_exhaustive]
37pub struct CcRequest {}
38
39impl Ext for CcRequest {
40 type Id = CircRequestExtType;
41 fn type_id(&self) -> Self::Id {
42 CircRequestExtType::CC_REQUEST
43 }
44 fn take_body_from(_b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
45 Ok(Self {})
46 }
47 fn write_body_onto<B: Writer + ?Sized>(&self, _b: &mut B) -> EncodeResult<()> {
48 Ok(())
49 }
50}
51
52#[derive(Clone, Debug, PartialEq, Eq)]
56pub struct CcResponse {
57 sendme_inc: u8,
59}
60
61impl CcResponse {
62 pub fn new(sendme_inc: u8) -> Self {
65 CcResponse { sendme_inc }
66 }
67
68 pub fn sendme_inc(&self) -> u8 {
70 self.sendme_inc
71 }
72}
73
74impl Ext for CcResponse {
75 type Id = CircResponseExtType;
76 fn type_id(&self) -> Self::Id {
77 CircResponseExtType::CC_RESPONSE
78 }
79
80 fn take_body_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
81 let sendme_inc = b.take_u8()?;
82 Ok(Self { sendme_inc })
83 }
84
85 fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
86 b.write_u8(self.sendme_inc);
87 Ok(())
88 }
89}
90
91#[derive(Clone, Debug, PartialEq, Eq)]
93pub struct SubprotocolRequest {
94 protocols: Vec<tor_protover::NumberedSubver>,
96}
97
98impl<A> FromIterator<A> for SubprotocolRequest
99where
100 A: Into<tor_protover::NumberedSubver>,
101{
102 fn from_iter<T: IntoIterator<Item = A>>(iter: T) -> Self {
103 let mut protocols: Vec<_> = iter.into_iter().map(Into::into).collect();
104 protocols.sort();
105 protocols.dedup();
106 Self { protocols }
107 }
108}
109
110impl Ext for SubprotocolRequest {
111 type Id = CircRequestExtType;
112
113 fn type_id(&self) -> Self::Id {
114 CircRequestExtType::SUBPROTOCOL_REQUEST
115 }
116
117 fn take_body_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
118 let mut protocols = Vec::new();
119 while b.remaining() != 0 {
120 protocols.push(b.extract()?);
121 }
122 let protocols_orig = protocols.clone();
123 protocols.sort();
125 protocols.dedup();
126 if protocols_orig != protocols {
127 return Err(tor_bytes::Error::InvalidMessage(
128 "SubprotocolRequest not sorted and deduplicated.".into(),
129 ));
130 }
131 Ok(Self { protocols })
132 }
133
134 fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
135 for p in self.protocols.iter() {
136 b.write(p)?;
137 }
138 Ok(())
139 }
140}
141impl SubprotocolRequest {
142 pub fn contains(&self, cap: tor_protover::NamedSubver) -> bool {
144 self.protocols.binary_search(&cap.into()).is_ok()
145 }
146
147 pub fn contains_only(&self, list: &tor_protover::Protocols) -> bool {
150 self.protocols
151 .iter()
152 .all(|p| list.supports_numbered_subver(*p))
153 }
154}
155
156decl_extension_group! {
157 #[derive(Debug,Clone,PartialEq)]
160 #[non_exhaustive]
161 pub enum CircRequestExt [ CircRequestExtType ] {
162 CcRequest,
164 [ feature: #[cfg(feature = "hs")] ]
166 ProofOfWork,
167 SubprotocolRequest,
169 }
170}
171
172decl_extension_group! {
173 #[derive(Debug,Clone,PartialEq)]
179 #[non_exhaustive]
180 pub enum CircResponseExt [ CircResponseExtType ] {
181 CcResponse,
183 }
184}
185
186macro_rules! impl_encode_decode {
187 ($extgroup:ty, $name:expr) => {
188 impl $extgroup {
189 pub fn write_many_onto<W: Writer>(exts: &[Self], out: &mut W) -> EncodeResult<()> {
191 ExtListRef::from(exts).write_onto(out)?;
192 Ok(())
193 }
194 pub fn decode(message: &[u8]) -> crate::Result<Vec<Self>> {
197 let err_cvt = |err| crate::Error::BytesErr { err, parsed: $name };
198 let mut r = tor_bytes::Reader::from_slice(message);
199 let list: ExtList<_> = r.extract().map_err(err_cvt)?;
200 r.should_be_exhausted().map_err(err_cvt)?;
201 Ok(list.into_vec())
202 }
203 }
204 };
205}
206
207impl_encode_decode!(CircRequestExt, "CREATE2 extension list");
208impl_encode_decode!(CircResponseExt, "CREATED2 extension list");
209
210#[cfg(test)]
211mod test {
212 #![allow(clippy::bool_assert_comparison)]
214 #![allow(clippy::clone_on_copy)]
215 #![allow(clippy::dbg_macro)]
216 #![allow(clippy::mixed_attributes_style)]
217 #![allow(clippy::print_stderr)]
218 #![allow(clippy::print_stdout)]
219 #![allow(clippy::single_char_pattern)]
220 #![allow(clippy::unwrap_used)]
221 #![allow(clippy::unchecked_duration_subtraction)]
222 #![allow(clippy::useless_vec)]
223 #![allow(clippy::needless_pass_by_value)]
224 use super::*;
226
227 #[test]
228 fn subproto_ext_valid() {
229 use tor_protover::named::*;
230 let sp: SubprotocolRequest = [RELAY_NTORV3, RELAY_NTORV3, LINK_V4].into_iter().collect();
231 let mut v = Vec::new();
232 sp.write_body_onto(&mut v).unwrap();
233 assert_eq!(&v[..], [0, 4, 2, 4]);
234
235 let mut r = Reader::from_slice(&v[..]);
236 let sp2: SubprotocolRequest = SubprotocolRequest::take_body_from(&mut r).unwrap();
237 assert_eq!(sp, sp2);
238 }
239
240 #[test]
241 fn subproto_invalid() {
242 let mut r = Reader::from_slice(&[0, 4, 2]);
244 let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
245 dbg!(e.to_string());
246 assert!(e.to_string().contains("too short"));
247
248 let mut r = Reader::from_slice(&[0, 4, 0, 4]);
250 let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
251 dbg!(e.to_string());
252 assert!(e.to_string().contains("deduplicated"));
253
254 let mut r = Reader::from_slice(&[2, 4, 0, 4]);
256 let e = SubprotocolRequest::take_body_from(&mut r).unwrap_err();
257 dbg!(e.to_string());
258 assert!(e.to_string().contains("sorted"));
259 }
260
261 #[test]
262 fn subproto_supported() {
263 use tor_protover::named::*;
264 let sp: SubprotocolRequest = [RELAY_NTORV3, RELAY_NTORV3, LINK_V4].into_iter().collect();
265 assert!(sp.contains(LINK_V4));
267 assert!(!sp.contains(LINK_V2));
268
269 assert!(sp.contains_only(&[RELAY_NTORV3, LINK_V4, CONFLUX_BASE].into_iter().collect()));
272 assert!(sp.contains_only(&[RELAY_NTORV3, LINK_V4].into_iter().collect()));
273 assert!(!sp.contains_only(&[LINK_V4].into_iter().collect()));
274 assert!(!sp.contains_only(&[LINK_V4, CONFLUX_BASE].into_iter().collect()));
275 assert!(!sp.contains_only(&[CONFLUX_BASE].into_iter().collect()));
276 }
277}