1
//! Helpers to manage lists of HS cell extensions.
2
//
3
// TODO: We might generalize this even more in the future to handle other
4
// similar lists in our cell protocol.
5

            
6
use derive_deftly::Deftly;
7
use tor_bytes::{EncodeError, EncodeResult, Readable, Reader, Result, Writeable, Writer};
8
use tor_memquota::{derive_deftly_template_HasMemoryCost, HasMemoryCostStructural};
9

            
10
/// A list of extensions, represented in a common format used by many HS-related
11
/// message.
12
///
13
/// The common format is:
14
/// ```text
15
///      N_EXTENSIONS     [1 byte]
16
///      N_EXTENSIONS times:
17
///           EXT_FIELD_TYPE [1 byte]
18
///           EXT_FIELD_LEN  [1 byte]
19
///           EXT_FIELD      [EXT_FIELD_LEN bytes]
20
/// ```
21
///
22
/// It is subject to the additional restraints:
23
///
24
/// * Each extension type SHOULD be sent only once in a message.
25
/// * Parties MUST ignore any occurrences all occurrences of an extension
26
///   with a given type after the first such occurrence.
27
/// * Extensions SHOULD be sent in numerically ascending order by type.
28
#[derive(Clone, Debug, derive_more::Deref, derive_more::DerefMut, Deftly)]
29
#[derive_deftly(HasMemoryCost)]
30
#[deftly(has_memory_cost(bounds = "T: HasMemoryCostStructural"))]
31
pub(super) struct ExtList<T> {
32
    /// The extensions themselves.
33
    extensions: Vec<T>,
34
}
35
impl<T> Default for ExtList<T> {
36
392
    fn default() -> Self {
37
392
        Self {
38
392
            extensions: Vec::new(),
39
392
        }
40
392
    }
41
}
42
/// An kind of extension that can be used with some kind of HS-related message.
43
///
44
/// Each extendible message will likely define its own enum,
45
/// implementing this trait,
46
/// representing the possible extensions.
47
pub(super) trait ExtGroup: Readable + Writeable {
48
    /// An identifier kind used with this sort of extension
49
    type Id: From<u8> + Into<u8> + Eq + PartialEq + Ord + Copy;
50
    /// The field-type id for this particular extension.
51
    fn type_id(&self) -> Self::Id;
52
}
53
/// A single typed extension that can be used with some kind of HS-related message.
54
pub(super) trait Ext: Sized {
55
    /// An identifier kind used with this sort of extension.
56
    ///
57
    /// Typically defined with caret_int.
58
    type Id: From<u8> + Into<u8>;
59
    /// The field-type id for this particular extension.
60
    fn type_id(&self) -> Self::Id;
61
    /// Extract the body (not the type or the length) from a single
62
    /// extension.
63
    fn take_body_from(b: &mut Reader<'_>) -> Result<Self>;
64
    /// Write the body (not the type or the length) for a single extension.
65
    fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()>;
66
}
67
impl<T: ExtGroup> Readable for ExtList<T> {
68
490
    fn take_from(b: &mut Reader<'_>) -> Result<Self> {
69
490
        let n_extensions = b.take_u8()?;
70
490
        let extensions: Result<Vec<T>> = (0..n_extensions).map(|_| b.extract::<T>()).collect();
71
490
        Ok(Self {
72
490
            extensions: extensions?,
73
        })
74
490
    }
75
}
76
impl<T: ExtGroup> Writeable for ExtList<T> {
77
30
    fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
78
30
        let n_extensions = self
79
30
            .extensions
80
30
            .len()
81
30
            .try_into()
82
30
            .map_err(|_| EncodeError::BadLengthValue)?;
83
30
        b.write_u8(n_extensions);
84
30
        let mut exts_sorted: Vec<&T> = self.extensions.iter().collect();
85
30
        exts_sorted.sort_by_key(|ext| ext.type_id());
86
30
        exts_sorted.iter().try_for_each(|ext| ext.write_onto(b))?;
87
30
        Ok(())
88
30
    }
89
}
90
impl<T: ExtGroup> ExtList<T> {
91
    /// Insert `ext` into this list of extensions, replacing any previous
92
    /// extension with the same field type ID.
93
147
    pub(super) fn replace_by_type(&mut self, ext: T) {
94
147
        self.retain(|e| e.type_id() != ext.type_id());
95
147
        self.push(ext);
96
147
    }
97
}
98

            
99
/// An unrecognized or unencoded extension for some HS-related message.
100
#[derive(Clone, Debug, Deftly)]
101
#[derive_deftly(HasMemoryCost)]
102
// Use `Copy + 'static` and `#[deftly(has_memory_cost(copy))]` so that we don't
103
// need to derive HasMemoryCost for the id types, which are indeed all Copy.
104
#[deftly(has_memory_cost(bounds = "ID: Copy + 'static"))]
105
pub struct UnrecognizedExt<ID> {
106
    /// The field type ID for this extension.
107
    #[deftly(has_memory_cost(copy))]
108
    pub(super) type_id: ID,
109
    /// The body of this extension.
110
    pub(super) body: Vec<u8>,
111
}
112

            
113
impl<ID> UnrecognizedExt<ID> {
114
    /// Return a new unrecognized extension with a given ID and body.
115
    ///
116
    /// NOTE: nothing actually enforces that this type ID is not
117
    /// recognized.
118
    ///
119
    /// NOTE: This function accepts bodies longer than 255 bytes, but
120
    /// it is not possible to encode them.
121
2
    pub fn new(type_id: ID, body: impl Into<Vec<u8>>) -> Self {
122
2
        Self {
123
2
            type_id,
124
2
            body: body.into(),
125
2
        }
126
2
    }
127
}
128

            
129
/// Declare an Extension group that takes a given identifier.
130
//
131
// TODO: This is rather similar to restrict_msg(), isn't it?  Also, We use this
132
// pattern of (number, (cmd, length, body)*) a few of times in Tor outside the
133
// hs module.  Perhaps we can extend and unify our code here...
134
macro_rules! decl_extension_group {
135
    {
136
        $( #[$meta:meta] )*
137
        $v:vis enum $id:ident [ $type_id:ty ] {
138
            $(
139
                $(#[$cmeta:meta])*
140
                $case:ident),*
141
            $(,)?
142
        }
143
    } => {paste::paste!{
144
        $( #[$meta] )*
145
        $v enum $id {
146
            $( $(#[$cmeta])*
147
               $case($case),
148
            )*
149
            /// An extension of a type we do not recognize, or which we have not
150
            /// encoded.
151
            Unrecognized(UnrecognizedExt<$type_id>)
152
        }
153
        impl Readable for $id {
154
245
            fn take_from(b: &mut Reader<'_>) -> Result<Self> {
155
245
                let type_id = b.take_u8()?.into();
156
149
                Ok(match type_id {
157
98
                    $(
158
98
                        $type_id::[< $case:snake:upper >] => {
159
102
                            Self::$case( b.read_nested_u8len(|r| $case::take_body_from(r))? )
160
98
                        }
161
98
                    )*
162
98
                    _ => {
163
98
                        Self::Unrecognized(UnrecognizedExt {
164
99
                            type_id,
165
150
                            body: b.read_nested_u8len(|r| Ok(r.take_rest().into()))?,
166
                        })
167
                    }
168
                })
169
245
            }
170
        }
171
        impl Writeable for $id {
172
16
            fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<
173
16
()> {
174
                #[allow(unused)]
175
                use std::ops::DerefMut;
176
14
                match self {
177
4
                    $(
178
10
                        Self::$case(val) => {
179
10
                            b.write_u8(val.type_id().into());
180
10
                            let mut nested = b.write_nested_u8len();
181
10
                            val.write_body_onto(nested.deref_mut())?;
182
10
                            nested.finish()?;
183
4
                        }
184
4
                    )*
185
6
                    Self::Unrecognized(unrecognized) => {
186
6
                        b.write_u8(unrecognized.type_id.into());
187
6
                        let mut nested = b.write_nested_u8len();
188
6
                        nested.write_all(&unrecognized.body[..]);
189
6
                        nested.finish()?;
190
                    }
191
                }
192
16
                Ok(())
193
16
            }
194
        }
195
        impl ExtGroup for $id {
196
            type Id = $type_id;
197
296
            fn type_id(&self) -> Self::Id {
198
296
                match self {
199
98
                    $(
200
149
                        Self::$case(val) => val.type_id(),
201
98
                    )*
202
149
                    Self::Unrecognized(unrecognized) => unrecognized.type_id,
203
98
                }
204
296
            }
205
        }
206
        $(
207
        impl From<$case> for $id {
208
98
            fn from(val: $case) -> $id {
209
98
                $id :: $case ( val )
210
98
            }
211
        }
212
        )*
213
        impl From<UnrecognizedExt<$type_id>> for $id {
214
49
            fn from(val: UnrecognizedExt<$type_id>) -> $id {
215
49
                $id :: Unrecognized(val)
216
49
            }
217
        }
218
}}
219
}
220
pub(super) use decl_extension_group;