1
//! Helpers to manage lists of HS cell extensions.
2
//
3
// TODO: We might generalize this even more in the future to handle other
4
// similar lists in our cell protocol.
5

            
6
use tor_bytes::{EncodeError, EncodeResult, Readable, Reader, Result, Writeable, Writer};
7

            
8
/// A list of extensions, represented in a common format used by many HS-related
9
/// message.
10
///
11
/// The common format is:
12
/// ```text
13
///      N_EXTENSIONS     [1 byte]
14
///      N_EXTENSIONS times:
15
///           EXT_FIELD_TYPE [1 byte]
16
///           EXT_FIELD_LEN  [1 byte]
17
///           EXT_FIELD      [EXT_FIELD_LEN bytes]
18
/// ```
19
///
20
/// It is subject to the additional restraints:
21
///
22
/// * Each extension type SHOULD be sent only once in a message.
23
/// * Parties MUST ignore any occurrences all occurrences of an extension
24
///   with a given type after the first such occurrence.
25
/// * Extensions SHOULD be sent in numerically ascending order by type.
26
308
#[derive(Clone, Debug, derive_more::Deref, derive_more::DerefMut)]
27
pub(super) struct ExtList<T> {
28
    /// The extensions themselves.
29
    extensions: Vec<T>,
30
}
31
impl<T> Default for ExtList<T> {
32
376
    fn default() -> Self {
33
376
        Self {
34
376
            extensions: Vec::new(),
35
376
        }
36
376
    }
37
}
38
/// An kind of extension that can be used with some kind of HS-related message.
39
///
40
/// Each extendible message will likely define its own enum,
41
/// implementing this trait,
42
/// representing the possible extensions.
43
pub(super) trait ExtGroup: Readable + Writeable {
44
    /// An identifier kind used with this sort of extension
45
    type Id: From<u8> + Into<u8> + Eq + PartialEq + Ord + Copy;
46
    /// The field-type id for this particular extension.
47
    fn type_id(&self) -> Self::Id;
48
}
49
/// A single typed extension that can be used with some kind of HS-related message.
50
pub(super) trait Ext: Sized {
51
    /// An identifier kind used with this sort of extension.
52
    ///
53
    /// Typically defined with caret_int.
54
    type Id: From<u8> + Into<u8>;
55
    /// The field-type id for this particular extension.
56
    fn type_id(&self) -> Self::Id;
57
    /// Extract the body (not the type or the length) from a single
58
    /// extension.
59
    fn take_body_from(b: &mut Reader<'_>) -> Result<Self>;
60
    /// Write the body (not the type or the length) for a single extension.
61
    fn write_body_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()>;
62
}
63
impl<T: ExtGroup> Readable for ExtList<T> {
64
470
    fn take_from(b: &mut Reader<'_>) -> Result<Self> {
65
470
        let n_extensions = b.take_u8()?;
66
470
        let extensions: Result<Vec<T>> = (0..n_extensions).map(|_| b.extract::<T>()).collect();
67
470
        Ok(Self {
68
470
            extensions: extensions?,
69
        })
70
470
    }
71
}
72
impl<T: ExtGroup> Writeable for ExtList<T> {
73
30
    fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
74
30
        let n_extensions = self
75
30
            .extensions
76
30
            .len()
77
30
            .try_into()
78
30
            .map_err(|_| EncodeError::BadLengthValue)?;
79
30
        b.write_u8(n_extensions);
80
30
        let mut exts_sorted: Vec<&T> = self.extensions.iter().collect();
81
30
        exts_sorted.sort_by_key(|ext| ext.type_id());
82
30
        exts_sorted.iter().try_for_each(|ext| ext.write_onto(b))?;
83
30
        Ok(())
84
30
    }
85
}
86
impl<T: ExtGroup> ExtList<T> {
87
    /// Insert `ext` into this list of extensions, replacing any previous
88
    /// extension with the same field type ID.
89
141
    pub(super) fn replace_by_type(&mut self, ext: T) {
90
141
        self.retain(|e| e.type_id() != ext.type_id());
91
141
        self.push(ext);
92
141
    }
93
}
94

            
95
/// An unrecognized or unencoded extension for some HS-related message.
96
4
#[derive(Clone, Debug)]
97
pub struct UnrecognizedExt<ID> {
98
    /// The field type ID for this extension.
99
    pub(super) type_id: ID,
100
    /// The body of this extension.
101
    pub(super) body: Vec<u8>,
102
}
103

            
104
impl<ID> UnrecognizedExt<ID> {
105
    /// Return a new unrecognized extension with a given ID and body.
106
    ///
107
    /// NOTE: nothing actually enforces that this type ID is not
108
    /// recognized.
109
    ///
110
    /// NOTE: This function accepts bodies longer than 255 bytes, but
111
    /// it is not possible to encode them.
112
2
    pub fn new(type_id: ID, body: impl Into<Vec<u8>>) -> Self {
113
2
        Self {
114
2
            type_id,
115
2
            body: body.into(),
116
2
        }
117
2
    }
118
}
119

            
120
/// Declare an Extension group that takes a given identifier.
121
//
122
// TODO: This is rather similar to restrict_msg(), isn't it?  Also, We use this
123
// pattern of (number, (cmd, length, body)*) a few of times in Tor outside the
124
// hs module.  Perhaps we can extend and unify our code here...
125
macro_rules! decl_extension_group {
126
    {
127
        $( #[$meta:meta] )*
128
        $v:vis enum $id:ident [ $type_id:ty ] {
129
            $(
130
                $(#[$cmeta:meta])*
131
                $case:ident),*
132
            $(,)?
133
        }
134
    } => {paste::paste!{
135
        $( #[$meta] )*
136
        $v enum $id {
137
            $( $(#[$cmeta])*
138
               $case($case),
139
            )*
140
            /// An extension of a type we do not recognize, or which we have not
141
            /// encoded.
142
            Unrecognized(UnrecognizedExt<$type_id>)
143
        }
144
        impl Readable for $id {
145
235
            fn take_from(b: &mut Reader<'_>) -> Result<Self> {
146
235
                let type_id = b.take_u8()?.into();
147
143
                Ok(match type_id {
148
94
                    $(
149
94
                        $type_id::[< $case:snake:upper >] => {
150
96
                            Self::$case( b.read_nested_u8len(|r| $case::take_body_from(r))? )
151
94
                        }
152
94
                    )*
153
94
                    _ => {
154
94
                        Self::Unrecognized(UnrecognizedExt {
155
95
                            type_id,
156
141
                            body: b.read_nested_u8len(|r| Ok(r.take_rest().into()))?,
157
                        })
158
                    }
159
                })
160
235
            }
161
        }
162
        impl Writeable for $id {
163
14
            fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<
164
14
()> {
165
14
                #[allow(unused)]
166
14
                use std::ops::DerefMut;
167
14
                match self {
168
4
                    $(
169
10
                        Self::$case(val) => {
170
10
                            b.write_u8(val.type_id().into());
171
10
                            let mut nested = b.write_nested_u8len();
172
10
                            val.write_body_onto(nested.deref_mut())?;
173
10
                            nested.finish()?;
174
4
                        }
175
4
                    )*
176
6
                    Self::Unrecognized(unrecognized) => {
177
6
                        b.write_u8(unrecognized.type_id.into());
178
6
                        let mut nested = b.write_nested_u8len();
179
6
                        nested.write_all(&unrecognized.body[..]);
180
6
                        nested.finish()?;
181
                    }
182
                }
183
16
                Ok(())
184
16
            }
185
        }
186
        impl ExtGroup for $id {
187
            type Id = $type_id;
188
284
            fn type_id(&self) -> Self::Id {
189
284
                match self {
190
94
                    $(
191
143
                        Self::$case(val) => val.type_id(),
192
94
                    )*
193
143
                    Self::Unrecognized(unrecognized) => unrecognized.type_id,
194
94
                }
195
284
            }
196
        }
197
        $(
198
        impl From<$case> for $id {
199
94
            fn from(val: $case) -> $id {
200
94
                $id :: $case ( val )
201
94
            }
202
        }
203
        )*
204
        impl From<UnrecognizedExt<$type_id>> for $id {
205
47
            fn from(val: UnrecognizedExt<$type_id>) -> $id {
206
47
                $id :: Unrecognized(val)
207
47
            }
208
        }
209
}}
210
}
211
pub(super) use decl_extension_group;