1
//! Helper utilities
2
//!
3

            
4
// TODO RPC: Consider replacing this with a derive-deftly template.
5
//
6
/// Define an `impl From<fromty> for toty`` that wraps its input as
7
/// `toty::variant(Arc::new(e))``
8
macro_rules! define_from_for_arc {
9
    { $fromty:ty => $toty:ty [$variant:ident] } => {
10
        impl From<$fromty> for $toty {
11
10
            fn from(e: $fromty) -> $toty {
12
10
                Self::$variant(std::sync::Arc::new(e))
13
10
            }
14
        }
15
    };
16
}
17
use std::ffi::{CStr, CString, NulError};
18

            
19
pub(crate) use define_from_for_arc;
20

            
21
/// A string that is guaranteed to be UTF-8 and NUL-terminated,
22
/// for fast access as either type.
23
//
24
// TODO RPC: Rename so we can expose it more sensibly.
25
#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
26
pub struct Utf8CString {
27
    /// The body of this string.
28
    ///
29
    /// # Safety
30
    ///
31
    /// INVARIANT: This string must be valid UTF-8.
32
    ///
33
    /// (We do not _yet_ depend on this invariant for safety in our rust code, but we do promise in
34
    /// our C ffi that it will hold.)
35
    string: Box<CStr>,
36
}
37

            
38
impl AsRef<CStr> for Utf8CString {
39
    fn as_ref(&self) -> &CStr {
40
        &self.string
41
    }
42
}
43

            
44
impl AsRef<str> for Utf8CString {
45
4108
    fn as_ref(&self) -> &str {
46
4108
        // TODO: We might someday decide to implement this using unsafe methods, to avoid walking
47
4108
        // over the string to enforce properties that are already there.
48
4108
        self.string.to_str().expect("Utf8CString was not UTF-8‽")
49
4108
    }
50
}
51

            
52
// TODO: In theory we could have an unchecked version of this function, if we are 100%
53
// sure that serde_json will reject every string that contains a NUL.  But let's not do
54
// that unless the NUL check shows up in profiles.
55
impl TryFrom<String> for Utf8CString {
56
    type Error = NulError;
57

            
58
16496
    fn try_from(value: String) -> Result<Self, Self::Error> {
59
16496
        // Safety: Since `value` is a `String`, it is guaranteed to be UTF-8.
60
16496
        Ok(Utf8CString {
61
16496
            string: CString::new(value)?.into_boxed_c_str(),
62
        })
63
16496
    }
64
}
65

            
66
impl std::fmt::Display for Utf8CString {
67
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68
        let s: &str = self.as_ref();
69
        std::fmt::Display::fmt(s, f)
70
    }
71
}
72

            
73
/// An error from trying to convert a byte-slice to a Utf8CString.
74
#[derive(Clone, Debug, thiserror::Error)]
75
enum Utf8CStringFromBytesError {
76
    /// The bytes contained a nul, so we can't convert into a nul-terminated string.
77
    #[error("Bytes contained 0")]
78
    Nul(#[from] NulError),
79
    /// The bytes were not value UTF-8
80
    #[error("Bytes were not utf-8.")]
81
    Utf8(#[from] std::str::Utf8Error),
82
}
83

            
84
impl Utf8CString {
85
    /// Try to construct a new `Utf8CString` from a given byte slice.
86
    fn try_from_bytes(bytes: &[u8]) -> Result<Self, Utf8CStringFromBytesError> {
87
        let s: &str = std::str::from_utf8(bytes)?;
88
        Ok(s.to_owned().try_into()?)
89
    }
90
}
91

            
92
impl serde::Serialize for Utf8CString {
93
4
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
94
4
    where
95
4
        S: serde::Serializer,
96
4
    {
97
4
        serializer.serialize_str(self.as_ref())
98
4
    }
99
}
100
impl<'de> serde::Deserialize<'de> for Utf8CString {
101
8302
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
102
8302
    where
103
8302
        D: serde::Deserializer<'de>,
104
8302
    {
105
        /// Visitor to implement Deserialize for Utf8CString
106
        struct Visitor;
107
        impl<'de> serde::de::Visitor<'de> for Visitor {
108
            type Value = Utf8CString;
109

            
110
2
            fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
111
2
                fmt.write_str("a UTF-8 string with no internal NULs")
112
2
            }
113
            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
114
            where
115
                E: serde::de::Error,
116
            {
117
                Utf8CString::try_from_bytes(v).map_err(|e| E::custom(e))
118
            }
119
8300
            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
120
8300
            where
121
8300
                E: serde::de::Error,
122
8300
            {
123
8300
                Utf8CString::try_from(v.to_owned()).map_err(|e| E::custom(e))
124
8300
            }
125
        }
126
8302
        deserializer.deserialize_str(Visitor)
127
8302
    }
128
}
129

            
130
/// Ffi-related functionality for Utf8CStr
131
#[cfg(feature = "ffi")]
132
pub(crate) mod ffi {
133
    use std::ffi::c_char;
134

            
135
    impl super::Utf8CString {
136
        /// Expose this Utf8CStr as a C string.
137
        pub(crate) fn as_ptr(&self) -> *const c_char {
138
            self.string.as_ptr()
139
        }
140
    }
141
}
142

            
143
#[cfg(test)]
144
/// Assert that s1 and s2 are both valid json, and parse to the same serde_json::Value.
145
macro_rules! assert_same_json {
146
        { $s1:expr, $s2:expr } => {
147
            let v1: serde_json::Value = serde_json::from_str($s1).unwrap();
148
            let v2: serde_json::Value = serde_json::from_str($s2).unwrap();
149
            assert_eq!(v1, v2);
150
        }
151
    }
152
#[cfg(test)]
153
pub(crate) use assert_same_json;