1
//! Helpers to manage lists of extensions within relay messages.
2
//!
3
//! These are used widely throughout the HS code,
4
//! but also in the ntor-v3 handshake.
5

            
6
use derive_deftly::Deftly;
7
use tor_bytes::{EncodeError, EncodeResult, Readable, Reader, Result, Writeable, Writer};
8
use tor_memquota::{HasMemoryCostStructural, derive_deftly_template_HasMemoryCost};
9

            
10
/// A list of extensions, represented in a common format used by many messages.
11
///
12
/// The common format is:
13
/// ```text
14
///      N_EXTENSIONS     [1 byte]
15
///      N_EXTENSIONS times:
16
///           EXT_FIELD_TYPE [1 byte]
17
///           EXT_FIELD_LEN  [1 byte]
18
///           EXT_FIELD      [EXT_FIELD_LEN bytes]
19
/// ```
20
///
21
/// It is subject to the additional restraints:
22
///
23
/// * Each extension type SHOULD be sent only once in a message.
24
/// * Parties MUST ignore any occurrences all occurrences of an extension
25
///   with a given type after the first such occurrence.
26
/// * Extensions SHOULD be sent in numerically ascending order by type.
27
#[derive(Clone, Debug, derive_more::Deref, derive_more::DerefMut, Deftly)]
28
#[derive_deftly(HasMemoryCost)]
29
#[deftly(has_memory_cost(bounds = "T: HasMemoryCostStructural"))]
30
pub(super) struct ExtList<T> {
31
    /// The extensions themselves.
32
    pub(super) extensions: Vec<T>,
33
}
34
impl<T> Default for ExtList<T> {
35
456
    fn default() -> Self {
36
456
        Self {
37
456
            extensions: Vec::new(),
38
456
        }
39
456
    }
40
}
41

            
42
/// As ExtList, but held by reference.
43
#[derive(Clone, Debug, derive_more::Deref, derive_more::DerefMut, derive_more::From)]
44
pub(super) struct ExtListRef<'a, T> {
45
    /// A reference to a slice of extensions.
46
    extensions: &'a [T],
47
}
48

            
49
/// A kind of extension that can be used with some kind of relay message.
50
///
51
/// Each extendible message will likely define its own enum,
52
/// implementing this trait,
53
/// representing the possible extensions.
54
pub(super) trait ExtGroup: Readable + Writeable {
55
    /// An identifier kind used with this sort of extension
56
    type Id: From<u8> + Into<u8> + Eq + PartialEq + Ord + Copy;
57
    /// The field-type id for this particular extension.
58
    fn type_id(&self) -> Self::Id;
59
}
60
/// A single typed extension that can be used with some kind of relay message.
61
pub(super) trait Ext: Sized {
62
    /// An identifier kind used with this sort of extension.
63
    ///
64
    /// Typically defined with caret_int.
65
    type Id: From<u8> + Into<u8>;
66
    /// The field-type id for this particular extension.
67
    fn type_id(&self) -> Self::Id;
68
    /// Extract the body (not the type or the length) from a single
69
    /// extension.
70
    fn take_body_from(b: &mut Reader<'_>) -> Result<Self>;
71
    /// Write the body (not the type or the length) for a single extension.
72
    fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()>;
73
}
74
impl<T: ExtGroup> Readable for ExtList<T> {
75
2850
    fn take_from(b: &mut Reader<'_>) -> Result<Self> {
76
2850
        let n_extensions = b.take_u8()?;
77
2850
        let extensions: Result<Vec<T>> = (0..n_extensions).map(|_| b.extract::<T>()).collect();
78
        Ok(Self {
79
2850
            extensions: extensions?,
80
        })
81
2850
    }
82
}
83
impl<'a, T: ExtGroup> Writeable for ExtListRef<'a, T> {
84
110
    fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
85
110
        let n_extensions = self
86
110
            .extensions
87
110
            .len()
88
110
            .try_into()
89
110
            .map_err(|_| EncodeError::BadLengthValue)?;
90
110
        b.write_u8(n_extensions);
91
110
        let mut exts_sorted: Vec<&T> = self.extensions.iter().collect();
92
110
        exts_sorted.sort_by_key(|ext| ext.type_id());
93
110
        exts_sorted.iter().try_for_each(|ext| ext.write_onto(b))?;
94
110
        Ok(())
95
110
    }
96
}
97
impl<T: ExtGroup> Writeable for ExtList<T> {
98
30
    fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
99
30
        ExtListRef::from(&self.extensions[..]).write_onto(b)
100
30
    }
101
}
102
impl<T: ExtGroup> ExtList<T> {
103
    /// Insert `ext` into this list of extensions, replacing any previous
104
    /// extension with the same field type ID.
105
    #[cfg(feature = "hs")] // currently, only used when "hs' is enabled.
106
171
    pub(super) fn replace_by_type(&mut self, ext: T) {
107
171
        self.retain(|e| e.type_id() != ext.type_id());
108
171
        self.push(ext);
109
171
    }
110
    /// Consume this ExtList and return its members as a vector.
111
2280
    pub(super) fn into_vec(self) -> Vec<T> {
112
2280
        self.extensions
113
2280
    }
114
}
115

            
116
/// An unrecognized or unencoded extension for some relay message.
117
#[derive(Clone, Debug, Deftly, Eq, PartialEq)]
118
#[derive_deftly(HasMemoryCost)]
119
// Use `Copy + 'static` and `#[deftly(has_memory_cost(copy))]` so that we don't
120
// need to derive HasMemoryCost for the id types, which are indeed all Copy.
121
#[deftly(has_memory_cost(bounds = "ID: Copy + 'static"))]
122
pub struct UnrecognizedExt<ID> {
123
    /// The field type ID for this extension.
124
    #[deftly(has_memory_cost(copy))]
125
    pub(super) type_id: ID,
126
    /// The body of this extension.
127
    pub(super) body: Vec<u8>,
128
}
129

            
130
impl<ID> UnrecognizedExt<ID> {
131
    /// Return a new unrecognized extension with a given ID and body.
132
    ///
133
    /// NOTE: nothing actually enforces that this type ID is not
134
    /// recognized.
135
    ///
136
    /// NOTE: This function accepts bodies longer than 255 bytes, but
137
    /// it is not possible to encode them.
138
2
    pub fn new(type_id: ID, body: impl Into<Vec<u8>>) -> Self {
139
2
        Self {
140
2
            type_id,
141
2
            body: body.into(),
142
2
        }
143
2
    }
144
}
145

            
146
/// Declare an Extension group that takes a given identifier.
147
//
148
// TODO: This is rather similar to restrict_msg(), isn't it?  Also, We use this
149
// pattern of (number, (cmd, length, body)*) a few of times in Tor outside the relaycell
150
// module.  Perhaps we can extend and unify our code here...
151
macro_rules! decl_extension_group {
152
    {
153
        $( #[$meta:meta] )*
154
        $v:vis enum $id:ident [ $type_id:ty ] {
155
            $(
156
                $(#[$cmeta:meta])*
157
                $([feature: #[$fmeta:meta]])?
158
                $case:ident),*
159
            $(,)?
160
        }
161
    } => {paste::paste!{
162
        $( #[$meta] )*
163
57
        $v enum $id {
164
57
            $( $(#[$cmeta])*
165
57
               $( #[$fmeta] )?
166
57
               $case($case),
167
57
            )*
168
57
            /// An extension of a type we do not recognize, or which we have not
169
57
            /// encoded.
170
57
            Unrecognized(crate::relaycell::extlist::UnrecognizedExt<$type_id>)
171
57
        }
172
57
        impl tor_bytes::Readable for $id {
173
1083
            fn take_from(b: &mut Reader<'_>) -> tor_bytes::Result<Self> {
174
57
                #[allow(unused)]
175
57
                use crate::relaycell::extlist::Ext as _;
176
1083
                let type_id = b.take_u8()?.into();
177
1083
                Ok(match type_id {
178
57
                    $(
179
57
                        $( #[$fmeta] )?
180
57
                        $type_id::[< $case:snake:upper >] => {
181
928
                            Self::$case( b.read_nested_u8len(|r| $case::take_body_from(r))? )
182
57
                        }
183
57
                    )*
184
57
                    _ => {
185
57
                        Self::Unrecognized(crate::relaycell::extlist::UnrecognizedExt {
186
171
                            type_id,
187
174
                            body: b.read_nested_u8len(|r| Ok(r.take_rest().into()))?,
188
57
                        })
189
57
                    }
190
57
                })
191
1083
            }
192
57
        }
193
57
        impl tor_bytes::Writeable for $id {
194
99
            fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> tor_bytes::EncodeResult<
195
53
()> {
196
57
                #![allow(unused_imports)]
197
57
                use crate::relaycell::extlist::Ext as _;
198
57
                use tor_bytes::Writeable as _;
199
57
                use std::ops::DerefMut;
200
99
                match self {
201
57
                    $(
202
57
                        $( #[$fmeta] )?
203
91
                        Self::$case(val) => {
204
91
                            b.write_u8(val.type_id().into());
205
91
                            let mut nested = b.write_nested_u8len();
206
91
                            val.write_body_onto(nested.deref_mut())?;
207
91
                            nested.finish()?;
208
57
                        }
209
57
                    )*
210
64
                    Self::Unrecognized(unrecognized) => {
211
64
                        b.write_u8(unrecognized.type_id.into());
212
64
                        let mut nested = b.write_nested_u8len();
213
64
                        nested.write_all(&unrecognized.body[..]);
214
64
                        nested.finish()?;
215
57
                    }
216
57
                }
217
99
                Ok(())
218
99
            }
219
57
        }
220
57
        impl crate::relaycell::extlist::ExtGroup for $id {
221
57
            type Id = $type_id;
222
456
            fn type_id(&self) -> Self::Id {
223
57
                #![allow(unused_imports)]
224
57
                use crate::relaycell::extlist::Ext as _;
225
456
                match self {
226
57
                    $(
227
57
                        $( #[$fmeta] )?
228
171
                        Self::$case(val) => val.type_id(),
229
57
                    )*
230
285
                    Self::Unrecognized(unrecognized) => unrecognized.type_id,
231
57
                }
232
456
            }
233
57
        }
234
57
        $(
235
57
        $( #[$fmeta] )?
236
57
        impl From<$case> for $id {
237
114
            fn from(val: $case) -> $id {
238
114
                $id :: $case ( val )
239
114
            }
240
57
        }
241
57
        )*
242
57
        impl From<crate::relaycell::extlist::UnrecognizedExt<$type_id>> for $id {
243
57
            fn from(val: crate::relaycell::extlist::UnrecognizedExt<$type_id>) -> $id {
244
57
                $id :: Unrecognized(val)
245
57
            }
246
        }
247
}}
248
}
249
pub(super) use decl_extension_group;