1
//! Middle-level API for RPC connections
2
//!
3
//! This module focuses around the `RpcConn` type, which supports sending RPC requests
4
//! and matching them with their responses.
5

            
6
use std::{
7
    io::{self},
8
    sync::{Arc, Mutex},
9
};
10

            
11
use crate::msgs::{
12
    request::InvalidRequestError,
13
    response::{ResponseKind, RpcError, ValidatedResponse},
14
    AnyRequestId, ObjectId,
15
};
16

            
17
mod auth;
18
mod builder;
19
mod connimpl;
20
mod stream;
21

            
22
use crate::util::Utf8CString;
23
pub use builder::{BuilderError, ConnPtDescription, RpcConnBuilder};
24
pub use connimpl::RpcConn;
25
use serde::{de::DeserializeOwned, Deserialize};
26
pub use stream::StreamError;
27
use tor_rpc_connect::{auth::cookie::CookieAccessError, HasClientErrorAction};
28

            
29
/// A handle to an open request.
30
///
31
/// These handles are created with [`RpcConn::execute_with_handle`].
32
///
33
/// Note that dropping a RequestHandle does not cancel the associated request:
34
/// it will continue running, but you won't have a way to receive updates from it.
35
/// To cancel a request, use [`RpcConn::cancel`].
36
#[derive(educe::Educe)]
37
#[educe(Debug)]
38
pub struct RequestHandle {
39
    /// The underlying `Receiver` that we'll use to get updates for this request
40
    ///
41
    /// It's wrapped in a `Mutex` to prevent concurrent calls to `Receiver::wait_on_message_for`.
42
    //
43
    // NOTE: As an alternative to using a Mutex here, we _could_ remove
44
    // the restriction from `wait_on_message_for` that says that only one thread
45
    // may be waiting on a given request ID at once.  But that would introduce
46
    // complexity to the implementation,
47
    // and it's not clear that the benefit would be worth it.
48
    #[educe(Debug(ignore))]
49
    conn: Mutex<Arc<connimpl::Receiver>>,
50
    /// The ID of this request.
51
    id: AnyRequestId,
52
}
53

            
54
// TODO RPC: Possibly abolish these types.
55
//
56
// I am keeping this for now because it makes it more clear that we can never reinterpret
57
// a success as an update or similar.
58
//
59
// I am not at all pleased with these types; we should revise them.
60
//
61
// TODO RPC: Possibly, all of these should be reconstructed
62
// from their serde_json::Values rather than forwarded verbatim.
63
// (But why would we our json to be more canonical than arti's? See #1491.)
64
//
65
// DODGY TYPES BEGIN: TODO RPC
66

            
67
/// A Success Response from Arti, indicating that a request was successful.
68
///
69
/// This is the complete message, including `id` and `result` fields.
70
//
71
// Invariant: it is valid JSON and contains no NUL bytes or newlines.
72
// TODO RPC: check that the newline invariant is enforced in constructors.
73
#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
74
#[as_ref(forward)]
75
pub struct SuccessResponse(Utf8CString);
76

            
77
impl SuccessResponse {
78
    /// Helper: Decode the `result` field of this response as an instance of D.
79
2010
    fn decode<D: DeserializeOwned>(&self) -> Result<D, serde_json::Error> {
80
        /// Helper object for decoding the "result" field.
81
        #[derive(Deserialize)]
82
        struct Response<R> {
83
            /// The decoded value.
84
            result: R,
85
        }
86
2010
        let response: Response<D> = serde_json::from_str(self.as_ref())?;
87
2010
        Ok(response.result)
88
2010
    }
89
}
90

            
91
/// An Update Response from Arti, with information about the progress of a request.
92
///
93
/// This is the complete message, including `id` and `update` fields.
94
//
95
// Invariant: it is valid JSON and contains no NUL bytes or newlines.
96
// TODO RPC: check that the newline invariant is enforced in constructors.
97
// TODO RPC consider changing this to CString.
98
#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
99
#[as_ref(forward)]
100
pub struct UpdateResponse(Utf8CString);
101

            
102
/// A Error Response from Arti, indicating that an error occurred.
103
///
104
/// (This is the complete message, including the `error` field.
105
/// It also an `id` if it
106
/// is in response to a request; but not if it is a fatal protocol error.)
107
//
108
// Invariant: Does not contain a NUL. (Safe to convert to CString.)
109
//
110
// Invariant: This field MUST encode a response whose body is an RPC error.
111
//
112
// Otherwise the `decode` method may panic.
113
//
114
// TODO RPC: check that the newline invariant is enforced in constructors.
115
#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
116
#[as_ref(forward)]
117
// TODO: If we keep this, it should implement Error.
118
pub struct ErrorResponse(Utf8CString);
119
impl ErrorResponse {
120
    /// Construct an ErrorResponse from the Error reply.
121
    ///
122
    /// This not a From impl because we want it to be crate-internal.
123
2092
    pub(crate) fn from_validated_string(s: Utf8CString) -> Self {
124
2092
        ErrorResponse(s)
125
2092
    }
126

            
127
    /// Convert this response into an internal error in response to `cmd`.
128
    ///
129
    /// This is only appropriate when the error cannot be caused because of user behavior.
130
    pub(crate) fn internal_error(&self, cmd: &str) -> ProtoError {
131
        ProtoError::InternalRequestFailed(UnexpectedReply {
132
            request: cmd.to_string(),
133
            reply: self.to_string(),
134
            problem: UnexpectedReplyProblem::ErrorNotExpected,
135
        })
136
    }
137

            
138
    /// Try to interpret this response as an [`RpcError`].
139
2088
    pub fn decode(&self) -> RpcError {
140
2088
        crate::msgs::response::try_decode_response_as_err(self.0.as_ref())
141
2088
            .expect("Could not decode response that was already decoded as an error?")
142
2088
            .expect("Could not extract error from response that was already decoded as an error?")
143
2088
    }
144
}
145

            
146
impl std::fmt::Display for ErrorResponse {
147
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148
        let e = self.decode();
149
        write!(f, "Peer said {:?}", e.message())
150
    }
151
}
152

            
153
/// A final response -- that is, the last one that we expect to receive for a request.
154
///
155
type FinalResponse = Result<SuccessResponse, ErrorResponse>;
156

            
157
/// Any of the three types of Arti responses.
158
#[derive(Clone, Debug)]
159
#[allow(clippy::exhaustive_structs)]
160
pub enum AnyResponse {
161
    /// The request has succeeded; no more response will be given.
162
    Success(SuccessResponse),
163
    /// The request has failed; no more response will be given.
164
    Error(ErrorResponse),
165
    /// An incremental update; more messages may arrive.
166
    Update(UpdateResponse),
167
}
168
// TODO RPC: DODGY TYPES END.
169

            
170
impl AnyResponse {
171
    /// Convert `v` into `AnyResponse`.
172
8080
    fn from_validated(v: ValidatedResponse) -> Self {
173
8080
        // TODO RPC, Perhaps unify AnyResponse with ValidatedResponse, once we are sure what
174
8080
        // AnyResponse should look like.
175
8080
        match v.meta.kind {
176
2088
            ResponseKind::Error => AnyResponse::Error(ErrorResponse::from_validated_string(v.msg)),
177
2010
            ResponseKind::Success => AnyResponse::Success(SuccessResponse(v.msg)),
178
3982
            ResponseKind::Update => AnyResponse::Update(UpdateResponse(v.msg)),
179
        }
180
8080
    }
181

            
182
    /// Consume this `AnyResponse`, and return its internal string.
183
    #[cfg(feature = "ffi")]
184
    pub(crate) fn into_string(self) -> Utf8CString {
185
        match self {
186
            AnyResponse::Success(m) => m.into(),
187
            AnyResponse::Error(m) => m.into(),
188
            AnyResponse::Update(m) => m.into(),
189
        }
190
    }
191
}
192

            
193
impl RpcConn {
194
    /// Return the ObjectId for the negotiated Session.
195
    ///
196
    /// Nearly all RPC methods require a Session, or some other object
197
    /// accessed via the session.
198
    ///
199
    /// (This function will only return None if no authentication has been performed.
200
    /// TODO RPC: It is not currently possible to make an unauthenticated connection.)
201
    pub fn session(&self) -> Option<&ObjectId> {
202
        self.session.as_ref()
203
    }
204

            
205
    /// Run a command, and wait for success or failure.
206
    ///
207
    /// Note that this function will return `Err(.)` only if sending the command or getting a
208
    /// response failed.  If the command was sent successfully, and Arti reported an error in response,
209
    /// this function returns `Ok(Err(.))`.
210
    ///
211
    /// Note that the command does not need to include an `id` field.  If you omit it,
212
    /// one will be generated.
213
98
    pub fn execute(&self, cmd: &str) -> Result<FinalResponse, ProtoError> {
214
98
        let hnd = self.execute_with_handle(cmd)?;
215
46
        hnd.wait()
216
98
    }
217

            
218
    /// Helper for executing internally-generated requests and decoding their results.
219
    ///
220
    /// Behaves like `execute`, except on success, where it tries to decode the `result` field
221
    /// of the response as a `T`.
222
    ///
223
    /// Use this method in cases where it's reasonable for Arti to sometimes return an RPC error:
224
    /// in other words, where it's not necessarily a programming error or version mismatch.
225
    ///
226
    /// Don't use this for user-generated requests: it will misreport unexpected replies
227
    /// as internal errors.
228
2
    pub(crate) fn execute_internal<T: DeserializeOwned>(
229
2
        &self,
230
2
        cmd: &str,
231
2
    ) -> Result<Result<T, ErrorResponse>, ProtoError> {
232
2
        match self.execute(cmd)? {
233
2
            Ok(success) => match success.decode::<T>() {
234
2
                Ok(result) => Ok(Ok(result)),
235
                Err(json_error) => Err(ProtoError::InternalRequestFailed(UnexpectedReply {
236
                    request: cmd.to_string(),
237
                    reply: Utf8CString::from(success).to_string(),
238
                    problem: UnexpectedReplyProblem::CannotDecode(Arc::new(json_error)),
239
                })),
240
            },
241
            Err(error) => Ok(Err(error)),
242
        }
243
2
    }
244

            
245
    /// Helper for executing internally-generated requests and decoding their results.
246
    ///
247
    /// Behaves like `execute_internal`, except that it treats any RPC error reply
248
    /// as an internal error or version mismatch.
249
    ///
250
    /// Don't use this for user-generated requests, or for requests that can fail because of
251
    /// incorrect user inputs: it will misreport failures in those requests as internal errors.
252
2
    pub(crate) fn execute_internal_ok<T: DeserializeOwned>(
253
2
        &self,
254
2
        cmd: &str,
255
2
    ) -> Result<T, ProtoError> {
256
2
        match self.execute_internal(cmd)? {
257
2
            Ok(v) => Ok(v),
258
            Err(err_response) => Err(err_response.internal_error(cmd)),
259
        }
260
2
    }
261

            
262
    /// Cancel a request by ID.
263
    pub fn cancel(&self, request_id: &AnyRequestId) -> Result<(), ProtoError> {
264
        /// Arguments to an `rpc::cancel` request.
265
        #[derive(serde::Serialize, Debug)]
266
        struct CancelParams<'a> {
267
            /// The request to cancel.
268
            request_id: &'a AnyRequestId,
269
        }
270

            
271
        let request = crate::msgs::request::Request::new(
272
            ObjectId::connection_id(),
273
            "rpc:cancel",
274
            CancelParams { request_id },
275
        );
276
        match self.execute_internal::<EmptyReply>(&request.encode()?)? {
277
            Ok(EmptyReply {}) => Ok(()),
278
            Err(_) => Err(ProtoError::RequestCompleted),
279
        }
280
    }
281

            
282
    /// Like `execute`, but don't wait.  This lets the caller see the
283
    /// request ID and  maybe cancel it.
284
4194
    pub fn execute_with_handle(&self, cmd: &str) -> Result<RequestHandle, ProtoError> {
285
4194
        self.send_request(cmd)
286
4194
    }
287
    /// As execute(), but run update_cb for every update we receive.
288
4096
    pub fn execute_with_updates<F>(
289
4096
        &self,
290
4096
        cmd: &str,
291
4096
        mut update_cb: F,
292
4096
    ) -> Result<FinalResponse, ProtoError>
293
4096
    where
294
4096
        F: FnMut(UpdateResponse) + Send + Sync,
295
4096
    {
296
4096
        let hnd = self.execute_with_handle(cmd)?;
297
        loop {
298
8078
            match hnd.wait_with_updates()? {
299
2008
                AnyResponse::Success(s) => return Ok(Ok(s)),
300
2088
                AnyResponse::Error(e) => return Ok(Err(e)),
301
3982
                AnyResponse::Update(u) => update_cb(u),
302
            }
303
        }
304
4096
    }
305

            
306
    /// Helper: Tell Arti to release `obj`.
307
    ///
308
    /// Do not use this method for a user-provided object ID:
309
    /// It gives an internal error if the object does not exist.
310
    pub(crate) fn release_obj(&self, obj: ObjectId) -> Result<(), ProtoError> {
311
        let release_request = crate::msgs::request::Request::new(obj, "rpc:release", NoParams {});
312
        let _empty_response: EmptyReply = self.execute_internal_ok(&release_request.encode()?)?;
313
        Ok(())
314
    }
315

            
316
    // TODO RPC: shutdown() on the socket on Drop.
317
}
318

            
319
impl RequestHandle {
320
    /// Return the ID of this request, to help cancelling it.
321
    pub fn id(&self) -> &AnyRequestId {
322
        &self.id
323
    }
324
    /// Wait for success or failure, and return what happened.
325
    ///
326
    /// (Ignores any update messages that are received.)
327
    ///
328
    /// Note that this function will return `Err(.)` only if sending the command or getting a
329
    /// response failed.  If the command was sent successfully, and Arti reported an error in response,
330
    /// this function returns `Ok(Err(.))`.
331
46
    pub fn wait(self) -> Result<FinalResponse, ProtoError> {
332
        loop {
333
46
            match self.wait_with_updates()? {
334
2
                AnyResponse::Success(s) => return Ok(Ok(s)),
335
                AnyResponse::Error(e) => return Ok(Err(e)),
336
                AnyResponse::Update(_) => {}
337
            }
338
        }
339
46
    }
340
    /// Wait for the next success, failure, or update from this handle.
341
    ///
342
    /// Note that this function will return `Err(.)` only if sending the command or getting a
343
    /// response failed.  If the command was sent successfully, and Arti reported an error in response,
344
    /// this function returns `Ok(AnyResponse::Error(.))`.
345
    ///
346
    /// You may call this method on the same `RequestHandle` from multiple threads.
347
    /// If you do so, those calls will receive responses (or errors) in an unspecified order.
348
    ///
349
    /// If this function returns Success or Error, then you shouldn't call it again.
350
    /// All future calls to this function will fail with `CmdError::RequestCancelled`.
351
    /// (TODO RPC: Maybe rename that error.)
352
8124
    pub fn wait_with_updates(&self) -> Result<AnyResponse, ProtoError> {
353
8124
        let conn = self.conn.lock().expect("Poisoned lock");
354
8124
        let validated = conn.wait_on_message_for(&self.id)?;
355

            
356
8080
        Ok(AnyResponse::from_validated(validated))
357
8124
    }
358

            
359
    // TODO RPC: Sketch out how we would want to do this in an async world,
360
    // or with poll
361
}
362

            
363
/// An error (or other condition) that has caused an RPC connection to shut down.
364
#[derive(Clone, Debug, thiserror::Error)]
365
#[non_exhaustive]
366
pub enum ShutdownError {
367
    /// Io error occurred while reading.
368
    #[error("Unable to read response")]
369
    Read(#[source] Arc<io::Error>),
370
    /// Io error occurred while writing.
371
    #[error("Unable to write request")]
372
    Write(#[source] Arc<io::Error>),
373
    /// Something was wrong with Arti's responses; this is a protocol violation.
374
    #[error("Arti sent a message that didn't conform to the RPC protocol: {0:?}")]
375
    ProtocolViolated(String),
376
    /// Arti has told us that we violated the protocol somehow.
377
    #[error("Arti reported a fatal error: {0:?}")]
378
    ProtocolViolationReport(ErrorResponse),
379
    /// The underlying connection closed.
380
    ///
381
    /// This probably means that Arti has shut down.
382
    #[error("Connection closed")]
383
    ConnectionClosed,
384
}
385

            
386
impl From<crate::msgs::response::DecodeResponseError> for ShutdownError {
387
4
    fn from(value: crate::msgs::response::DecodeResponseError) -> Self {
388
        use crate::msgs::response::DecodeResponseError::*;
389
        use ShutdownError as E;
390
4
        match value {
391
2
            JsonProtocolViolation(e) => E::ProtocolViolated(e.to_string()),
392
            ProtocolViolation(s) => E::ProtocolViolated(s.to_string()),
393
2
            Fatal(rpc_err) => E::ProtocolViolationReport(rpc_err),
394
        }
395
4
    }
396
}
397

            
398
/// An error that has occurred while launching an RPC command.
399
#[derive(Clone, Debug, thiserror::Error)]
400
#[non_exhaustive]
401
pub enum ProtoError {
402
    /// The RPC connection failed, or was closed by the other side.
403
    #[error("RPC connection is shut down")]
404
    Shutdown(#[from] ShutdownError),
405

            
406
    /// There was a problem in the request we tried to send.
407
    #[error("Invalid request")]
408
    InvalidRequest(#[from] InvalidRequestError),
409

            
410
    /// We tried to send a request with an ID that was already pending.
411
    #[error("Request ID already in use.")]
412
    RequestIdInUse,
413

            
414
    /// We tried to wait for or inspect a request that had already succeeded or failed.
415
    #[error("Request has already completed (or failed)")]
416
    RequestCompleted,
417

            
418
    /// We tried to wait for the same request more than once.
419
    ///
420
    /// (This should be impossible.)
421
    #[error("Internal error: waiting on the same request more than once at a time.")]
422
    DuplicateWait,
423

            
424
    /// We got an internal error while trying to encode an RPC request.
425
    ///
426
    /// (This should be impossible.)
427
    #[error("Internal error while encoding request")]
428
    CouldNotEncode(#[source] Arc<serde_json::Error>),
429

            
430
    /// We got a response to some internally generated request that wasn't what we expected.
431
    #[error("{0}")]
432
    InternalRequestFailed(#[source] UnexpectedReply),
433
}
434

            
435
/// A set of errors encountered while trying to connect to the Arti process
436
#[derive(Clone, Debug, thiserror::Error)]
437
pub struct ConnectFailure {
438
    /// A list of all the declined connect points we encountered, and how they failed.
439
    declined: Vec<(builder::ConnPtDescription, ConnectError)>,
440
    /// A description of where we found the final error (if it's an abort.)
441
    final_desc: Option<builder::ConnPtDescription>,
442
    /// The final error explaining why we couldn't connect.
443
    ///
444
    /// This is either an abort, an AllAttemptsDeclined, or an error that prevented the
445
    /// search process from even beginning.
446
    #[source]
447
    pub(crate) final_error: ConnectError,
448
}
449

            
450
impl std::fmt::Display for ConnectFailure {
451
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
452
        write!(f, "Unable to connect")?;
453
        if !self.declined.is_empty() {
454
            write!(
455
                f,
456
                " ({} attempts failed{})",
457
                self.declined.len(),
458
                if matches!(self.final_error, ConnectError::AllAttemptsDeclined) {
459
                    ""
460
                } else {
461
                    " before fatal error"
462
                }
463
            )?;
464
        }
465
        Ok(())
466
    }
467
}
468

            
469
impl ConnectFailure {
470
    /// If this attempt failed because of a fatal error that made a connect point attempt abort,
471
    /// return a description of the origin of that connect point.
472
    pub fn fatal_error_origin(&self) -> Option<&builder::ConnPtDescription> {
473
        self.final_desc.as_ref()
474
    }
475

            
476
    /// For each connect attempt that failed nonfatally, return a description of the
477
    /// origin of that connect point, and the error that caused it to fail.
478
    pub fn declined_attempt_outcomes(
479
        &self,
480
    ) -> impl Iterator<Item = (&builder::ConnPtDescription, &ConnectError)> {
481
        // Note: this map looks like a no-op, but isn't.
482
        self.declined.iter().map(|(a, b)| (a, b))
483
    }
484

            
485
    /// Return a helper type to format this error, and all of its internal errors recursively.
486
    ///
487
    /// Unlike [`tor_error::Report`], this method includes not only fatal errors, but also
488
    /// information about connect attempts that failed nonfatally.
489
    pub fn display_verbose(&self) -> ConnectFailureVerboseFmt<'_> {
490
        ConnectFailureVerboseFmt(self)
491
    }
492
}
493

            
494
/// Helper type to format a ConnectFailure along with all of its internal errors,
495
/// including non-fatal errors.
496
#[derive(Debug, Clone)]
497
pub struct ConnectFailureVerboseFmt<'a>(&'a ConnectFailure);
498

            
499
impl<'a> std::fmt::Display for ConnectFailureVerboseFmt<'a> {
500
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501
        use tor_error::ErrorReport as _;
502
        writeln!(f, "{}:", self.0)?;
503
        for (idx, (origin, error)) in self.0.declined_attempt_outcomes().enumerate() {
504
            writeln!(f, "  {}. {}: {}", idx + 1, origin, error.report())?;
505
        }
506
        if let Some(origin) = self.0.fatal_error_origin() {
507
            writeln!(
508
                f,
509
                "  {}. [FATAL] {}: {}",
510
                self.0.declined.len() + 1,
511
                origin,
512
                self.0.final_error.report()
513
            )?;
514
        } else {
515
            writeln!(f, "  - {}", self.0.final_error.report())?;
516
        }
517
        Ok(())
518
    }
519
}
520

            
521
/// An error while trying to connect to the Arti process.
522
#[derive(Clone, Debug, thiserror::Error)]
523
#[non_exhaustive]
524
pub enum ConnectError {
525
    /// Unable to parse connect points from an environment variable.
526
    #[error("Cannot parse connect points from environment variable")]
527
    BadEnvironment,
528
    /// We were unable to load and/or parse a given connect point.
529
    #[error("Unable to load and parse connect point")]
530
    CannotParse(#[from] tor_rpc_connect::load::LoadError),
531
    /// The path used to specify a connect file couldn't be resolved.
532
    #[error("Unable to resolve connect point path")]
533
    CannotResolvePath(#[source] tor_config_path::CfgPathError),
534
    /// A parsed connect point couldn't be resolved.
535
    #[error("Unable to resolve connect point")]
536
    CannotResolveConnectPoint(#[from] tor_rpc_connect::ResolveError),
537
    /// IO error while connecting to Arti.
538
    #[error("Unable to make a connection")]
539
    CannotConnect(#[from] tor_rpc_connect::ConnectError),
540
    /// Opened a connection, but didn't get a banner message.
541
    ///
542
    /// (This isn't a `BadMessage`, since it is likelier to represent something that isn't
543
    /// pretending to be Arti at all than it is to be a malfunctioning Arti.)
544
    #[error("Did not receive expected banner message upon connecting")]
545
    InvalidBanner,
546
    /// All attempted connect points were declined, and none were aborted.
547
    #[error("All connect points were declined (or there were none)")]
548
    AllAttemptsDeclined,
549
    /// A connect file or directory was given as a relative path.
550
    /// (Only absolute paths are supported).
551
    #[error("Connect file was given as a relative path.")]
552
    RelativeConnectFile,
553
    /// One of our authentication messages received an error.
554
    #[error("Received an error while trying to authenticate: {0}")]
555
    AuthenticationFailed(ErrorResponse),
556
    /// The connect point uses an RPC authentication type we don't support.
557
    #[error("Authentication type is not supported")]
558
    AuthenticationNotSupported,
559
    /// We couldn't decode one of the responses we got.
560
    #[error("Message not in expected format")]
561
    BadMessage(#[source] Arc<serde_json::Error>),
562
    /// A protocol error occurred during negotiations.
563
    #[error("Error while negotiating with Arti")]
564
    ProtoError(#[from] ProtoError),
565
    /// The server thinks it is listening on an address where we don't expect to find it.
566
    /// This can be misconfiguration or an attempted MITM attack.
567
    #[error("We connected to the server at {ours}, but it thinks it's listening at {theirs}")]
568
    ServerAddressMismatch {
569
        /// The address we think the server has
570
        ours: String,
571
        /// The address that the server says it has.
572
        theirs: String,
573
    },
574
    /// The server tried to prove knowledge of a cookie file, but its proof was incorrect.
575
    #[error("Server's cookie MAC was not as expected.")]
576
    CookieMismatch,
577
    /// We were unable to access the configured cookie file.
578
    #[error("Unable to load secret cookie value")]
579
    LoadCookie(#[from] CookieAccessError),
580
}
581

            
582
impl HasClientErrorAction for ConnectError {
583
    fn client_action(&self) -> tor_rpc_connect::ClientErrorAction {
584
        use tor_rpc_connect::ClientErrorAction as A;
585
        use ConnectError as E;
586
        match self {
587
            E::BadEnvironment => A::Abort,
588
            E::CannotParse(e) => e.client_action(),
589
            E::CannotResolvePath(_) => A::Abort,
590
            E::CannotResolveConnectPoint(e) => e.client_action(),
591
            E::CannotConnect(e) => e.client_action(),
592
            E::InvalidBanner => A::Decline,
593
            E::RelativeConnectFile => A::Abort,
594
            E::AuthenticationFailed(_) => A::Decline,
595
            // TODO RPC: Is this correct?  This error can also occur when
596
            // we are talking to something other than an RPC server.
597
            E::BadMessage(_) => A::Abort,
598
            E::ProtoError(e) => e.client_action(),
599
            E::AllAttemptsDeclined => A::Abort,
600
            E::AuthenticationNotSupported => A::Decline,
601
            E::ServerAddressMismatch { .. } => A::Abort,
602
            E::CookieMismatch => A::Abort,
603
            E::LoadCookie(e) => e.client_action(),
604
        }
605
    }
606
}
607

            
608
impl HasClientErrorAction for ProtoError {
609
    fn client_action(&self) -> tor_rpc_connect::ClientErrorAction {
610
        use tor_rpc_connect::ClientErrorAction as A;
611
        use ProtoError as E;
612
        match self {
613
            E::Shutdown(_) => A::Decline,
614
            E::InternalRequestFailed(_) => A::Decline,
615
            // These are always internal errors if they occur while negotiating a connection to RPC,
616
            // which is the context we care about for `HasClientErrorAction`.
617
            E::InvalidRequest(_)
618
            | E::RequestIdInUse
619
            | E::RequestCompleted
620
            | E::DuplicateWait
621
            | E::CouldNotEncode(_) => A::Abort,
622
        }
623
    }
624
}
625

            
626
/// In response to a request that we generated internally,
627
/// Arti gave a reply that we did not understand.
628
///
629
/// This could be due to a bug in this library, a bug in Arti,
630
/// or a compatibility issue between the two.
631
#[derive(Clone, Debug, thiserror::Error)]
632
#[error("In response to our request {request:?}, Arti gave the unexpected reply {reply:?}")]
633
pub struct UnexpectedReply {
634
    /// The request we sent.
635
    request: String,
636
    /// The response we got.
637
    reply: String,
638
    /// What was wrong with the response.
639
    #[source]
640
    problem: UnexpectedReplyProblem,
641
}
642

            
643
/// Underlying reason for an UnexpectedReply
644
#[derive(Clone, Debug, thiserror::Error)]
645
enum UnexpectedReplyProblem {
646
    /// There was a json failure while trying to decode the response:
647
    /// the result type was not what we expected.
648
    #[error("Cannot decode as correct JSON type")]
649
    CannotDecode(Arc<serde_json::Error>),
650
    /// Arti replied with an RPC error in a context no error should have been possible.
651
    #[error("Unexpected error")]
652
    ErrorNotExpected,
653
}
654

            
655
/// Arguments to a request that takes no parameters.
656
#[derive(serde::Serialize, Debug)]
657
struct NoParams {}
658

            
659
/// A reply with no data.
660
#[derive(serde::Deserialize, Debug)]
661
struct EmptyReply {}
662

            
663
#[cfg(test)]
664
mod test {
665
    // @@ begin test lint list maintained by maint/add_warning @@
666
    #![allow(clippy::bool_assert_comparison)]
667
    #![allow(clippy::clone_on_copy)]
668
    #![allow(clippy::dbg_macro)]
669
    #![allow(clippy::mixed_attributes_style)]
670
    #![allow(clippy::print_stderr)]
671
    #![allow(clippy::print_stdout)]
672
    #![allow(clippy::single_char_pattern)]
673
    #![allow(clippy::unwrap_used)]
674
    #![allow(clippy::unchecked_duration_subtraction)]
675
    #![allow(clippy::useless_vec)]
676
    #![allow(clippy::needless_pass_by_value)]
677
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
678

            
679
    use std::{sync::atomic::AtomicUsize, thread, time::Duration};
680

            
681
    use io::{BufRead as _, BufReader, Write as _};
682
    use rand::{seq::SliceRandom as _, Rng as _, SeedableRng as _};
683
    use tor_basic_utils::{test_rng::testing_rng, RngExt as _};
684

            
685
    use crate::{
686
        llconn,
687
        msgs::request::{JsonMap, Request, ValidatedRequest},
688
    };
689

            
690
    use super::*;
691

            
692
    /// helper: Return a dummy RpcConn, along with a socketpair for it to talk to.
693
    fn dummy_connected() -> (RpcConn, crate::testing::SocketpairStream) {
694
        let (s1, s2) = crate::testing::construct_socketpair().unwrap();
695
        let s1_w = s1.try_clone().unwrap();
696
        let s1_r = io::BufReader::new(s1);
697
        let conn = RpcConn::new(llconn::Reader::new(s1_r), llconn::Writer::new(s1_w));
698

            
699
        (conn, s2)
700
    }
701

            
702
    fn write_val(w: &mut impl io::Write, v: &serde_json::Value) {
703
        let mut enc = serde_json::to_string(v).unwrap();
704
        enc.push('\n');
705
        w.write_all(enc.as_bytes()).unwrap();
706
    }
707

            
708
    #[test]
709
    fn simple() {
710
        let (conn, sock) = dummy_connected();
711

            
712
        let user_thread = thread::spawn(move || {
713
            let response1 = conn
714
                .execute_internal_ok::<JsonMap>(
715
                    r#"{"obj":"fred","method":"arti:x-frob","params":{}}"#,
716
                )
717
                .unwrap();
718
            (response1, conn)
719
        });
720

            
721
        let fake_arti_thread = thread::spawn(move || {
722
            let mut sock = BufReader::new(sock);
723
            let mut s = String::new();
724
            let _len = sock.read_line(&mut s).unwrap();
725
            let request = ValidatedRequest::from_string_strict(s.as_ref()).unwrap();
726
            let response = serde_json::json!({
727
                "id": request.id().clone(),
728
                "result": { "xyz" : 3 }
729
            });
730
            write_val(sock.get_mut(), &response);
731
            sock // prevent close
732
        });
733

            
734
        let _sock = fake_arti_thread.join().unwrap();
735
        let (map, _conn) = user_thread.join().unwrap();
736
        assert_eq!(map.get("xyz"), Some(&serde_json::Value::Number(3.into())));
737
    }
738

            
739
    #[test]
740
    fn complex() {
741
        use std::sync::atomic::Ordering::SeqCst;
742
        let n_threads = 16;
743
        let n_commands_per_thread = 128;
744
        let n_commands_total = n_threads * n_commands_per_thread;
745
        let n_completed = Arc::new(AtomicUsize::new(0));
746

            
747
        let (conn, sock) = dummy_connected();
748
        let conn = Arc::new(conn);
749
        let mut user_threads = Vec::new();
750
        let mut rng = testing_rng();
751

            
752
        // -------
753
        // User threads: Make a bunch of requests.
754
        for th_idx in 0..n_threads {
755
            let conn = Arc::clone(&conn);
756
            let n_completed = Arc::clone(&n_completed);
757
            let mut rng = rand_chacha::ChaCha12Rng::from_seed(rng.gen());
758
            let th = thread::spawn(move || {
759
                for cmd_idx in 0..n_commands_per_thread {
760
                    // We are spawning a bunch of worker threads, each of which will run a number of
761
                    // commands in sequence.  Each command will be a request that gets optional
762
                    // updates, and an error or a success.
763
                    // We will double-check that each request gets the response it asked for.
764
                    let s = format!("{}:{}", th_idx, cmd_idx);
765
                    let want_updates: bool = rng.gen();
766
                    let want_failure: bool = rng.gen();
767
                    let req = serde_json::json!({
768
                        "obj":"fred",
769
                        "method":"arti:x-echo",
770
                        "meta": {
771
                            "updates": want_updates,
772
                        },
773
                        "params": {
774
                            "val": &s,
775
                            "fail": want_failure,
776
                        },
777
                    });
778
                    let req = serde_json::to_string(&req).unwrap();
779

            
780
                    // Wait for a final response, processing updates if we asked for them.
781
                    let mut n_updates = 0;
782
                    let outcome = conn
783
                        .execute_with_updates(&req, |_update| {
784
                            n_updates += 1;
785
                        })
786
                        .unwrap();
787
                    assert_eq!(n_updates > 0, want_updates);
788

            
789
                    // See if we liked the final response.
790
                    if want_failure {
791
                        let e = outcome.unwrap_err().decode();
792
                        assert_eq!(e.message(), "You asked me to fail");
793
                        assert_eq!(i32::from(e.code()), 33);
794
                        assert_eq!(
795
                            e.kinds_iter().collect::<Vec<_>>(),
796
                            vec!["Example".to_string()]
797
                        );
798
                    } else {
799
                        let success = outcome.unwrap();
800
                        let map = success.decode::<JsonMap>().unwrap();
801
                        assert_eq!(map.get("echo"), Some(&serde_json::Value::String(s)));
802
                    }
803
                    n_completed.fetch_add(1, SeqCst);
804
                    if rng.gen::<f32>() < 0.02 {
805
                        thread::sleep(Duration::from_millis(3));
806
                    }
807
                }
808
            });
809
            user_threads.push(th);
810
        }
811

            
812
        #[derive(serde::Deserialize, Debug)]
813
        struct Echo {
814
            val: String,
815
            fail: bool,
816
        }
817

            
818
        // -----
819
        // Worker thread: handles user requests.
820
        let worker_rng = rand_chacha::ChaCha12Rng::from_seed(rng.gen());
821
        let worker_thread = thread::spawn(move || {
822
            let mut rng = worker_rng;
823
            let mut sock = BufReader::new(sock);
824
            let mut pending: Vec<Request<Echo>> = Vec::new();
825
            let mut n_received = 0;
826

            
827
            // How many requests do we buffer before we shuffle them and answer them out-of-order?
828
            let scramble_factor = 7;
829
            // After receiving how many requests do we stop shuffling requests?
830
            //
831
            // (Our shuffling algorithm can deadlock us otherwise.)
832
            let scramble_threshold =
833
                n_commands_total - (n_commands_per_thread + 1) * scramble_factor;
834

            
835
            'outer: loop {
836
                let flush_pending_at = if n_received >= scramble_threshold {
837
                    1
838
                } else {
839
                    scramble_factor
840
                };
841

            
842
                // Queue a handful of requests in "pending"
843
                while pending.len() < flush_pending_at {
844
                    let mut buf = String::new();
845
                    if sock.read_line(&mut buf).unwrap() == 0 {
846
                        break 'outer;
847
                    }
848
                    n_received += 1;
849
                    let req: Request<Echo> = serde_json::from_str(&buf).unwrap();
850
                    pending.push(req);
851
                }
852

            
853
                // Handle the requests in "pending" in random order.
854
                let mut handling = std::mem::take(&mut pending);
855
                handling.shuffle(&mut rng);
856

            
857
                for req in handling {
858
                    if req.meta.unwrap_or_default().updates {
859
                        let n_updates = rng.gen_range_checked(1..4).unwrap();
860
                        for _ in 0..n_updates {
861
                            let up = serde_json::json!({
862
                                "id": req.id.clone(),
863
                                "update": {
864
                                    "hello": req.params.val.clone(),
865
                                }
866
                            });
867
                            write_val(sock.get_mut(), &up);
868
                        }
869
                    }
870

            
871
                    let response = if req.params.fail {
872
                        serde_json::json!({
873
                            "id": req.id.clone(),
874
                            "error": { "message": "You asked me to fail", "code": 33, "kinds": ["Example"], "data": req.params.val },
875
                        })
876
                    } else {
877
                        serde_json::json!({
878
                            "id": req.id.clone(),
879
                            "result": {
880
                                "echo": req.params.val
881
                            }
882
                        })
883
                    };
884
                    write_val(sock.get_mut(), &response);
885
                }
886
            }
887
        });
888
        drop(conn);
889
        for t in user_threads {
890
            t.join().unwrap();
891
        }
892

            
893
        worker_thread.join().unwrap();
894

            
895
        assert_eq!(n_completed.load(SeqCst), n_commands_total);
896
    }
897

            
898
    #[test]
899
    fn arti_socket_closed() {
900
        // Here we send a bunch of requests and then close the socket without answering them.
901
        //
902
        // Every request should get a ProtoError::Shutdown.
903
        let n_threads = 16;
904

            
905
        let (conn, sock) = dummy_connected();
906
        let conn = Arc::new(conn);
907
        let mut user_threads = Vec::new();
908
        for _ in 0..n_threads {
909
            let conn = Arc::clone(&conn);
910
            let th = thread::spawn(move || {
911
                // We are spawning a bunch of worker threads, each of which will run a number of
912
                // We will double-check that each request gets the response it asked for.
913
                let req = serde_json::json!({
914
                    "obj":"fred",
915
                    "method":"arti:x-echo",
916
                    "params":{}
917
                });
918
                let req = serde_json::to_string(&req).unwrap();
919
                let outcome = conn.execute(&req);
920
                if !matches!(
921
                    &outcome,
922
                    Err(ProtoError::Shutdown(ShutdownError::Write(_)))
923
                        | Err(ProtoError::Shutdown(ShutdownError::Read(_))),
924
                ) {
925
                    dbg!(&outcome);
926
                }
927

            
928
                assert!(matches!(
929
                    outcome,
930
                    Err(ProtoError::Shutdown(ShutdownError::Write(_)))
931
                        | Err(ProtoError::Shutdown(ShutdownError::Read(_)))
932
                        | Err(ProtoError::Shutdown(ShutdownError::ConnectionClosed))
933
                ));
934
            });
935
            user_threads.push(th);
936
        }
937

            
938
        drop(sock);
939

            
940
        for t in user_threads {
941
            t.join().unwrap();
942
        }
943
    }
944

            
945
    /// Send a bunch of requests and then send back a single reply.
946
    ///
947
    /// That reply should cause every request to get closed.
948
    fn proto_err_with_msg<F>(msg: &str, outcome_ok: F)
949
    where
950
        F: Fn(ProtoError) -> bool,
951
    {
952
        let n_threads = 16;
953

            
954
        let (conn, mut sock) = dummy_connected();
955
        let conn = Arc::new(conn);
956
        let mut user_threads = Vec::new();
957
        for _ in 0..n_threads {
958
            let conn = Arc::clone(&conn);
959
            let th = thread::spawn(move || {
960
                // We are spawning a bunch of worker threads, each of which will run a number of
961
                // We will double-check that each request gets the response it asked for.
962
                let req = serde_json::json!({
963
                    "obj":"fred",
964
                    "method":"arti:x-echo",
965
                    "params":{}
966
                });
967
                let req = serde_json::to_string(&req).unwrap();
968
                conn.execute(&req)
969
            });
970
            user_threads.push(th);
971
        }
972

            
973
        sock.write_all(msg.as_bytes()).unwrap();
974

            
975
        for t in user_threads {
976
            let outcome = t.join().unwrap();
977
            assert!(outcome_ok(outcome.unwrap_err()));
978
        }
979
    }
980

            
981
    #[test]
982
    fn syntax_error() {
983
        proto_err_with_msg("this is not json\n", |outcome| {
984
            matches!(
985
                outcome,
986
                ProtoError::Shutdown(ShutdownError::ProtocolViolated(_))
987
            )
988
        });
989
    }
990

            
991
    #[test]
992
    fn fatal_error() {
993
        let j = serde_json::json!({
994
            "error":{ "message": "This test is doomed", "code": 413, "kinds": ["Example"], "data": {} },
995
        });
996
        let mut s = serde_json::to_string(&j).unwrap();
997
        s.push('\n');
998

            
999
        proto_err_with_msg(&s, |outcome| {
            matches!(
                outcome,
                ProtoError::Shutdown(ShutdownError::ProtocolViolationReport(_))
            )
        });
    }
}