tor_config/
flatten.rs

1//! Similar to `#[serde(flatten)]` but works with [`serde_ignored`]
2//!
3//! Our approach to deserialize a [`Flatten`] is as follows:
4//!
5//!  * We tell the input data format (underlying deserializer) that we want a map.
6//!  * In our visitor, we visit each key in the map in order
7//!  * For each key, we consult `Flattenable::has_field` to find out which child it's in
8//!    (fields in T shadow fields in U, as with serde),
9//!    and store the key and the value in the appropriate [`Portion`].
10//!    (We must store the value as a [`serde_value::Value`]
11//!    since we don't know what type it should be,
12//!    and can't know until we are ready to enter T and U's [`Deserialize`] impls.)
13//!  * If it's in neither T nor U, we explicitly ignore the value
14//!  * When we've processed all the fields, we call the actual deserialisers for T and U:
15//!    we take on the role of the data format, giving each of T and U a map.
16//!
17//! From the point of view of T and U, we each offer them a subset of the fields,
18//! having already rendered the keys to strings and the values to `Value`.
19//!
20//! From the point of view of the data format (which might be a `serde_ignored` proxy)
21//! we consume the union of the fields, and ignore the rest.
22//!
23//! ### Rationale and alternatives
24//!
25//! The key difficulty is this:
26//! we want to call [`Deserializer::deserialize_ignored_any`]
27//! on our input data format for precisely the fields which neither T nor U want.
28//! We must achieve this somehow using information from T or U.
29//! If we tried to use only the [`Deserialize`] impls,
30//! the only way to detect this is to call their `deserialize` methods
31//! and watch to see if they in turn call `deserialize_ignored_any`.
32//! But we need to be asking each of T and U this question for each field:
33//! the shape of [`MapAccess`] puts the data structure in charge of sequencing.
34//! So we would need to somehow suspend `T`'s deserialisation,
35//! and call `U`'s, and then suspend `U`s, and go back to `T`.
36//!
37//! Other possibilities that seemed worse:
38//!
39//!  * Use threads.
40//!    We could spawn a thread for each of `T` and `U`,
41//!    allowing us to run them in parallel and control their execution flow.
42//!
43//!  * Use coroutines eg. [corosensei](https://lib.rs/crates/corosensei)
44//!    (by Amanieu, author of hashbrown etc.)
45//!
46//!  * Instead of suspending and restarting `T` and `U`'s deserialisation,
47//!    discard the partially-deserialised `T` and `U` and restart them each time
48//!    (with cloned copies of the `Value`s).  This is O(n^2) and involves much boxing.
49//!
50//! # References
51//!
52//!  * Tickets against `serde-ignored`:
53//!    <https://github.com/dtolnay/serde-ignored/issues/17>
54//!    <https://github.com/dtolnay/serde-ignored/issues/10>
55//!
56//!  * Workaround with `HashMap` that doesn't quite work right:
57//!    <https://github.com/dtolnay/serde-ignored/issues/10#issuecomment-1044058310>
58//!    <https://github.com/serde-rs/serde/issues/2176>
59//!
60//!  * Discussion in Tor Project gitlab re Arti configuration:
61//!    <https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/1599#note_2944510>
62
63use std::collections::VecDeque;
64use std::fmt::{self, Display};
65use std::marker::PhantomData;
66use std::mem;
67
68use derive_deftly::{define_derive_deftly, derive_deftly_adhoc, Deftly};
69use paste::paste;
70use serde::de::{self, DeserializeSeed, Deserializer, Error as _, IgnoredAny, MapAccess, Visitor};
71use serde::{Deserialize, Serialize, Serializer};
72use serde_value::Value;
73use thiserror::Error;
74
75// Must come first so we can refer to it in docs
76define_derive_deftly! {
77    /// Derives [`Flattenable`] for a struct
78    ///
79    /// # Limitations
80    ///
81    /// Some serde attributes might not be supported.
82    /// For example, ones which make the type no longer deserialize as a named fields struct.
83    /// This will be detected by a macro-generated always-failing test case.
84    ///
85    /// Most serde attributes (eg field renaming and ignoring) will be fine.
86    ///
87    /// # Example
88    ///
89    /// ```
90    /// use serde::{Serialize, Deserialize};
91    /// use derive_deftly::Deftly;
92    /// use tor_config::derive_deftly_template_Flattenable;
93    ///
94    /// #[derive(Serialize, Deserialize, Debug, Deftly)]
95    /// #[derive_deftly(Flattenable)]
96    /// struct A {
97    ///     a: i32,
98    /// }
99    /// ```
100    //
101    // Note re semver:
102    //
103    // We re-export derive-deftly's template engine, in the manner discussed by the d-a docs.
104    // See
105    //  https://docs.rs/derive-deftly/latest/derive_deftly/macro.define_derive_deftly.html#exporting-a-template-for-use-by-other-crates
106    //
107    // The semantic behaviour of the template *does* have semver implications.
108    export Flattenable for struct, expect items:
109
110    impl tor_config::Flattenable for $ttype {
111        fn has_field(s: &str) -> bool {
112            let fnames = tor_config::flattenable_extract_fields::<'_, Self>();
113            IntoIterator::into_iter(fnames).any(|f| *f == s)
114
115        }
116    }
117
118    // Detect if flattenable_extract_fields panics
119    #[test]
120    fn $<flattenable_test_ ${snake_case $tname}>() {
121        // Using $ttype::has_field avoids writing out again
122        // the call to flattenable_extract_fields, with all its generics,
123        // and thereby ensures that we didn't have a mismatch that
124        // allows broken impls to slip through.
125        // (We know the type is at least similar because we go via the Flattenable impl.)
126        let _: bool = <$ttype as tor_config::Flattenable>::has_field("");
127    }
128}
129
130/// Helper for flattening deserialisation, compatible with [`serde_ignored`]
131///
132/// A combination of two structs `T` and `U`.
133///
134/// The serde representation flattens both structs into a single, larger, struct.
135///
136/// Furthermore, unlike plain use of `#[serde(flatten)]`,
137/// `serde_ignored` will still detect fields which appear in serde input
138/// but which form part of neither `T` nor `U`.
139///
140/// `T` and `U` must both be [`Flattenable`].
141/// Usually that trait should be derived with
142/// the [`Flattenable macro`](derive_deftly_template_Flattenable).
143///
144/// If it's desired to combine more than two structs, `Flatten` can be nested.
145///
146/// # Limitations
147///
148/// Field name overlaps are not detected.
149/// Fields which appear in both structs
150/// will be processed as part of `T` during deserialization.
151/// They will be internally presented as duplicate fields during serialization,
152/// with the outcome depending on the data format implementation.
153///
154/// # Example
155///
156/// ```
157/// use serde::{Serialize, Deserialize};
158/// use derive_deftly::Deftly;
159/// use tor_config::{Flatten, derive_deftly_template_Flattenable};
160///
161/// #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
162/// #[derive_deftly(Flattenable)]
163/// struct A {
164///     a: i32,
165/// }
166///
167/// #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
168/// #[derive_deftly(Flattenable)]
169/// struct B {
170///     b: String,
171/// }
172///
173/// let combined: Flatten<A,B> = toml::from_str(r#"
174///     a = 42
175///     b = "hello"
176/// "#).unwrap();
177///
178/// assert_eq!(
179///    combined,
180///    Flatten(A { a: 42 }, B { b: "hello".into() }),
181/// );
182/// ```
183//
184// We derive Deftly on Flatten itself so we can use
185// derive_deftly_adhoc! to iterate over Flatten's two fields.
186// This avoids us accidentally (for example) checking T's fields for passing to U.
187#[derive(Deftly, Debug, Clone, Copy, Hash, Ord, PartialOrd, Eq, PartialEq, Default)]
188#[derive_deftly_adhoc]
189#[allow(clippy::exhaustive_structs)]
190pub struct Flatten<T, U>(pub T, pub U);
191
192/// Types that can be used with [`Flatten`]
193///
194/// Usually, derived with
195/// the [`Flattenable derive-deftly macro`](derive_deftly_template_Flattenable).
196pub trait Flattenable {
197    /// Does this type have a field named `s` ?
198    fn has_field(f: &str) -> bool;
199}
200
201//========== local helper macros ==========
202
203/// Implement `deserialize_$what` as a call to `deserialize_any`.
204///
205/// `$args`, if provided, are any other formal arguments, not including the `Visitor`
206macro_rules! call_any { { $what:ident $( $args:tt )* } => { paste!{
207    fn [<deserialize_ $what>]<V>(self $( $args )*, visitor: V) -> Result<V::Value, Self::Error>
208    where
209        V: Visitor<'de>,
210    {
211        self.deserialize_any(visitor)
212    }
213} } }
214
215/// Implement most `deserialize_*` as calls to `deserialize_any`.
216///
217/// The exceptions are the ones we need to handle specially in any of our types,
218/// namely `any` itself and `struct`.
219macro_rules! call_any_for_rest { {} => {
220    call_any!(map);
221    call_any!(bool);
222    call_any!(byte_buf);
223    call_any!(bytes);
224    call_any!(char);
225    call_any!(f32);
226    call_any!(f64);
227    call_any!(i128);
228    call_any!(i16);
229    call_any!(i32);
230    call_any!(i64);
231    call_any!(i8);
232    call_any!(identifier);
233    call_any!(ignored_any);
234    call_any!(option);
235    call_any!(seq);
236    call_any!(str);
237    call_any!(string);
238    call_any!(u128);
239    call_any!(u16);
240    call_any!(u32);
241    call_any!(u64);
242    call_any!(u8);
243    call_any!(unit);
244
245    call_any!(enum, _: &'static str, _: FieldList);
246    call_any!(newtype_struct, _: &'static str );
247    call_any!(tuple, _: usize );
248    call_any!(tuple_struct, _: &'static str, _: usize );
249    call_any!(unit_struct, _: &'static str );
250} }
251
252//========== Implementations of Serialize and Flattenable ==========
253
254derive_deftly_adhoc! {
255    Flatten expect items:
256
257    impl<T, U> Serialize for Flatten<T, U>
258    where $( $ftype: Serialize, )
259    {
260        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
261        where S: Serializer
262        {
263            /// version of outer `Flatten` containing references
264            ///
265            /// We give it the same name because the name is visible via serde
266            ///
267            /// The problems with `#[serde(flatten)]` don't apply to serialisation,
268            /// because we're not trying to track ignored fields.
269            /// But we can't just apply `#[serde(flatten)]` to `Flatten`
270            /// since it doesn't work with tuple structs.
271            #[derive(Serialize)]
272            struct Flatten<'r, T, U> {
273              $(
274                #[serde(flatten)]
275                $fpatname: &'r $ftype,
276              )
277            }
278
279            Flatten {
280              $(
281                $fpatname: &self.$fname,
282              )
283            }
284            .serialize(serializer)
285        }
286    }
287
288    /// `Flatten` may be nested
289    impl<T, U> Flattenable for Flatten<T, U>
290    where $( $ftype: Flattenable, )
291    {
292        fn has_field(f: &str) -> bool {
293            $(
294                $ftype::has_field(f)
295                    ||
296              )
297                false
298        }
299    }
300}
301
302//========== Deserialize implementation ==========
303
304/// The keys and values we are to direct to a particular child
305///
306/// See the module-level comment for the algorithm.
307#[derive(Default)]
308struct Portion(VecDeque<(String, Value)>);
309
310/// [`de::Visitor`] for `Flatten`
311struct FlattenVisitor<T, U>(PhantomData<(T, U)>);
312
313/// Wrapper for a field name, impls [`de::Deserializer`]
314struct Key(String);
315
316/// Type alias for reified error
317///
318/// [`serde_value::DeserializerError`] has one variant
319/// for each of the constructors of [`de::Error`].
320type FlattenError = serde_value::DeserializerError;
321
322//----- part 1: disassembly -----
323
324derive_deftly_adhoc! {
325    Flatten expect items:
326
327    // where constraint on the Deserialize impl
328    ${define FLATTENABLE $( $ftype: Deserialize<'de> + Flattenable, )}
329
330    impl<'de, T, U> Deserialize<'de> for Flatten<T, U>
331    where $FLATTENABLE
332    {
333        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
334        where D: Deserializer<'de>
335        {
336            deserializer.deserialize_map(FlattenVisitor(PhantomData))
337        }
338    }
339
340    impl<'de, T, U> Visitor<'de> for FlattenVisitor<T,U>
341    where $FLATTENABLE
342    {
343        type Value = Flatten<T, U>;
344
345        fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346            write!(f, "map (for struct)")
347        }
348
349        fn visit_map<A>(self, mut mapa: A) -> Result<Self::Value, A::Error>
350        where A: MapAccess<'de>
351        {
352            // See the module-level comment for an explanation.
353
354            // $P is a local variable named after T/U: `p_t` or `p_u`, as appropriate
355            ${define P $<p_ $fname>}
356
357            ${for fields { let mut $P = Portion::default(); }}
358
359            #[allow(clippy::suspicious_else_formatting)] // this is the least bad layout
360            while let Some(k) = mapa.next_key::<String>()? {
361              $(
362                 if $ftype::has_field(&k) {
363                    let v: Value = mapa.next_value()?;
364                    $P.0.push_back((k, v));
365                    continue;
366                }
367                else
368              )
369                {
370                     let _: IgnoredAny = mapa.next_value()?;
371                }
372            }
373
374            Flatten::assemble( ${for fields { $P, }} )
375                .map_err(A::Error::custom)
376        }
377    }
378}
379
380//----- part 2: reassembly -----
381
382derive_deftly_adhoc! {
383    Flatten expect items:
384
385    impl<'de, T, U> Flatten<T, U>
386    where $( $ftype: Deserialize<'de>, )
387    {
388        /// Assemble a `Flatten` out of the partition of its keys and values
389        ///
390        /// Uses `Portion`'s `Deserializer` impl and T and U's `Deserialize`
391        fn assemble(
392          $(
393            $fpatname: Portion,
394          )
395        ) -> Result<Self, FlattenError> {
396            Ok(Flatten(
397              $(
398                $ftype::deserialize($fpatname)?,
399              )
400            ))
401        }
402    }
403}
404
405impl<'de> Deserializer<'de> for Portion {
406    type Error = FlattenError;
407
408    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
409    where
410        V: Visitor<'de>,
411    {
412        visitor.visit_map(self)
413    }
414
415    call_any!(struct, _: &'static str, _: FieldList);
416    call_any_for_rest!();
417}
418
419impl<'de> MapAccess<'de> for Portion {
420    type Error = FlattenError;
421
422    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
423    where
424        K: DeserializeSeed<'de>,
425    {
426        let Some(entry) = self.0.get_mut(0) else {
427            return Ok(None);
428        };
429        let k = mem::take(&mut entry.0);
430        let k: K::Value = seed.deserialize(Key(k))?;
431        Ok(Some(k))
432    }
433
434    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
435    where
436        V: DeserializeSeed<'de>,
437    {
438        let v = self
439            .0
440            .pop_front()
441            .expect("next_value called inappropriately")
442            .1;
443        let r = seed.deserialize(v)?;
444        Ok(r)
445    }
446}
447
448impl<'de> Deserializer<'de> for Key {
449    type Error = FlattenError;
450
451    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
452    where
453        V: Visitor<'de>,
454    {
455        visitor.visit_string(self.0)
456    }
457
458    call_any!(struct, _: &'static str, _: FieldList);
459    call_any_for_rest!();
460}
461
462//========== Field extractor ==========
463
464/// List of fields, appears in several APIs here
465type FieldList = &'static [&'static str];
466
467/// Stunt "data format" which we use for extracting fields for derived `Flattenable` impls
468///
469/// The field extraction works as follows:
470///  * We ask serde to deserialize `$ttype` from a `FieldExtractor`
471///  * We expect the serde-macro-generated `Deserialize` impl to call `deserialize_struct`
472///  * We return the list of fields to match up as an error
473struct FieldExtractor;
474
475/// Error resulting from successful operation of a [`FieldExtractor`]
476///
477/// Existence of this error is a *success*.
478/// Unexpected behaviour by the type's serde implementation causes panics, not errors.
479#[derive(Error, Debug)]
480#[error("Flattenable macro test gave error, so test passed successfully")]
481struct FieldExtractorSuccess(FieldList);
482
483/// Extract fields of a struct, as viewed by `serde`
484///
485/// # Performance
486///
487/// In release builds, is very fast - all the serde nonsense boils off.
488/// In debug builds, maybe a hundred instructions, so not ideal,
489/// but it is at least O(1) since it doesn't have any loops.
490///
491/// # STABILITY WARNING
492///
493/// This function is `pub` but it is `#[doc(hidden)]`.
494/// The only legitimate use is via the `Flattenable` macro.
495/// There are **NO SEMVER GUARANTEES**
496///
497/// # Panics
498///
499/// Will panic on types whose serde field list cannot be simply extracted via serde,
500/// which will include things that aren't named fields structs,
501/// might include types decorated with unusual serde annotations.
502pub fn flattenable_extract_fields<'de, T: Deserialize<'de>>() -> FieldList {
503    let notional_input = FieldExtractor;
504    let FieldExtractorSuccess(fields) = T::deserialize(notional_input)
505        .map(|_| ())
506        .expect_err("unexpected success deserializing from FieldExtractor!");
507    fields
508}
509
510impl de::Error for FieldExtractorSuccess {
511    fn custom<E>(e: E) -> Self
512    where
513        E: Display,
514    {
515        panic!("Flattenable macro test failed - some *other* serde error: {e}");
516    }
517}
518
519impl<'de> Deserializer<'de> for FieldExtractor {
520    type Error = FieldExtractorSuccess;
521
522    fn deserialize_struct<V>(
523        self,
524        _name: &'static str,
525        fields: FieldList,
526        _visitor: V,
527    ) -> Result<V::Value, Self::Error>
528    where
529        V: Visitor<'de>,
530    {
531        Err(FieldExtractorSuccess(fields))
532    }
533
534    fn deserialize_any<V>(self, _: V) -> Result<V::Value, Self::Error>
535    where
536        V: Visitor<'de>,
537    {
538        panic!("test failed: Flattennable misimplemented by macros!");
539    }
540
541    call_any_for_rest!();
542}
543
544//========== tests ==========
545
546#[cfg(test)]
547mod test {
548    // @@ begin test lint list maintained by maint/add_warning @@
549    #![allow(clippy::bool_assert_comparison)]
550    #![allow(clippy::clone_on_copy)]
551    #![allow(clippy::dbg_macro)]
552    #![allow(clippy::mixed_attributes_style)]
553    #![allow(clippy::print_stderr)]
554    #![allow(clippy::print_stdout)]
555    #![allow(clippy::single_char_pattern)]
556    #![allow(clippy::unwrap_used)]
557    #![allow(clippy::unchecked_duration_subtraction)]
558    #![allow(clippy::useless_vec)]
559    #![allow(clippy::needless_pass_by_value)]
560    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
561    use super::*;
562    use crate as tor_config; // for the benefit of the macros
563
564    use std::collections::HashMap;
565
566    #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
567    #[derive_deftly(Flattenable)]
568    struct A {
569        a: i32,
570        m: HashMap<String, String>,
571    }
572
573    #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
574    #[derive_deftly(Flattenable)]
575    struct B {
576        b: i32,
577        v: Vec<String>,
578    }
579
580    #[derive(Serialize, Deserialize, Debug, Deftly, Eq, PartialEq)]
581    #[derive_deftly(Flattenable)]
582    struct C {
583        c: HashMap<String, String>,
584    }
585
586    const TEST_INPUT: &str = r#"
587        a = 42
588
589        m.one = "unum"
590        m.two = "bis"
591
592        b = 99
593        v = ["hi", "ho"]
594
595        spurious = 66
596
597        c.zed = "final"
598    "#;
599
600    fn test_input() -> toml::Value {
601        toml::from_str(TEST_INPUT).unwrap()
602    }
603    fn simply<'de, T: Deserialize<'de>>() -> T {
604        test_input().try_into().unwrap()
605    }
606    fn with_ignored<'de, T: Deserialize<'de>>() -> (T, Vec<String>) {
607        let mut ignored = vec![];
608        let f = serde_ignored::deserialize(
609            test_input(), //
610            |path| ignored.push(path.to_string()),
611        )
612        .unwrap();
613        (f, ignored)
614    }
615
616    #[test]
617    fn plain() {
618        let f: Flatten<A, B> = test_input().try_into().unwrap();
619        assert_eq!(f, Flatten(simply(), simply()));
620    }
621
622    #[test]
623    fn ignored() {
624        let (f, ignored) = with_ignored::<Flatten<A, B>>();
625        assert_eq!(f, simply());
626        assert_eq!(ignored, ["c", "spurious"]);
627    }
628
629    #[test]
630    fn nested() {
631        let (f, ignored) = with_ignored::<Flatten<A, Flatten<B, C>>>();
632        assert_eq!(f, simply());
633        assert_eq!(ignored, ["spurious"]);
634    }
635
636    #[test]
637    fn ser() {
638        let f: Flatten<A, Flatten<B, C>> = simply();
639
640        assert_eq!(
641            serde_json::to_value(f).unwrap(),
642            serde_json::json!({
643                "a": 42,
644                "m": {
645                    "one": "unum",
646                    "two": "bis"
647                },
648                "b": 99,
649                "v": [
650                    "hi",
651                    "ho"
652                ],
653                "c": {
654                    "zed": "final"
655                }
656            }),
657        );
658    }
659
660    /// This function exists only so we can disassemble it.
661    ///
662    /// To see what the result looks like in a release build:
663    ///
664    ///  * `RUSTFLAGS=-g cargo test -p tor-config --all-features --locked --release -- --nocapture flattenable_extract_fields_a_test`
665    ///  * Observe the binary that's run, eg `Running unittests src/lib.rs (target/release/deps/tor_config-d4c4f29c45a0a3f9)`
666    ///  * Disassemble it `objdump -d target/release/deps/tor_config-d4c4f29c45a0a3f9`
667    ///  * Search for this function: `less +/'28flattenable_extract_fields_a.*:'`
668    ///
669    /// At the time of writing, the result is three instructions:
670    /// load the address of the list, load a register with the constant 2 (the length),
671    /// return.
672    fn flattenable_extract_fields_a() -> FieldList {
673        flattenable_extract_fields::<'_, A>()
674    }
675
676    #[test]
677    fn flattenable_extract_fields_a_test() {
678        use std::hint::black_box;
679        let f: fn() -> _ = black_box(flattenable_extract_fields_a);
680        eprintln!("{:?}", f());
681    }
682}