tor_dirmgr/
shared_ref.rs

1//! Utility module to safely refer to a mutable Arc.
2
3use std::sync::{Arc, RwLock};
4
5use educe::Educe;
6
7use crate::{Error, Result};
8
9/// A shareable mutable-ish optional reference to a an [`Arc`].
10///
11/// Because you can't actually change a shared [`Arc`], this type implements
12/// mutability by replacing the Arc itself with a new value.  It tries
13/// to avoid needless clones by taking advantage of [`Arc::make_mut`].
14///
15// We give this construction its own type to simplify its users, and make
16// sure we don't hold the lock against any async suspend points.
17#[derive(Debug, Educe)]
18#[educe(Default)]
19#[cfg_attr(docsrs, doc(cfg(feature = "experimental-api")))]
20#[cfg_attr(not(feature = "experimental-api"), allow(unreachable_pub))]
21pub struct SharedMutArc<T> {
22    /// Locked reference to the current value.
23    ///
24    /// (It's okay to use RwLock here, because we never suspend
25    /// while holding the lock.)
26    dir: RwLock<Option<Arc<T>>>,
27}
28
29#[cfg_attr(not(feature = "experimental-api"), allow(unreachable_pub))]
30impl<T> SharedMutArc<T> {
31    /// Construct a new empty SharedMutArc.
32    pub fn new() -> Self {
33        SharedMutArc::default()
34    }
35
36    /// Replace the current value with `new_val`.
37    pub fn replace(&self, new_val: T) {
38        let mut w = self
39            .dir
40            .write()
41            .expect("Poisoned lock for directory reference");
42        *w = Some(Arc::new(new_val));
43    }
44
45    /// Remove the current value of this SharedMutArc.
46    #[allow(unused)]
47    pub(crate) fn clear(&self) {
48        let mut w = self
49            .dir
50            .write()
51            .expect("Poisoned lock for directory reference");
52        *w = None;
53    }
54
55    /// Return a new reference to the current value, if there is one.
56    pub fn get(&self) -> Option<Arc<T>> {
57        let r = self
58            .dir
59            .read()
60            .expect("Poisoned lock for directory reference");
61        r.as_ref().map(Arc::clone)
62    }
63
64    /// Replace the contents of this SharedMutArc with the results of applying
65    /// `func` to the inner value.
66    ///
67    /// Gives an error if there is no inner value.
68    ///
69    /// Other threads will not abe able to access the inner value
70    /// while the function is running.
71    ///
72    /// # Limitation: No panic-safety
73    ///
74    /// If `func` panics while it's running, this object will become invalid
75    /// and future attempts to use it will panic. (TODO: Fix this.)
76    // Note: If we decide to make this type public, we'll probably
77    // want to fiddle with how we handle the return type.
78    pub fn mutate<F, U>(&self, func: F) -> Result<U>
79    where
80        F: FnOnce(&mut T) -> Result<U>,
81        T: Clone,
82    {
83        let mut writeable = self
84            .dir
85            .write()
86            .expect("Poisoned lock for directory reference");
87        let dir = writeable.as_mut();
88        match dir {
89            None => Err(Error::DirectoryNotPresent), // Kinda bogus.
90            Some(arc) => func(Arc::make_mut(arc)),
91        }
92    }
93}
94
95#[cfg(test)]
96mod test {
97    // @@ begin test lint list maintained by maint/add_warning @@
98    #![allow(clippy::bool_assert_comparison)]
99    #![allow(clippy::clone_on_copy)]
100    #![allow(clippy::dbg_macro)]
101    #![allow(clippy::mixed_attributes_style)]
102    #![allow(clippy::print_stderr)]
103    #![allow(clippy::print_stdout)]
104    #![allow(clippy::single_char_pattern)]
105    #![allow(clippy::unwrap_used)]
106    #![allow(clippy::unchecked_duration_subtraction)]
107    #![allow(clippy::useless_vec)]
108    #![allow(clippy::needless_pass_by_value)]
109    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
110    use super::*;
111    #[test]
112    fn shared_mut_arc() {
113        let val: SharedMutArc<Vec<u32>> = SharedMutArc::new();
114        assert_eq!(val.get(), None);
115
116        val.replace(Vec::new());
117        assert_eq!(val.get().unwrap().as_ref()[..], Vec::<u32>::new());
118
119        val.mutate(|v| {
120            v.push(99);
121            Ok(())
122        })
123        .unwrap();
124        assert_eq!(val.get().unwrap().as_ref()[..], [99]);
125
126        val.clear();
127        assert_eq!(val.get(), None);
128
129        assert!(val
130            .mutate(|v| {
131                v.push(99);
132                Ok(())
133            })
134            .is_err());
135    }
136}