arti_rpc_client_core/
conn.rs

1//! Middle-level API for RPC connections
2//!
3//! This module focuses around the `RpcConn` type, which supports sending RPC requests
4//! and matching them with their responses.
5
6use std::{
7    io::{self},
8    sync::{Arc, Mutex},
9};
10
11use crate::msgs::{
12    request::InvalidRequestError,
13    response::{ResponseKind, RpcError, ValidatedResponse},
14    AnyRequestId, ObjectId,
15};
16
17mod auth;
18mod builder;
19mod connimpl;
20mod stream;
21
22use crate::util::Utf8CString;
23pub use builder::{BuilderError, ConnPtDescription, RpcConnBuilder};
24pub use connimpl::RpcConn;
25use serde::{de::DeserializeOwned, Deserialize};
26pub use stream::StreamError;
27use tor_rpc_connect::{auth::cookie::CookieAccessError, HasClientErrorAction};
28
29/// A handle to an open request.
30///
31/// These handles are created with [`RpcConn::execute_with_handle`].
32///
33/// Note that dropping a RequestHandle does not cancel the associated request:
34/// it will continue running, but you won't have a way to receive updates from it.
35/// To cancel a request, use [`RpcConn::cancel`].
36#[derive(educe::Educe)]
37#[educe(Debug)]
38pub struct RequestHandle {
39    /// The underlying `Receiver` that we'll use to get updates for this request
40    ///
41    /// It's wrapped in a `Mutex` to prevent concurrent calls to `Receiver::wait_on_message_for`.
42    //
43    // NOTE: As an alternative to using a Mutex here, we _could_ remove
44    // the restriction from `wait_on_message_for` that says that only one thread
45    // may be waiting on a given request ID at once.  But that would introduce
46    // complexity to the implementation,
47    // and it's not clear that the benefit would be worth it.
48    #[educe(Debug(ignore))]
49    conn: Mutex<Arc<connimpl::Receiver>>,
50    /// The ID of this request.
51    id: AnyRequestId,
52}
53
54// TODO RPC: Possibly abolish these types.
55//
56// I am keeping this for now because it makes it more clear that we can never reinterpret
57// a success as an update or similar.
58//
59// I am not at all pleased with these types; we should revise them.
60//
61// TODO RPC: Possibly, all of these should be reconstructed
62// from their serde_json::Values rather than forwarded verbatim.
63// (But why would we our json to be more canonical than arti's? See #1491.)
64//
65// DODGY TYPES BEGIN: TODO RPC
66
67/// A Success Response from Arti, indicating that a request was successful.
68///
69/// This is the complete message, including `id` and `result` fields.
70//
71// Invariant: it is valid JSON and contains no NUL bytes or newlines.
72// TODO RPC: check that the newline invariant is enforced in constructors.
73#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
74#[as_ref(forward)]
75pub struct SuccessResponse(Utf8CString);
76
77impl SuccessResponse {
78    /// Helper: Decode the `result` field of this response as an instance of D.
79    fn decode<D: DeserializeOwned>(&self) -> Result<D, serde_json::Error> {
80        /// Helper object for decoding the "result" field.
81        #[derive(Deserialize)]
82        struct Response<R> {
83            /// The decoded value.
84            result: R,
85        }
86        let response: Response<D> = serde_json::from_str(self.as_ref())?;
87        Ok(response.result)
88    }
89}
90
91/// An Update Response from Arti, with information about the progress of a request.
92///
93/// This is the complete message, including `id` and `update` fields.
94//
95// Invariant: it is valid JSON and contains no NUL bytes or newlines.
96// TODO RPC: check that the newline invariant is enforced in constructors.
97// TODO RPC consider changing this to CString.
98#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
99#[as_ref(forward)]
100pub struct UpdateResponse(Utf8CString);
101
102/// A Error Response from Arti, indicating that an error occurred.
103///
104/// (This is the complete message, including the `error` field.
105/// It also an `id` if it
106/// is in response to a request; but not if it is a fatal protocol error.)
107//
108// Invariant: Does not contain a NUL. (Safe to convert to CString.)
109//
110// Invariant: This field MUST encode a response whose body is an RPC error.
111//
112// Otherwise the `decode` method may panic.
113//
114// TODO RPC: check that the newline invariant is enforced in constructors.
115#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
116#[as_ref(forward)]
117// TODO: If we keep this, it should implement Error.
118pub struct ErrorResponse(Utf8CString);
119impl ErrorResponse {
120    /// Construct an ErrorResponse from the Error reply.
121    ///
122    /// This not a From impl because we want it to be crate-internal.
123    pub(crate) fn from_validated_string(s: Utf8CString) -> Self {
124        ErrorResponse(s)
125    }
126
127    /// Convert this response into an internal error in response to `cmd`.
128    ///
129    /// This is only appropriate when the error cannot be caused because of user behavior.
130    pub(crate) fn internal_error(&self, cmd: &str) -> ProtoError {
131        ProtoError::InternalRequestFailed(UnexpectedReply {
132            request: cmd.to_string(),
133            reply: self.to_string(),
134            problem: UnexpectedReplyProblem::ErrorNotExpected,
135        })
136    }
137
138    /// Try to interpret this response as an [`RpcError`].
139    pub fn decode(&self) -> RpcError {
140        crate::msgs::response::try_decode_response_as_err(self.0.as_ref())
141            .expect("Could not decode response that was already decoded as an error?")
142            .expect("Could not extract error from response that was already decoded as an error?")
143    }
144}
145
146impl std::fmt::Display for ErrorResponse {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        let e = self.decode();
149        write!(f, "Peer said {:?}", e.message())
150    }
151}
152
153/// A final response -- that is, the last one that we expect to receive for a request.
154///
155type FinalResponse = Result<SuccessResponse, ErrorResponse>;
156
157/// Any of the three types of Arti responses.
158#[derive(Clone, Debug)]
159#[allow(clippy::exhaustive_structs)]
160pub enum AnyResponse {
161    /// The request has succeeded; no more response will be given.
162    Success(SuccessResponse),
163    /// The request has failed; no more response will be given.
164    Error(ErrorResponse),
165    /// An incremental update; more messages may arrive.
166    Update(UpdateResponse),
167}
168// TODO RPC: DODGY TYPES END.
169
170impl AnyResponse {
171    /// Convert `v` into `AnyResponse`.
172    fn from_validated(v: ValidatedResponse) -> Self {
173        // TODO RPC, Perhaps unify AnyResponse with ValidatedResponse, once we are sure what
174        // AnyResponse should look like.
175        match v.meta.kind {
176            ResponseKind::Error => AnyResponse::Error(ErrorResponse::from_validated_string(v.msg)),
177            ResponseKind::Success => AnyResponse::Success(SuccessResponse(v.msg)),
178            ResponseKind::Update => AnyResponse::Update(UpdateResponse(v.msg)),
179        }
180    }
181
182    /// Consume this `AnyResponse`, and return its internal string.
183    #[cfg(feature = "ffi")]
184    pub(crate) fn into_string(self) -> Utf8CString {
185        match self {
186            AnyResponse::Success(m) => m.into(),
187            AnyResponse::Error(m) => m.into(),
188            AnyResponse::Update(m) => m.into(),
189        }
190    }
191}
192
193impl RpcConn {
194    /// Return the ObjectId for the negotiated Session.
195    ///
196    /// Nearly all RPC methods require a Session, or some other object
197    /// accessed via the session.
198    ///
199    /// (This function will only return None if no authentication has been performed.
200    /// TODO RPC: It is not currently possible to make an unauthenticated connection.)
201    pub fn session(&self) -> Option<&ObjectId> {
202        self.session.as_ref()
203    }
204
205    /// Run a command, and wait for success or failure.
206    ///
207    /// Note that this function will return `Err(.)` only if sending the command or getting a
208    /// response failed.  If the command was sent successfully, and Arti reported an error in response,
209    /// this function returns `Ok(Err(.))`.
210    ///
211    /// Note that the command does not need to include an `id` field.  If you omit it,
212    /// one will be generated.
213    pub fn execute(&self, cmd: &str) -> Result<FinalResponse, ProtoError> {
214        let hnd = self.execute_with_handle(cmd)?;
215        hnd.wait()
216    }
217
218    /// Helper for executing internally-generated requests and decoding their results.
219    ///
220    /// Behaves like `execute`, except on success, where it tries to decode the `result` field
221    /// of the response as a `T`.
222    ///
223    /// Use this method in cases where it's reasonable for Arti to sometimes return an RPC error:
224    /// in other words, where it's not necessarily a programming error or version mismatch.
225    ///
226    /// Don't use this for user-generated requests: it will misreport unexpected replies
227    /// as internal errors.
228    pub(crate) fn execute_internal<T: DeserializeOwned>(
229        &self,
230        cmd: &str,
231    ) -> Result<Result<T, ErrorResponse>, ProtoError> {
232        match self.execute(cmd)? {
233            Ok(success) => match success.decode::<T>() {
234                Ok(result) => Ok(Ok(result)),
235                Err(json_error) => Err(ProtoError::InternalRequestFailed(UnexpectedReply {
236                    request: cmd.to_string(),
237                    reply: Utf8CString::from(success).to_string(),
238                    problem: UnexpectedReplyProblem::CannotDecode(Arc::new(json_error)),
239                })),
240            },
241            Err(error) => Ok(Err(error)),
242        }
243    }
244
245    /// Helper for executing internally-generated requests and decoding their results.
246    ///
247    /// Behaves like `execute_internal`, except that it treats any RPC error reply
248    /// as an internal error or version mismatch.
249    ///
250    /// Don't use this for user-generated requests, or for requests that can fail because of
251    /// incorrect user inputs: it will misreport failures in those requests as internal errors.
252    pub(crate) fn execute_internal_ok<T: DeserializeOwned>(
253        &self,
254        cmd: &str,
255    ) -> Result<T, ProtoError> {
256        match self.execute_internal(cmd)? {
257            Ok(v) => Ok(v),
258            Err(err_response) => Err(err_response.internal_error(cmd)),
259        }
260    }
261
262    /// Cancel a request by ID.
263    pub fn cancel(&self, request_id: &AnyRequestId) -> Result<(), ProtoError> {
264        /// Arguments to an `rpc::cancel` request.
265        #[derive(serde::Serialize, Debug)]
266        struct CancelParams<'a> {
267            /// The request to cancel.
268            request_id: &'a AnyRequestId,
269        }
270
271        let request = crate::msgs::request::Request::new(
272            ObjectId::connection_id(),
273            "rpc:cancel",
274            CancelParams { request_id },
275        );
276        match self.execute_internal::<EmptyReply>(&request.encode()?)? {
277            Ok(EmptyReply {}) => Ok(()),
278            Err(_) => Err(ProtoError::RequestCompleted),
279        }
280    }
281
282    /// Like `execute`, but don't wait.  This lets the caller see the
283    /// request ID and  maybe cancel it.
284    pub fn execute_with_handle(&self, cmd: &str) -> Result<RequestHandle, ProtoError> {
285        self.send_request(cmd)
286    }
287    /// As execute(), but run update_cb for every update we receive.
288    pub fn execute_with_updates<F>(
289        &self,
290        cmd: &str,
291        mut update_cb: F,
292    ) -> Result<FinalResponse, ProtoError>
293    where
294        F: FnMut(UpdateResponse) + Send + Sync,
295    {
296        let hnd = self.execute_with_handle(cmd)?;
297        loop {
298            match hnd.wait_with_updates()? {
299                AnyResponse::Success(s) => return Ok(Ok(s)),
300                AnyResponse::Error(e) => return Ok(Err(e)),
301                AnyResponse::Update(u) => update_cb(u),
302            }
303        }
304    }
305
306    /// Helper: Tell Arti to release `obj`.
307    ///
308    /// Do not use this method for a user-provided object ID:
309    /// It gives an internal error if the object does not exist.
310    pub(crate) fn release_obj(&self, obj: ObjectId) -> Result<(), ProtoError> {
311        let release_request = crate::msgs::request::Request::new(obj, "rpc:release", NoParams {});
312        let _empty_response: EmptyReply = self.execute_internal_ok(&release_request.encode()?)?;
313        Ok(())
314    }
315
316    // TODO RPC: shutdown() on the socket on Drop.
317}
318
319impl RequestHandle {
320    /// Return the ID of this request, to help cancelling it.
321    pub fn id(&self) -> &AnyRequestId {
322        &self.id
323    }
324    /// Wait for success or failure, and return what happened.
325    ///
326    /// (Ignores any update messages that are received.)
327    ///
328    /// Note that this function will return `Err(.)` only if sending the command or getting a
329    /// response failed.  If the command was sent successfully, and Arti reported an error in response,
330    /// this function returns `Ok(Err(.))`.
331    pub fn wait(self) -> Result<FinalResponse, ProtoError> {
332        loop {
333            match self.wait_with_updates()? {
334                AnyResponse::Success(s) => return Ok(Ok(s)),
335                AnyResponse::Error(e) => return Ok(Err(e)),
336                AnyResponse::Update(_) => {}
337            }
338        }
339    }
340    /// Wait for the next success, failure, or update from this handle.
341    ///
342    /// Note that this function will return `Err(.)` only if sending the command or getting a
343    /// response failed.  If the command was sent successfully, and Arti reported an error in response,
344    /// this function returns `Ok(AnyResponse::Error(.))`.
345    ///
346    /// You may call this method on the same `RequestHandle` from multiple threads.
347    /// If you do so, those calls will receive responses (or errors) in an unspecified order.
348    ///
349    /// If this function returns Success or Error, then you shouldn't call it again.
350    /// All future calls to this function will fail with `CmdError::RequestCancelled`.
351    /// (TODO RPC: Maybe rename that error.)
352    pub fn wait_with_updates(&self) -> Result<AnyResponse, ProtoError> {
353        let conn = self.conn.lock().expect("Poisoned lock");
354        let validated = conn.wait_on_message_for(&self.id)?;
355
356        Ok(AnyResponse::from_validated(validated))
357    }
358
359    // TODO RPC: Sketch out how we would want to do this in an async world,
360    // or with poll
361}
362
363/// An error (or other condition) that has caused an RPC connection to shut down.
364#[derive(Clone, Debug, thiserror::Error)]
365#[non_exhaustive]
366pub enum ShutdownError {
367    /// Io error occurred while reading.
368    #[error("Unable to read response")]
369    Read(#[source] Arc<io::Error>),
370    /// Io error occurred while writing.
371    #[error("Unable to write request")]
372    Write(#[source] Arc<io::Error>),
373    /// Something was wrong with Arti's responses; this is a protocol violation.
374    #[error("Arti sent a message that didn't conform to the RPC protocol: {0:?}")]
375    ProtocolViolated(String),
376    /// Arti has told us that we violated the protocol somehow.
377    #[error("Arti reported a fatal error: {0:?}")]
378    ProtocolViolationReport(ErrorResponse),
379    /// The underlying connection closed.
380    ///
381    /// This probably means that Arti has shut down.
382    #[error("Connection closed")]
383    ConnectionClosed,
384}
385
386impl From<crate::msgs::response::DecodeResponseError> for ShutdownError {
387    fn from(value: crate::msgs::response::DecodeResponseError) -> Self {
388        use crate::msgs::response::DecodeResponseError::*;
389        use ShutdownError as E;
390        match value {
391            JsonProtocolViolation(e) => E::ProtocolViolated(e.to_string()),
392            ProtocolViolation(s) => E::ProtocolViolated(s.to_string()),
393            Fatal(rpc_err) => E::ProtocolViolationReport(rpc_err),
394        }
395    }
396}
397
398/// An error that has occurred while launching an RPC command.
399#[derive(Clone, Debug, thiserror::Error)]
400#[non_exhaustive]
401pub enum ProtoError {
402    /// The RPC connection failed, or was closed by the other side.
403    #[error("RPC connection is shut down")]
404    Shutdown(#[from] ShutdownError),
405
406    /// There was a problem in the request we tried to send.
407    #[error("Invalid request")]
408    InvalidRequest(#[from] InvalidRequestError),
409
410    /// We tried to send a request with an ID that was already pending.
411    #[error("Request ID already in use.")]
412    RequestIdInUse,
413
414    /// We tried to wait for or inspect a request that had already succeeded or failed.
415    #[error("Request has already completed (or failed)")]
416    RequestCompleted,
417
418    /// We tried to wait for the same request more than once.
419    ///
420    /// (This should be impossible.)
421    #[error("Internal error: waiting on the same request more than once at a time.")]
422    DuplicateWait,
423
424    /// We got an internal error while trying to encode an RPC request.
425    ///
426    /// (This should be impossible.)
427    #[error("Internal error while encoding request")]
428    CouldNotEncode(#[source] Arc<serde_json::Error>),
429
430    /// We got a response to some internally generated request that wasn't what we expected.
431    #[error("{0}")]
432    InternalRequestFailed(#[source] UnexpectedReply),
433}
434
435/// A set of errors encountered while trying to connect to the Arti process
436#[derive(Clone, Debug, thiserror::Error)]
437pub struct ConnectFailure {
438    /// A list of all the declined connect points we encountered, and how they failed.
439    declined: Vec<(builder::ConnPtDescription, ConnectError)>,
440    /// A description of where we found the final error (if it's an abort.)
441    final_desc: Option<builder::ConnPtDescription>,
442    /// The final error explaining why we couldn't connect.
443    ///
444    /// This is either an abort, an AllAttemptsDeclined, or an error that prevented the
445    /// search process from even beginning.
446    #[source]
447    pub(crate) final_error: ConnectError,
448}
449
450impl std::fmt::Display for ConnectFailure {
451    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
452        write!(f, "Unable to connect")?;
453        if !self.declined.is_empty() {
454            write!(
455                f,
456                " ({} attempts failed{})",
457                self.declined.len(),
458                if matches!(self.final_error, ConnectError::AllAttemptsDeclined) {
459                    ""
460                } else {
461                    " before fatal error"
462                }
463            )?;
464        }
465        Ok(())
466    }
467}
468
469impl ConnectFailure {
470    /// If this attempt failed because of a fatal error that made a connect point attempt abort,
471    /// return a description of the origin of that connect point.
472    pub fn fatal_error_origin(&self) -> Option<&builder::ConnPtDescription> {
473        self.final_desc.as_ref()
474    }
475
476    /// For each connect attempt that failed nonfatally, return a description of the
477    /// origin of that connect point, and the error that caused it to fail.
478    pub fn declined_attempt_outcomes(
479        &self,
480    ) -> impl Iterator<Item = (&builder::ConnPtDescription, &ConnectError)> {
481        // Note: this map looks like a no-op, but isn't.
482        self.declined.iter().map(|(a, b)| (a, b))
483    }
484
485    /// Return a helper type to format this error, and all of its internal errors recursively.
486    ///
487    /// Unlike [`tor_error::Report`], this method includes not only fatal errors, but also
488    /// information about connect attempts that failed nonfatally.
489    pub fn display_verbose(&self) -> ConnectFailureVerboseFmt<'_> {
490        ConnectFailureVerboseFmt(self)
491    }
492}
493
494/// Helper type to format a ConnectFailure along with all of its internal errors,
495/// including non-fatal errors.
496#[derive(Debug, Clone)]
497pub struct ConnectFailureVerboseFmt<'a>(&'a ConnectFailure);
498
499impl<'a> std::fmt::Display for ConnectFailureVerboseFmt<'a> {
500    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501        use tor_error::ErrorReport as _;
502        writeln!(f, "{}:", self.0)?;
503        for (idx, (origin, error)) in self.0.declined_attempt_outcomes().enumerate() {
504            writeln!(f, "  {}. {}: {}", idx + 1, origin, error.report())?;
505        }
506        if let Some(origin) = self.0.fatal_error_origin() {
507            writeln!(
508                f,
509                "  {}. [FATAL] {}: {}",
510                self.0.declined.len() + 1,
511                origin,
512                self.0.final_error.report()
513            )?;
514        } else {
515            writeln!(f, "  - {}", self.0.final_error.report())?;
516        }
517        Ok(())
518    }
519}
520
521/// An error while trying to connect to the Arti process.
522#[derive(Clone, Debug, thiserror::Error)]
523#[non_exhaustive]
524pub enum ConnectError {
525    /// Unable to parse connect points from an environment variable.
526    #[error("Cannot parse connect points from environment variable")]
527    BadEnvironment,
528    /// We were unable to load and/or parse a given connect point.
529    #[error("Unable to load and parse connect point")]
530    CannotParse(#[from] tor_rpc_connect::load::LoadError),
531    /// The path used to specify a connect file couldn't be resolved.
532    #[error("Unable to resolve connect point path")]
533    CannotResolvePath(#[source] tor_config_path::CfgPathError),
534    /// A parsed connect point couldn't be resolved.
535    #[error("Unable to resolve connect point")]
536    CannotResolveConnectPoint(#[from] tor_rpc_connect::ResolveError),
537    /// IO error while connecting to Arti.
538    #[error("Unable to make a connection")]
539    CannotConnect(#[from] tor_rpc_connect::ConnectError),
540    /// Opened a connection, but didn't get a banner message.
541    ///
542    /// (This isn't a `BadMessage`, since it is likelier to represent something that isn't
543    /// pretending to be Arti at all than it is to be a malfunctioning Arti.)
544    #[error("Did not receive expected banner message upon connecting")]
545    InvalidBanner,
546    /// All attempted connect points were declined, and none were aborted.
547    #[error("All connect points were declined (or there were none)")]
548    AllAttemptsDeclined,
549    /// A connect file or directory was given as a relative path.
550    /// (Only absolute paths are supported).
551    #[error("Connect file was given as a relative path.")]
552    RelativeConnectFile,
553    /// One of our authentication messages received an error.
554    #[error("Received an error while trying to authenticate: {0}")]
555    AuthenticationFailed(ErrorResponse),
556    /// The connect point uses an RPC authentication type we don't support.
557    #[error("Authentication type is not supported")]
558    AuthenticationNotSupported,
559    /// We couldn't decode one of the responses we got.
560    #[error("Message not in expected format")]
561    BadMessage(#[source] Arc<serde_json::Error>),
562    /// A protocol error occurred during negotiations.
563    #[error("Error while negotiating with Arti")]
564    ProtoError(#[from] ProtoError),
565    /// The server thinks it is listening on an address where we don't expect to find it.
566    /// This can be misconfiguration or an attempted MITM attack.
567    #[error("We connected to the server at {ours}, but it thinks it's listening at {theirs}")]
568    ServerAddressMismatch {
569        /// The address we think the server has
570        ours: String,
571        /// The address that the server says it has.
572        theirs: String,
573    },
574    /// The server tried to prove knowledge of a cookie file, but its proof was incorrect.
575    #[error("Server's cookie MAC was not as expected.")]
576    CookieMismatch,
577    /// We were unable to access the configured cookie file.
578    #[error("Unable to load secret cookie value")]
579    LoadCookie(#[from] CookieAccessError),
580}
581
582impl HasClientErrorAction for ConnectError {
583    fn client_action(&self) -> tor_rpc_connect::ClientErrorAction {
584        use tor_rpc_connect::ClientErrorAction as A;
585        use ConnectError as E;
586        match self {
587            E::BadEnvironment => A::Abort,
588            E::CannotParse(e) => e.client_action(),
589            E::CannotResolvePath(_) => A::Abort,
590            E::CannotResolveConnectPoint(e) => e.client_action(),
591            E::CannotConnect(e) => e.client_action(),
592            E::InvalidBanner => A::Decline,
593            E::RelativeConnectFile => A::Abort,
594            E::AuthenticationFailed(_) => A::Decline,
595            // TODO RPC: Is this correct?  This error can also occur when
596            // we are talking to something other than an RPC server.
597            E::BadMessage(_) => A::Abort,
598            E::ProtoError(e) => e.client_action(),
599            E::AllAttemptsDeclined => A::Abort,
600            E::AuthenticationNotSupported => A::Decline,
601            E::ServerAddressMismatch { .. } => A::Abort,
602            E::CookieMismatch => A::Abort,
603            E::LoadCookie(e) => e.client_action(),
604        }
605    }
606}
607
608impl HasClientErrorAction for ProtoError {
609    fn client_action(&self) -> tor_rpc_connect::ClientErrorAction {
610        use tor_rpc_connect::ClientErrorAction as A;
611        use ProtoError as E;
612        match self {
613            E::Shutdown(_) => A::Decline,
614            E::InternalRequestFailed(_) => A::Decline,
615            // These are always internal errors if they occur while negotiating a connection to RPC,
616            // which is the context we care about for `HasClientErrorAction`.
617            E::InvalidRequest(_)
618            | E::RequestIdInUse
619            | E::RequestCompleted
620            | E::DuplicateWait
621            | E::CouldNotEncode(_) => A::Abort,
622        }
623    }
624}
625
626/// In response to a request that we generated internally,
627/// Arti gave a reply that we did not understand.
628///
629/// This could be due to a bug in this library, a bug in Arti,
630/// or a compatibility issue between the two.
631#[derive(Clone, Debug, thiserror::Error)]
632#[error("In response to our request {request:?}, Arti gave the unexpected reply {reply:?}")]
633pub struct UnexpectedReply {
634    /// The request we sent.
635    request: String,
636    /// The response we got.
637    reply: String,
638    /// What was wrong with the response.
639    #[source]
640    problem: UnexpectedReplyProblem,
641}
642
643/// Underlying reason for an UnexpectedReply
644#[derive(Clone, Debug, thiserror::Error)]
645enum UnexpectedReplyProblem {
646    /// There was a json failure while trying to decode the response:
647    /// the result type was not what we expected.
648    #[error("Cannot decode as correct JSON type")]
649    CannotDecode(Arc<serde_json::Error>),
650    /// Arti replied with an RPC error in a context no error should have been possible.
651    #[error("Unexpected error")]
652    ErrorNotExpected,
653}
654
655/// Arguments to a request that takes no parameters.
656#[derive(serde::Serialize, Debug)]
657struct NoParams {}
658
659/// A reply with no data.
660#[derive(serde::Deserialize, Debug)]
661struct EmptyReply {}
662
663#[cfg(test)]
664mod test {
665    // @@ begin test lint list maintained by maint/add_warning @@
666    #![allow(clippy::bool_assert_comparison)]
667    #![allow(clippy::clone_on_copy)]
668    #![allow(clippy::dbg_macro)]
669    #![allow(clippy::mixed_attributes_style)]
670    #![allow(clippy::print_stderr)]
671    #![allow(clippy::print_stdout)]
672    #![allow(clippy::single_char_pattern)]
673    #![allow(clippy::unwrap_used)]
674    #![allow(clippy::unchecked_duration_subtraction)]
675    #![allow(clippy::useless_vec)]
676    #![allow(clippy::needless_pass_by_value)]
677    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
678
679    use std::{sync::atomic::AtomicUsize, thread, time::Duration};
680
681    use io::{BufRead as _, BufReader, Write as _};
682    use rand::{seq::SliceRandom as _, Rng as _, SeedableRng as _};
683    use tor_basic_utils::{test_rng::testing_rng, RngExt as _};
684
685    use crate::{
686        llconn,
687        msgs::request::{JsonMap, Request, ValidatedRequest},
688    };
689
690    use super::*;
691
692    /// helper: Return a dummy RpcConn, along with a socketpair for it to talk to.
693    fn dummy_connected() -> (RpcConn, crate::testing::SocketpairStream) {
694        let (s1, s2) = crate::testing::construct_socketpair().unwrap();
695        let s1_w = s1.try_clone().unwrap();
696        let s1_r = io::BufReader::new(s1);
697        let conn = RpcConn::new(llconn::Reader::new(s1_r), llconn::Writer::new(s1_w));
698
699        (conn, s2)
700    }
701
702    fn write_val(w: &mut impl io::Write, v: &serde_json::Value) {
703        let mut enc = serde_json::to_string(v).unwrap();
704        enc.push('\n');
705        w.write_all(enc.as_bytes()).unwrap();
706    }
707
708    #[test]
709    fn simple() {
710        let (conn, sock) = dummy_connected();
711
712        let user_thread = thread::spawn(move || {
713            let response1 = conn
714                .execute_internal_ok::<JsonMap>(
715                    r#"{"obj":"fred","method":"arti:x-frob","params":{}}"#,
716                )
717                .unwrap();
718            (response1, conn)
719        });
720
721        let fake_arti_thread = thread::spawn(move || {
722            let mut sock = BufReader::new(sock);
723            let mut s = String::new();
724            let _len = sock.read_line(&mut s).unwrap();
725            let request = ValidatedRequest::from_string_strict(s.as_ref()).unwrap();
726            let response = serde_json::json!({
727                "id": request.id().clone(),
728                "result": { "xyz" : 3 }
729            });
730            write_val(sock.get_mut(), &response);
731            sock // prevent close
732        });
733
734        let _sock = fake_arti_thread.join().unwrap();
735        let (map, _conn) = user_thread.join().unwrap();
736        assert_eq!(map.get("xyz"), Some(&serde_json::Value::Number(3.into())));
737    }
738
739    #[test]
740    fn complex() {
741        use std::sync::atomic::Ordering::SeqCst;
742        let n_threads = 16;
743        let n_commands_per_thread = 128;
744        let n_commands_total = n_threads * n_commands_per_thread;
745        let n_completed = Arc::new(AtomicUsize::new(0));
746
747        let (conn, sock) = dummy_connected();
748        let conn = Arc::new(conn);
749        let mut user_threads = Vec::new();
750        let mut rng = testing_rng();
751
752        // -------
753        // User threads: Make a bunch of requests.
754        for th_idx in 0..n_threads {
755            let conn = Arc::clone(&conn);
756            let n_completed = Arc::clone(&n_completed);
757            let mut rng = rand_chacha::ChaCha12Rng::from_seed(rng.gen());
758            let th = thread::spawn(move || {
759                for cmd_idx in 0..n_commands_per_thread {
760                    // We are spawning a bunch of worker threads, each of which will run a number of
761                    // commands in sequence.  Each command will be a request that gets optional
762                    // updates, and an error or a success.
763                    // We will double-check that each request gets the response it asked for.
764                    let s = format!("{}:{}", th_idx, cmd_idx);
765                    let want_updates: bool = rng.gen();
766                    let want_failure: bool = rng.gen();
767                    let req = serde_json::json!({
768                        "obj":"fred",
769                        "method":"arti:x-echo",
770                        "meta": {
771                            "updates": want_updates,
772                        },
773                        "params": {
774                            "val": &s,
775                            "fail": want_failure,
776                        },
777                    });
778                    let req = serde_json::to_string(&req).unwrap();
779
780                    // Wait for a final response, processing updates if we asked for them.
781                    let mut n_updates = 0;
782                    let outcome = conn
783                        .execute_with_updates(&req, |_update| {
784                            n_updates += 1;
785                        })
786                        .unwrap();
787                    assert_eq!(n_updates > 0, want_updates);
788
789                    // See if we liked the final response.
790                    if want_failure {
791                        let e = outcome.unwrap_err().decode();
792                        assert_eq!(e.message(), "You asked me to fail");
793                        assert_eq!(i32::from(e.code()), 33);
794                        assert_eq!(
795                            e.kinds_iter().collect::<Vec<_>>(),
796                            vec!["Example".to_string()]
797                        );
798                    } else {
799                        let success = outcome.unwrap();
800                        let map = success.decode::<JsonMap>().unwrap();
801                        assert_eq!(map.get("echo"), Some(&serde_json::Value::String(s)));
802                    }
803                    n_completed.fetch_add(1, SeqCst);
804                    if rng.gen::<f32>() < 0.02 {
805                        thread::sleep(Duration::from_millis(3));
806                    }
807                }
808            });
809            user_threads.push(th);
810        }
811
812        #[derive(serde::Deserialize, Debug)]
813        struct Echo {
814            val: String,
815            fail: bool,
816        }
817
818        // -----
819        // Worker thread: handles user requests.
820        let worker_rng = rand_chacha::ChaCha12Rng::from_seed(rng.gen());
821        let worker_thread = thread::spawn(move || {
822            let mut rng = worker_rng;
823            let mut sock = BufReader::new(sock);
824            let mut pending: Vec<Request<Echo>> = Vec::new();
825            let mut n_received = 0;
826
827            // How many requests do we buffer before we shuffle them and answer them out-of-order?
828            let scramble_factor = 7;
829            // After receiving how many requests do we stop shuffling requests?
830            //
831            // (Our shuffling algorithm can deadlock us otherwise.)
832            let scramble_threshold =
833                n_commands_total - (n_commands_per_thread + 1) * scramble_factor;
834
835            'outer: loop {
836                let flush_pending_at = if n_received >= scramble_threshold {
837                    1
838                } else {
839                    scramble_factor
840                };
841
842                // Queue a handful of requests in "pending"
843                while pending.len() < flush_pending_at {
844                    let mut buf = String::new();
845                    if sock.read_line(&mut buf).unwrap() == 0 {
846                        break 'outer;
847                    }
848                    n_received += 1;
849                    let req: Request<Echo> = serde_json::from_str(&buf).unwrap();
850                    pending.push(req);
851                }
852
853                // Handle the requests in "pending" in random order.
854                let mut handling = std::mem::take(&mut pending);
855                handling.shuffle(&mut rng);
856
857                for req in handling {
858                    if req.meta.unwrap_or_default().updates {
859                        let n_updates = rng.gen_range_checked(1..4).unwrap();
860                        for _ in 0..n_updates {
861                            let up = serde_json::json!({
862                                "id": req.id.clone(),
863                                "update": {
864                                    "hello": req.params.val.clone(),
865                                }
866                            });
867                            write_val(sock.get_mut(), &up);
868                        }
869                    }
870
871                    let response = if req.params.fail {
872                        serde_json::json!({
873                            "id": req.id.clone(),
874                            "error": { "message": "You asked me to fail", "code": 33, "kinds": ["Example"], "data": req.params.val },
875                        })
876                    } else {
877                        serde_json::json!({
878                            "id": req.id.clone(),
879                            "result": {
880                                "echo": req.params.val
881                            }
882                        })
883                    };
884                    write_val(sock.get_mut(), &response);
885                }
886            }
887        });
888        drop(conn);
889        for t in user_threads {
890            t.join().unwrap();
891        }
892
893        worker_thread.join().unwrap();
894
895        assert_eq!(n_completed.load(SeqCst), n_commands_total);
896    }
897
898    #[test]
899    fn arti_socket_closed() {
900        // Here we send a bunch of requests and then close the socket without answering them.
901        //
902        // Every request should get a ProtoError::Shutdown.
903        let n_threads = 16;
904
905        let (conn, sock) = dummy_connected();
906        let conn = Arc::new(conn);
907        let mut user_threads = Vec::new();
908        for _ in 0..n_threads {
909            let conn = Arc::clone(&conn);
910            let th = thread::spawn(move || {
911                // We are spawning a bunch of worker threads, each of which will run a number of
912                // We will double-check that each request gets the response it asked for.
913                let req = serde_json::json!({
914                    "obj":"fred",
915                    "method":"arti:x-echo",
916                    "params":{}
917                });
918                let req = serde_json::to_string(&req).unwrap();
919                let outcome = conn.execute(&req);
920                if !matches!(
921                    &outcome,
922                    Err(ProtoError::Shutdown(ShutdownError::Write(_)))
923                        | Err(ProtoError::Shutdown(ShutdownError::Read(_))),
924                ) {
925                    dbg!(&outcome);
926                }
927
928                assert!(matches!(
929                    outcome,
930                    Err(ProtoError::Shutdown(ShutdownError::Write(_)))
931                        | Err(ProtoError::Shutdown(ShutdownError::Read(_)))
932                        | Err(ProtoError::Shutdown(ShutdownError::ConnectionClosed))
933                ));
934            });
935            user_threads.push(th);
936        }
937
938        drop(sock);
939
940        for t in user_threads {
941            t.join().unwrap();
942        }
943    }
944
945    /// Send a bunch of requests and then send back a single reply.
946    ///
947    /// That reply should cause every request to get closed.
948    fn proto_err_with_msg<F>(msg: &str, outcome_ok: F)
949    where
950        F: Fn(ProtoError) -> bool,
951    {
952        let n_threads = 16;
953
954        let (conn, mut sock) = dummy_connected();
955        let conn = Arc::new(conn);
956        let mut user_threads = Vec::new();
957        for _ in 0..n_threads {
958            let conn = Arc::clone(&conn);
959            let th = thread::spawn(move || {
960                // We are spawning a bunch of worker threads, each of which will run a number of
961                // We will double-check that each request gets the response it asked for.
962                let req = serde_json::json!({
963                    "obj":"fred",
964                    "method":"arti:x-echo",
965                    "params":{}
966                });
967                let req = serde_json::to_string(&req).unwrap();
968                conn.execute(&req)
969            });
970            user_threads.push(th);
971        }
972
973        sock.write_all(msg.as_bytes()).unwrap();
974
975        for t in user_threads {
976            let outcome = t.join().unwrap();
977            assert!(outcome_ok(outcome.unwrap_err()));
978        }
979    }
980
981    #[test]
982    fn syntax_error() {
983        proto_err_with_msg("this is not json\n", |outcome| {
984            matches!(
985                outcome,
986                ProtoError::Shutdown(ShutdownError::ProtocolViolated(_))
987            )
988        });
989    }
990
991    #[test]
992    fn fatal_error() {
993        let j = serde_json::json!({
994            "error":{ "message": "This test is doomed", "code": 413, "kinds": ["Example"], "data": {} },
995        });
996        let mut s = serde_json::to_string(&j).unwrap();
997        s.push('\n');
998
999        proto_err_with_msg(&s, |outcome| {
1000            matches!(
1001                outcome,
1002                ProtoError::Shutdown(ShutdownError::ProtocolViolationReport(_))
1003            )
1004        });
1005    }
1006}