1
//! Casting objects to trait pointers.
2
//!
3
//! Rust supports Any-to-Concrete downcasting via Any;
4
//! and the `downcast_rs` crate supports Trait-to-Concrete downcasting.
5
//! This module adds `Trait-to-Trait` downcasting for the Object trait.
6

            
7
use std::{
8
    any::{Any, TypeId},
9
    collections::HashMap,
10
    sync::Arc,
11
};
12

            
13
use once_cell::sync::Lazy;
14

            
15
use crate::Object;
16

            
17
/// A collection of functions to downcast `&dyn Object` references for some
18
/// particular concrete object type `O` into various `&dyn Trait` references.
19
///
20
/// You shouldn't construct this on your own: instead use
21
/// `derive_deftly(Object)`.
22
///
23
/// You shouldn't use this directly; instead use
24
/// [`ObjectArcExt`](super::ObjectArcExt).
25
///
26
/// Note that the concrete object type `O`
27
/// is *not* represented in the type of `CastTable`;
28
/// `CastTable`s are obtained and used at runtime, as part of dynamic dispatch,
29
/// so the type `O` is erased.  We work with `TypeId`s and various `&dyn ...`.
30
#[derive(Default)]
31
pub struct CastTable {
32
    /// A mapping from target TypeId for some trait to a function that can
33
    /// convert this table's type into a trait pointer to that trait.
34
    ///
35
    /// Every entry in this table must contain:
36
    ///
37
    ///   * A key that is `typeid::of::<&'static dyn Tr>()` for some trait `Tr`.
38
    ///   * A [`Caster`] whose functions are suitable for casting objects from this table's
39
    ///     type to `dyn Tr`.
40
    table: HashMap<TypeId, Caster>,
41
}
42

            
43
/// A single entry in a `CastTable`.
44
///
45
/// Each `Caster` exists for one concrete object type "`O`", and one trait type "`Tr`".
46
///
47
/// Note that we use `Box` here in order to support generic types: you can't
48
/// get a `&'static` reference to a function that takes a generic type in
49
/// current rust.
50
struct Caster {
51
    /// Actual type: `fn(Arc<dyn Object>) -> Arc<dyn Tr>`
52
    ///
53
    /// Panics if Object does not have the expected type (`O`).
54
    cast_to_ref: Box<dyn Any + Send + Sync>,
55
    /// Actual type: `fn(Arc<dyn Object>) -> Arc<dyn Tr>`
56
    ///
57
    /// Panics if Object does not have the expected type (`O`).
58
    cast_to_arc: Box<dyn Any + Send + Sync>,
59
}
60

            
61
impl CastTable {
62
    /// Add a new entry to this `CastTable` for downcasting to TypeId.
63
    ///
64
    /// You should not call this yourself; instead use
65
    /// [`derive_deftly(Object)`](crate::templates::derive_deftly_template_Object)
66
    ///
67
    /// # Requirements
68
    ///
69
    /// `T` must be `dyn Tr` for some trait `Tr`.
70
    /// (Not checked by the compiler.)
71
    ///
72
    /// `cast_to_ref` is a downcaster from `&dyn Object` to `&dyn Tr`.
73
    ///
74
    /// `cast_to_arc` is a downcaster from `Arc<dyn Object>` to `Arc<dyn Tr>`.
75
    ///
76
    /// These functions SHOULD
77
    /// panic if the concrete type of its argument is not the concrete type `O`
78
    /// associated with this `CastTable`.
79
    ///
80
    /// `O` must be `'static`.
81
    /// (Checked by the compiler.)
82
    ///
83
    /// # Panics
84
    ///
85
    /// Panics if called twice on the same `CastTable` with the same `Tr`.
86
    //
87
    // `TypeId::of::<dyn SomeTrait + '_>` exists, but is not the same as
88
    // `TypeId::of::<dyn SomeTrait + 'static>` (unless `SomeTrait: 'static`).
89
    //
90
    // We avoid a consequent bug with non-'static traits as follows:
91
    // We insert and look up by `TypeId::of::<&'static dyn SomeTrait>`,
92
    // which must mean `&'static (dyn SomeTrait + 'static)`
93
    // since a 'static reference to anything non-'static is an ill-formed type.
94
10
    pub fn insert<T: 'static + ?Sized>(
95
10
        &mut self,
96
10
        cast_to_ref: fn(&dyn Object) -> &T,
97
10
        cast_to_arc: fn(Arc<dyn Object>) -> Arc<T>,
98
10
    ) {
99
10
        let type_id = TypeId::of::<&'static T>();
100
10
        let caster = Caster {
101
10
            cast_to_ref: Box::new(cast_to_ref),
102
10
            cast_to_arc: Box::new(cast_to_arc),
103
10
        };
104
10
        self.insert_erased(type_id, caster);
105
10
    }
106

            
107
    /// Implementation for adding an entry to the `CastTable`
108
    ///
109
    /// Broken out for clarity and to reduce monomorphisation.
110
    ///
111
    /// ### Requirements
112
    ///
113
    /// Like `insert`, but less compile-time checking.
114
    /// `type_id` is the identity of `&'static dyn Tr`,
115
    /// and `func` has been boxed and type-erased.
116
10
    fn insert_erased(&mut self, type_id: TypeId, caster: Caster) {
117
10
        let old_val = self.table.insert(type_id, caster);
118
10
        assert!(
119
10
            old_val.is_none(),
120
            "Tried to insert a duplicate entry in a cast table.",
121
        );
122
10
    }
123

            
124
    /// Try to downcast a reference to an object whose concrete type is
125
    /// `O` (the type associated with this `CastTable`)
126
    /// to some target type `T`.
127
    ///
128
    /// `T` should be `dyn Tr`.
129
    /// If `T` is not one of the `dyn Tr` for which `insert` was called,
130
    /// returns `None`.
131
    ///
132
    /// # Panics
133
    ///
134
    /// Panics if the concrete type of `obj` does not match `O`.
135
    ///
136
    /// May panic if any of the Requirements for [`CastTable::insert`] were
137
    /// violated.
138
12
    pub fn cast_object_to<'a, T: 'static + ?Sized>(&self, obj: &'a dyn Object) -> Option<&'a T> {
139
12
        let target_type = TypeId::of::<&'static T>();
140
12
        let caster = self.table.get(&target_type)?;
141
10
        let caster: &fn(&dyn Object) -> &T = caster
142
10
            .cast_to_ref
143
10
            .downcast_ref()
144
10
            .expect("Incorrect cast-function type found in cast table!");
145
10
        Some(caster(obj))
146
12
    }
147

            
148
    /// As [`cast_object_to`](CastTable::cast_object_to), but returns an `Arc<dyn Tr>`.
149
    ///
150
    /// If `T` is not one of the `dyn Tr` types for which `insert_arc` was called,
151
    /// return `Err(obj)`.
152
    ///
153
    /// # Panics
154
    ///
155
    /// Panics if the concrete type of `obj` does not match `O`.
156
    ///
157
    /// May panic if any of the Requirements for [`CastTable::insert`] were
158
    /// violated.
159
8
    pub fn cast_object_to_arc<T: 'static + ?Sized>(
160
8
        &self,
161
8
        obj: Arc<dyn Object>,
162
8
    ) -> Result<Arc<T>, Arc<dyn Object>> {
163
8
        let target_type = TypeId::of::<&'static T>();
164
8
        let caster = match self.table.get(&target_type) {
165
6
            Some(c) => c,
166
2
            None => return Err(obj),
167
        };
168
6
        let caster: &fn(Arc<dyn Object>) -> Arc<T> = caster
169
6
            .cast_to_arc
170
6
            .downcast_ref()
171
6
            .expect("Incorrect cast-function type found in cast table!");
172
6
        Ok(caster(obj))
173
8
    }
174
}
175

            
176
/// Static cast table that doesn't support casting anything to anything.
177
///
178
/// Because this table doesn't support any casting, it is okay to use it with
179
/// any concrete type.
180
pub(super) static EMPTY_CAST_TABLE: Lazy<CastTable> = Lazy::new(|| CastTable {
181
    table: HashMap::new(),
182
});
183

            
184
/// Helper for HasCastTable to work around derive-deftly#36.
185
///
186
/// Defines the body for a private make_cast_table() method.
187
///
188
/// This macro is not part of `tor-rpcbase`'s public API, and is not covered
189
/// by semver guarantees.
190
#[doc(hidden)]
191
#[macro_export]
192
macro_rules! cast_table_deftness_helper{
193
    // Note: We have to use tt here, since $ty can't be used in $(dyn .)
194
    { $( $traitname:tt ),* } => {
195
                #[allow(unused_mut)]
196
                let mut table = $crate::CastTable::default();
197
                $({
198
                    use std::sync::Arc;
199
                    // These are the actual functions that does the downcasting.
200
                    // It works by downcasting with Any to the concrete type, and then
201
                    // upcasting from the concrete type to &dyn Trait.
202
10
                    let cast_to_ref: fn(&dyn $crate::Object) -> &(dyn $traitname + 'static) = |self_| {
203
10
                        let self_: &Self = self_.downcast_ref().unwrap();
204
10
                        let self_: &dyn $traitname = self_ as _;
205
10
                        self_
206
10
                    };
207
6
                    let cast_to_arc: fn(Arc<dyn $crate::Object>) -> Arc<dyn $traitname> = |self_| {
208
6
                        let self_: Arc<Self> = self_
209
6
                            .downcast_arc()
210
6
                            .ok()
211
6
                            .expect("used with incorrect type");
212
6
                        let self_: Arc<dyn $traitname> = self_ as _;
213
6
                        self_
214
6
                    };
215
                    table.insert::<dyn $traitname>(cast_to_ref, cast_to_arc);
216
                })*
217
                table
218
    }
219
}
220

            
221
#[cfg(test)]
222
mod test {
223
    // @@ begin test lint list maintained by maint/add_warning @@
224
    #![allow(clippy::bool_assert_comparison)]
225
    #![allow(clippy::clone_on_copy)]
226
    #![allow(clippy::dbg_macro)]
227
    #![allow(clippy::mixed_attributes_style)]
228
    #![allow(clippy::print_stderr)]
229
    #![allow(clippy::print_stdout)]
230
    #![allow(clippy::single_char_pattern)]
231
    #![allow(clippy::unwrap_used)]
232
    #![allow(clippy::unchecked_duration_subtraction)]
233
    #![allow(clippy::useless_vec)]
234
    #![allow(clippy::needless_pass_by_value)]
235
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
236

            
237
    use super::*;
238
    use crate::templates::*;
239
    use derive_deftly::Deftly;
240

            
241
    trait Tr1 {}
242
    trait Tr2: 'static {}
243

            
244
    #[derive(Deftly)]
245
    #[derive_deftly(Object)]
246
    #[deftly(rpc(downcastable_to = "Tr1"))]
247
    struct Simple;
248
    impl Tr1 for Simple {}
249

            
250
    #[test]
251
    fn check_simple() {
252
        let concrete = Simple;
253
        let tab = Simple::make_cast_table();
254
        let obj: &dyn Object = &concrete;
255
        let _cast: &(dyn Tr1 + '_) = tab.cast_object_to(obj).expect("cast failed");
256

            
257
        let arc = Arc::new(Simple);
258
        let arc_obj: Arc<dyn Object> = arc.clone();
259
        let _cast: Arc<dyn Tr1> = tab.cast_object_to_arc(arc_obj).ok().expect("cast failed");
260
    }
261

            
262
    #[derive(Deftly)]
263
    #[derive_deftly(Object)]
264
    #[deftly(rpc(downcastable_to = "Tr1, Tr2"))]
265
    struct Generic<T: Send + Sync + 'static>(T);
266

            
267
    impl<T: Send + Sync + 'static> Tr1 for Generic<T> {}
268
    impl<T: Send + Sync + 'static> Tr2 for Generic<T> {}
269

            
270
    #[test]
271
    fn check_generic() {
272
        let gen: Generic<&'static str> = Generic("foo");
273
        let tab = Generic::<&'static str>::make_cast_table();
274
        let obj: &dyn Object = &gen;
275
        let _cast: &(dyn Tr1 + '_) = tab.cast_object_to(obj).expect("cast failed");
276

            
277
        let arc = Arc::new(Generic("bar"));
278
        let arc_obj: Arc<dyn Object> = arc.clone();
279
        let _cast: Arc<dyn Tr2> = tab.cast_object_to_arc(arc_obj).ok().expect("cast failed");
280
    }
281
}