1//! Define helpers for working with types in constant time.
23use derive_deftly::{Deftly, define_derive_deftly};
4use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
5use zeroize::Zeroize;
67#[cfg(feature = "memquota-memcost")]
8use tor_memquota::derive_deftly_template_HasMemoryCost;
910define_derive_deftly! {
11/// Derives [`subtle::ConstantTimeEq`] on structs for which all fields
12 /// already implement it. Note that this does NOT work on fields which are
13 /// arrays of type `T`, even if `T` implements [`subtle::ConstantTimeEq`].
14 /// Arrays do not directly implement [`subtle::ConstantTimeEq`] and instead
15 /// dereference to a slice, `[T]`, which does. See subtle!114 for a possible
16 /// future resolution.
17export ConstantTimeEq for struct:
1819impl<$tgens> ConstantTimeEq for $ttype
20where $twheres
21$( $ftype : ConstantTimeEq , )
22 {
23fn ct_eq(&self, other: &Self) -> subtle::Choice {
24match (self, other) {
25 $(
26 (${vpat fprefix=self_}, ${vpat fprefix=other_}) => {
27 $(
28 $<self_ $fname>.ct_eq($<other_ $fname>) &
29 )
30 subtle::Choice::from(1)
31 },
32 )
33 }
34 }
35 }
36}
37define_derive_deftly! {
38/// Derives [`core::cmp::PartialEq`] on types which implement
39 /// [`subtle::ConstantTimeEq`] by calling [`subtle::ConstantTimeEq::ct_eq`].
40export PartialEqFromCtEq:
4142impl<$tgens> PartialEq for $ttype
43where $twheres
44 $ttype : ConstantTimeEq
45 {
46fn eq(&self, other: &Self) -> bool {
47self.ct_eq(other).into()
48 }
49 }
50}
51pub(crate) use {derive_deftly_template_ConstantTimeEq, derive_deftly_template_PartialEqFromCtEq};
5253/// A byte array of length N for which comparisons are performed in constant
54/// time.
55///
56/// # Limitations
57///
58/// It is possible to avoid constant time comparisons here, just by using the
59/// `as_ref()` and `as_mut()` methods. They should therefore be approached with
60/// some caution.
61///
62/// (The decision to avoid implementing `Deref`/`DerefMut` is deliberate.)
63#[allow(clippy::derived_hash_with_manual_eq)]
64#[derive(Clone, Copy, Debug, Hash, Zeroize, derive_more::Deref)]
65#[cfg_attr(
66 feature = "memquota-memcost",
67 derive(Deftly),
68 derive_deftly(HasMemoryCost)
69)]
70pub struct CtByteArray<const N: usize>([u8; N]);
7172impl<const N: usize> ConstantTimeEq for CtByteArray<N> {
73fn ct_eq(&self, other: &Self) -> Choice {
74self.0.ct_eq(&other.0)
75 }
76}
7778impl<const N: usize> PartialEq for CtByteArray<N> {
79fn eq(&self, other: &Self) -> bool {
80self.ct_eq(other).into()
81 }
82}
83impl<const N: usize> Eq for CtByteArray<N> {}
8485impl<const N: usize> From<[u8; N]> for CtByteArray<N> {
86fn from(value: [u8; N]) -> Self {
87Self(value)
88 }
89}
9091impl<const N: usize> From<CtByteArray<N>> for [u8; N] {
92fn from(value: CtByteArray<N>) -> Self {
93 value.0
94}
95}
9697impl<const N: usize> Ord for CtByteArray<N> {
98fn cmp(&self, other: &Self) -> std::cmp::Ordering {
99// At every point, this value will be set to:
100 // 0 if a[i]==b[i] for all i considered so far.
101 // a[i] - b[i] for the lowest i that has a nonzero a[i] - b[i].
102let mut first_nonzero_difference = 0_i16;
103104for (a, b) in self.0.iter().zip(other.0.iter()) {
105let difference = i16::from(*a) - i16::from(*b);
106107// If it's already set to a nonzero value, this conditional
108 // assignment does nothing. Otherwise, it sets it to `difference`.
109 //
110 // The use of conditional_assign and ct_eq ensures that the compiler
111 // won't short-circuit our logic here and end the loop (or stop
112 // computing differences) on the first nonzero difference.
113first_nonzero_difference
114 .conditional_assign(&difference, first_nonzero_difference.ct_eq(&0));
115 }
116117// This comparison with zero is not itself constant-time, but that's
118 // okay: we only want our Ord function not to leak the array values.
119first_nonzero_difference.cmp(&0)
120 }
121}
122123impl<const N: usize> PartialOrd for CtByteArray<N> {
124fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
125Some(self.cmp(other))
126 }
127}
128129impl<const N: usize> AsRef<[u8; N]> for CtByteArray<N> {
130fn as_ref(&self) -> &[u8; N] {
131&self.0
132}
133}
134135impl<const N: usize> AsMut<[u8; N]> for CtByteArray<N> {
136fn as_mut(&mut self) -> &mut [u8; N] {
137&mut self.0
138}
139}
140141/// Try to find an item in a slice without leaking where and whether the
142/// item was found.
143///
144/// If there is any item `x` in the `array` for which `matches(x)`
145/// is true, this function will return a reference to one such
146/// item. (We don't specify which.)
147///
148/// Otherwise, this function returns none.
149///
150/// We evaluate `matches` on every item of the array, and try not to
151/// leak by timing which element (if any) matched. Note that if
152/// `matches` itself has side channels, this function can't hide them.
153///
154/// Note that this doesn't necessarily do a constant-time comparison,
155/// and that it is not constant-time for the found/not-found case.
156pub fn ct_lookup<T, F>(array: &[T], matches: F) -> Option<&T>
157where
158F: Fn(&T) -> Choice,
159{
160// ConditionallySelectable isn't implemented for usize, so we need
161 // to use u64.
162let mut idx: u64 = 0;
163let mut found: Choice = 0.into();
164165for (i, x) in array.iter().enumerate() {
166let equal = matches(x);
167 idx.conditional_assign(&(i as u64), equal);
168 found.conditional_assign(&equal, equal);
169 }
170171if found.into() {
172Some(&array[idx as usize])
173 } else {
174None
175}
176}
177178#[cfg(test)]
179mod test {
180// @@ begin test lint list maintained by maint/add_warning @@
181#![allow(clippy::bool_assert_comparison)]
182 #![allow(clippy::clone_on_copy)]
183 #![allow(clippy::dbg_macro)]
184 #![allow(clippy::mixed_attributes_style)]
185 #![allow(clippy::print_stderr)]
186 #![allow(clippy::print_stdout)]
187 #![allow(clippy::single_char_pattern)]
188 #![allow(clippy::unwrap_used)]
189 #![allow(clippy::unchecked_duration_subtraction)]
190 #![allow(clippy::useless_vec)]
191 #![allow(clippy::needless_pass_by_value)]
192//! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
193194use super::*;
195use rand::Rng;
196use tor_basic_utils::test_rng;
197198#[allow(clippy::nonminimal_bool)]
199 #[test]
200fn test_comparisons() {
201let num = 200;
202let mut rng = test_rng::testing_rng();
203204let mut array: Vec<CtByteArray<32>> =
205 (0..num).map(|_| rng.random::<[u8; 32]>().into()).collect();
206 array.sort();
207208for i in 0..num {
209assert_eq!(array[i], array[i]);
210assert!(!(array[i] < array[i]));
211assert!(!(array[i] > array[i]));
212213for j in (i + 1)..num {
214// Note that this test will behave incorrectly if the rng
215 // generates the same 256 value twice, but that's ridiculously
216 // implausible.
217assert!(array[i] < array[j]);
218assert_ne!(array[i], array[j]);
219assert!(array[j] > array[i]);
220assert_eq!(
221 array[i].cmp(&array[j]),
222 array[j].as_ref().cmp(array[i].as_ref()).reverse()
223 );
224 }
225 }
226 }
227228#[test]
229fn test_lookup() {
230use super::ct_lookup as lookup;
231use subtle::ConstantTimeEq;
232let items = vec![
233"One".to_string(),
234"word".to_string(),
235"of".to_string(),
236"every".to_string(),
237"length".to_string(),
238 ];
239let of_word = lookup(&items[..], |i| i.len().ct_eq(&2));
240let every_word = lookup(&items[..], |i| i.len().ct_eq(&5));
241let no_word = lookup(&items[..], |i| i.len().ct_eq(&99));
242assert_eq!(of_word.unwrap(), "of");
243assert_eq!(every_word.unwrap(), "every");
244assert_eq!(no_word, None);
245 }
246}