arti_rpc_client_core/msgs/
request.rs

1//! Support for encoding and decoding RPC Requests.
2//!
3//! There are several types in this module:
4//!
5//! - [`Request`] is for requests that are generated from within this crate,
6//!   to implement authentication, negotiation, and other functionality.
7//! - `ParsedRequestFields` (internal) is for a request we've completely validated,
8//!   with all of its fields present.
9//! - [`ValidatedRequest`] is for a string that we have validated as a request.
10
11use std::sync::Arc;
12
13use serde::{Deserialize, Serialize};
14
15/// Alias for a Map as used by the serde_json.
16pub(crate) type JsonMap = serde_json::Map<String, serde_json::Value>;
17
18use crate::conn::ProtoError;
19
20use super::{AnyRequestId, JsonAnyObj, ObjectId};
21
22/// An outbound request that we have generated from within this crate.
23///
24/// It lacks a required `id` field (since we will generate one when sending it),
25/// and it allows any Serialize for its `params`.
26#[derive(Serialize, Debug)]
27// Testing only. Don't implement Deserialize here; this is not the type you should parse into!
28#[cfg_attr(test, derive(Eq, PartialEq, Deserialize))]
29#[allow(clippy::missing_docs_in_private_items)] // Fields are as for ParsedRequest.
30pub(crate) struct Request<T> {
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub(crate) id: Option<AnyRequestId>,
33    pub(crate) obj: ObjectId,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub(crate) meta: Option<RequestMeta>,
36    pub(crate) method: String,
37    pub(crate) params: T,
38}
39
40/// An error that has prevented us from validating an request.
41#[derive(Clone, Debug, thiserror::Error)]
42#[non_exhaustive]
43pub enum InvalidRequestError {
44    /// We failed to turn the request into any kind of json.
45    #[error("Request was not valid Json")]
46    InvalidJson(#[source] Arc<serde_json::Error>),
47    /// We got the request into json, but we couldn't find the fields we wanted.
48    #[error("Request's fields were invalid or missing")]
49    InvalidFormat(#[source] Arc<serde_json::Error>),
50    /// We validated the request, but couldn't re-encode it.
51    #[error("Unable to re-encode or format request")]
52    ReencodeFailed(#[source] Arc<serde_json::Error>),
53}
54
55impl<T: Serialize> Request<T> {
56    /// Construct a new outbound Request.
57    pub(crate) fn new(obj: ObjectId, method: impl Into<String>, params: T) -> Self {
58        Self {
59            id: None,
60            obj,
61            meta: Default::default(),
62            method: method.into(),
63            params,
64        }
65    }
66    /// Try to encode this request as a String.
67    ///
68    /// The string may not yet be a valid request; it might need to get an ID assigned.
69    pub(crate) fn encode(&self) -> Result<String, ProtoError> {
70        serde_json::to_string(self).map_err(|e| ProtoError::CouldNotEncode(Arc::new(e)))
71    }
72}
73
74/// A request in its decoded (or unencoded) format.
75///
76/// We use this type to validate outbound requests from the application.
77#[derive(Deserialize, Debug)]
78// Don't implement Serialize here; this is not for generating requests!
79#[allow(dead_code)] // The fields here are only used for validating serde objects.
80struct ParsedRequestFields {
81    /// The identifier for this request.
82    ///
83    /// Used to match a request with its responses.
84    id: AnyRequestId,
85    /// The ID for the object to which this request is addressed.
86    ///
87    /// (Every request goes to a single object.)
88    obj: ObjectId,
89    /// Additional information for Arti about how to handle the request.
90    #[serde(skip_serializing_if = "Option::is_none")]
91    meta: Option<RequestMeta>,
92    /// The name of the method to invoke.
93    method: String,
94    /// Parameters to pass to the method.
95    params: JsonAnyObj,
96}
97
98/// A known-valid request, encoded as a string (in a single line, with a terminating newline).
99#[derive(derive_more::AsRef, Debug, Clone)]
100pub(crate) struct ValidatedRequest {
101    /// The message itself, as encoded.
102    #[as_ref]
103    msg: String,
104    /// The ID for this request.
105    id: AnyRequestId,
106}
107
108impl ValidatedRequest {
109    /// Return the Id associated with this request.
110    pub(crate) fn id(&self) -> &AnyRequestId {
111        &self.id
112    }
113
114    /// Try to construct a validated request from a `serde_json::Value`.
115    fn from_json_value(val: serde_json::Value) -> Result<Self, InvalidRequestError> {
116        let mut msg = serde_json::to_string(&val)
117            .map_err(|e| InvalidRequestError::ReencodeFailed(Arc::new(e)))?;
118        debug_assert!(!msg.contains('\n'));
119        msg.push('\n');
120
121        let req: ParsedRequestFields = serde_json::from_value(val)
122            .map_err(|e| InvalidRequestError::InvalidFormat(Arc::new(e)))?;
123        let id = req.id;
124
125        Ok(ValidatedRequest { id, msg })
126    }
127
128    /// Try to construct a validated request using `s`.
129    pub(crate) fn from_string_strict(s: &str) -> Result<Self, InvalidRequestError> {
130        let value: serde_json::Value =
131            serde_json::from_str(s).map_err(|e| InvalidRequestError::InvalidJson(Arc::new(e)))?;
132        Self::from_json_value(value)
133    }
134
135    /// Try to construct a ValidatedRequest from the string in `s`.
136    ///
137    /// If it has no `id`, add one using `id_generator`.
138    pub(crate) fn from_string_loose<F>(
139        s: &str,
140        id_generator: F,
141    ) -> Result<Self, InvalidRequestError>
142    where
143        F: FnOnce() -> AnyRequestId,
144    {
145        let mut value: serde_json::Value =
146            serde_json::from_str(s).map_err(|e| InvalidRequestError::InvalidJson(Arc::new(e)))?;
147
148        if let Some(obj) = value.as_object_mut() {
149            obj.entry("id")
150                .or_insert_with(|| id_generator().into_json_value());
151        }
152
153        Self::from_json_value(value)
154    }
155}
156
157/// Crate-internal: The "meta" field in a request.
158#[derive(Deserialize, Serialize, Debug, Default)]
159#[cfg_attr(test, derive(Eq, PartialEq))]
160pub(crate) struct RequestMeta {
161    /// If true, the application wants to receive incremental updates
162    /// about the request that it sent.
163    ///
164    /// (Default: false)
165    #[serde(default)]
166    pub(crate) updates: bool,
167    /// Any unrecognized fields that we received from the user.
168    /// (We re-encode these in case the user knows about fields that we don't.)
169    #[serde(flatten)]
170    pub(crate) unrecognized_fields: JsonMap,
171}
172
173/// A helper to return unique Request identifiers.
174///
175/// All identifiers are prefixed with `"!aut o!--"`:
176/// if you don't use that string in your own IDs,
177/// you won't have any collisions.
178#[derive(Debug, Default)]
179pub(crate) struct IdGenerator {
180    /// The number
181    next_id: u64,
182}
183
184impl IdGenerator {
185    /// Return a previously unyielded identifier.
186    pub(crate) fn next_id(&mut self) -> AnyRequestId {
187        let id = self.next_id;
188        self.next_id += 1;
189        format!("!auto!--{id}").into()
190    }
191}
192
193#[cfg(test)]
194mod test {
195    // @@ begin test lint list maintained by maint/add_warning @@
196    #![allow(clippy::bool_assert_comparison)]
197    #![allow(clippy::clone_on_copy)]
198    #![allow(clippy::dbg_macro)]
199    #![allow(clippy::mixed_attributes_style)]
200    #![allow(clippy::print_stderr)]
201    #![allow(clippy::print_stdout)]
202    #![allow(clippy::single_char_pattern)]
203    #![allow(clippy::unwrap_used)]
204    #![allow(clippy::unchecked_duration_subtraction)]
205    #![allow(clippy::useless_vec)]
206    #![allow(clippy::needless_pass_by_value)]
207    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
208
209    impl ParsedRequestFields {
210        /// Return true if this request is asking for updates.
211        fn updates_requested(&self) -> bool {
212            self.meta.as_ref().map(|m| m.updates).unwrap_or(false)
213        }
214    }
215
216    use crate::util::assert_same_json;
217
218    use super::*;
219    const REQ1: &str = r#"{"id":7, "obj": "hi", "meta": {"updates": true}, "method":"twiddle", "params":{"stuff": "nonsense"} }"#;
220    const REQ2: &str = r#"{"id":"fred", "obj": "hi", "method":"twiddle", "params":{} }"#;
221    const REQ3: &str =
222        r#"{"id":"fred", "obj": "hi", "method":"twiddle", "params":{},"unrecognized":"waffles"}"#;
223
224    #[test]
225    fn parse_requests() {
226        let req1: ParsedRequestFields = serde_json::from_str(REQ1).unwrap();
227        assert_eq!(req1.id, 7.into());
228        assert_eq!(req1.obj.as_ref(), "hi");
229        assert_eq!(req1.updates_requested(), true);
230        assert_eq!(req1.method, "twiddle");
231
232        let req2: ParsedRequestFields = serde_json::from_str(REQ2).unwrap();
233        assert_eq!(req2.id, "fred".to_string().into());
234        assert_eq!(req2.obj.as_ref(), "hi");
235        assert_eq!(req2.updates_requested(), false);
236        assert_eq!(req2.method, "twiddle");
237
238        let _req3: ParsedRequestFields = serde_json::from_str(REQ2).unwrap();
239    }
240
241    #[test]
242    fn reencode_requests() {
243        for r in [REQ1, REQ2, REQ3] {
244            let val1 = ValidatedRequest::from_string_strict(r).unwrap();
245            let val2 = ValidatedRequest::from_string_loose(r, || panic!()).unwrap();
246
247            assert_same_json!(val1.as_ref(), val2.as_ref());
248            assert_same_json!(val1.as_ref(), r);
249        }
250    }
251
252    #[test]
253    fn bad_requests() {
254        for text in [
255            // not an object.
256            "123",
257            // missing most parts.
258            r#"{"id":12,}"#,
259            // no id.
260            r#"{"obj":"hi", "method":"twiddle", "params":{"stuff":"nonsense"}}"#,
261            // no params
262            r#"{"obj":"hi", "id": 7, "method":"twiddle"}"#,
263            // bad params type
264            r#"{"obj":"hi", "id": 7, "method":"twiddle", "params": []}"#,
265            // weird obj.
266            r#"{"obj":7, "id": 7, "method":"twiddle", "params":{"stuff":"nonsense"}}"#,
267            // weird id.
268            r#"{"obj":"hi", "id": [], "method":"twiddle", "params":{"stuff":"nonsense"}}"#,
269            // weird method
270            r#"{"obj":"hi", "id": 7, "method":6", "params":{"stuff":"nonsense"}}"#,
271        ] {
272            let r: Result<ParsedRequestFields, _> = serde_json::from_str(dbg!(text));
273            assert!(r.is_err());
274        }
275    }
276
277    #[test]
278    fn fix_requests() {
279        let no_id = r#"{"obj":"hi", "method":"twiddle", "params":{"stuff":"nonsense"}}"#;
280        let validated = ValidatedRequest::from_string_loose(no_id, || 7.into()).unwrap();
281        let expected_with_id =
282            r#"{"id": 7, "obj":"hi", "method":"twiddle", "params":{"stuff":"nonsense"}}"#;
283        assert_same_json!(validated.as_ref(), expected_with_id);
284    }
285
286    #[test]
287    fn preserve_fields() {
288        let orig = r#"
289            {"obj":"hi",
290             "meta": { "updates": true, "waffles": "yesplz" },
291             "method":"twiddle",
292             "params":{"stuff":"nonsense"},
293             "explosions": -70
294            }"#;
295        let validated = ValidatedRequest::from_string_loose(orig, || 77.into()).unwrap();
296        let expected_with_id = r#"
297            {"id":77,
298            "obj":"hi",
299            "meta": { "updates": true, "waffles": "yesplz" },
300            "method":"twiddle",
301            "params":{"stuff":"nonsense"},
302            "explosions": -70
303            }"#;
304        assert_same_json!(validated.as_ref(), expected_with_id);
305    }
306
307    #[test]
308    fn ok_request_encode() {
309        let expected_encoded_request =
310            r#"{"obj":"connection","method":"arti:get_rpc_proxy_info","params":"123"}"#;
311        let obj_id = ObjectId::connection_id();
312        let encoded_request = Request::new(obj_id, "arti:get_rpc_proxy_info", "123")
313            .encode()
314            .unwrap();
315        assert_eq!(expected_encoded_request, encoded_request);
316    }
317
318    // This should not be possible
319    #[test]
320    fn err_request_encode() {
321        struct FailingSerialization;
322
323        impl serde::Serialize for FailingSerialization {
324            fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
325            where
326                S: serde::Serializer,
327            {
328                Err(serde::ser::Error::custom(
329                    "Intentional serialization failure",
330                ))
331            }
332        }
333
334        let obj_id = ObjectId::connection_id();
335        let failing_request = Request::new(obj_id, "arti:get_rpc_proxy_info", FailingSerialization);
336
337        let err = failing_request.encode().unwrap_err();
338        assert!(matches!(err, ProtoError::CouldNotEncode(_)));
339    }
340}