1
//! Functionality to connect to an RPC server.
2

            
3
use std::{
4
    collections::HashMap,
5
    io::{self},
6
    path::PathBuf,
7
    str::FromStr as _,
8
};
9

            
10
use fs_mistrust::Mistrust;
11
use tor_config_path::{CfgPath, CfgPathResolver};
12
use tor_rpc_connect::{
13
    auth::RpcAuth,
14
    load::{LoadError, LoadOptions},
15
    ClientErrorAction, HasClientErrorAction, ParsedConnectPoint,
16
};
17

            
18
use crate::{conn::ConnectError, llconn, msgs::response::UnparsedResponse, RpcConn};
19

            
20
use super::ConnectFailure;
21

            
22
/// An error occurred while trying to construct or manipulate an [`RpcConnBuilder`].
23
#[derive(Clone, Debug, thiserror::Error)]
24
#[non_exhaustive]
25
pub enum BuilderError {
26
    /// We couldn't decode a provided connect string.
27
    #[error("Invalid connect string.")]
28
    InvalidConnectString,
29
}
30

            
31
/// Information about how to construct a connection to an Arti instance.
32
//
33
// TODO RPC: Once we have our formats more settled, add a link to a piece of documentation
34
// explaining what a connect point is and how to make one.
35
#[derive(Default, Clone, Debug)]
36
pub struct RpcConnBuilder {
37
    /// Path entries provided programmatically.
38
    ///
39
    /// These are considered after entries in
40
    /// the `$ARTI_RPC_CONNECT_PATH_OVERRIDE` environment variable,
41
    /// but before any other entries.
42
    /// (See `RPCConnBuilder::new` for details.)
43
    ///
44
    /// These entries are stored in reverse order.
45
    prepend_path_reversed: Vec<SearchEntry>,
46
}
47

            
48
/// A single entry in the search path used to find connect points.
49
///
50
/// Includes information on where we got this entry
51
/// (environment variable, application, or default).
52
#[derive(Clone, Debug)]
53
struct SearchEntry {
54
    /// The source telling us this entry.
55
    source: ConnPtOrigin,
56
    /// The location to search.
57
    location: SearchLocation,
58
}
59

            
60
/// A single location in the search path used to find connect points.
61
#[derive(Clone, Debug)]
62
enum SearchLocation {
63
    /// A literal connect point entry to parse.
64
    Literal(String),
65
    /// A path to a connect file, or a directory full of connect files.
66
    Path {
67
        /// The path to load.
68
        path: CfgPath,
69

            
70
        /// If true, then this entry comes from a builtin default,
71
        /// and relative paths should cause the connect attempt to be declined.
72
        ///
73
        /// Otherwise, this entry comes from the user or application,
74
        /// and relative paths should cause the connect attempt to abort.
75
        is_default_entry: bool,
76
    },
77
}
78

            
79
/// Diagnostic: An explanation of where we found a connect point,
80
/// and why we looked there.
81
#[derive(Debug, Clone)]
82
pub struct ConnPtDescription {
83
    /// What told us to look in this location
84
    source: ConnPtOrigin,
85
    /// Where we found the connect point.
86
    location: ConnPtLocation,
87
}
88

            
89
impl std::fmt::Display for ConnPtDescription {
90
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91
        write!(
92
            f,
93
            "connect point in {}, from {}",
94
            &self.location, &self.source
95
        )
96
    }
97
}
98

            
99
/// Diagnostic: a source telling us where to look for a connect point.
100
#[derive(Clone, Copy, Debug)]
101
enum ConnPtOrigin {
102
    /// Found the search entry from an environment variable.
103
    EnvVar(&'static str),
104
    /// Application manually inserted the search entry.
105
    Application,
106
    /// The search entry was a built-in default
107
    Default,
108
}
109

            
110
impl std::fmt::Display for ConnPtOrigin {
111
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112
        match self {
113
            ConnPtOrigin::EnvVar(varname) => write!(f, "${}", varname),
114
            ConnPtOrigin::Application => write!(f, "application"),
115
            ConnPtOrigin::Default => write!(f, "default list"),
116
        }
117
    }
118
}
119

            
120
/// Diagnostic: Where we found a connect point.
121
#[derive(Clone, Debug)]
122
enum ConnPtLocation {
123
    /// The connect point was given as a literal string.
124
    Literal(String),
125
    /// We expanded a CfgPath to find the location of a connect file on disk.
126
    File {
127
        /// The path as configured
128
        path: CfgPath,
129
        /// The expanded path.
130
        expanded: Option<PathBuf>,
131
    },
132
    /// We expanded a CfgPath to find a directory, and found the connect file
133
    /// within that directory
134
    WithinDir {
135
        /// The path of the directory as configured.
136
        path: CfgPath,
137
        /// The location of the file.
138
        file: PathBuf,
139
    },
140
}
141

            
142
impl std::fmt::Display for ConnPtLocation {
143
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144
        // Note: here we use Path::display(), which in other crates we forbid
145
        // and use tor_basic_utils::PathExt::display_lossy().
146
        //
147
        // Here we make an exception, since arti-rpc-client-core is meant to have
148
        // minimal dependencies on our other crates.
149
        #[allow(clippy::disallowed_methods)]
150
        match self {
151
            ConnPtLocation::Literal(s) => write!(f, "literal string {:?}", s),
152
            ConnPtLocation::File {
153
                path,
154
                expanded: Some(ex),
155
            } => {
156
                write!(f, "file {} [{}]", path, ex.display())
157
            }
158
            ConnPtLocation::File {
159
                path,
160
                expanded: None,
161
            } => {
162
                write!(f, "file {} [cannot expand]", path)
163
            }
164

            
165
            ConnPtLocation::WithinDir {
166
                path,
167
                file: expanded,
168
            } => {
169
                write!(f, "file {} in directory {}", expanded.display(), path)
170
            }
171
        }
172
    }
173
}
174

            
175
impl RpcConnBuilder {
176
    /// Create a new `RpcConnBuilder` to try connecting to an Arti instance.
177
    ///
178
    /// By default, we search:
179
    ///   - Any connect points listed in the environment variable `$ARTI_RPC_CONNECT_PATH_OVERRIDE`
180
    ///   - Any connect points passed to `RpcConnBuilder::prepend_*`
181
    ///     (Since these variables are _prepended_,
182
    ///     the ones that are prepended _last_ will be considered _first_.)
183
    ///   - Any connect points listed in the environment variable `$ARTI_RPC_CONNECT_PATH`
184
    ///   - Any connect files in `${ARTI_LOCAL_DATA}/rpc/connect.d`
185
    ///   - Any connect files in `/etc/arti-rpc/connect.d` (unix only)
186
    ///   - [`tor_rpc_connect::USER_DEFAULT_CONNECT_POINT`]
187
    ///   - [`tor_rpc_connect::SYSTEM_DEFAULT_CONNECT_POINT`] if present
188
    //
189
    // TODO RPC: Once we have our formats more settled, add a link to a piece of documentation
190
    // explaining what a connect point is and how to make one.
191
    pub fn new() -> Self {
192
        Self::default()
193
    }
194

            
195
    /// Prepend a single literal connect point to the search path in this RpcConnBuilder.
196
    ///
197
    /// This entry will be considered before any entries in
198
    /// the `$ARTI_RPC_CONNECT_PATH` environment variable
199
    /// but after any entry in
200
    /// the `$ARTI_RPC_CONNECT_PATH_OVERRIDE` environment variable.
201
    ///
202
    /// This entry must be a literal connect point, expressed as a TOML table.
203
    pub fn prepend_literal_entry(&mut self, s: String) {
204
        self.prepend_internal(SearchLocation::Literal(s));
205
    }
206

            
207
    /// Prepend a single path entry to the search path in this RpcConnBuilder.
208
    ///
209
    /// This entry will be considered before any entries in
210
    /// the `$ARTI_RPC_CONNECT_PATH` environment variable,
211
    /// but after any entry in
212
    /// the `$ARTI_RPC_CONNECT_PATH_OVERRIDE` environment variable.
213
    ///
214
    /// This entry must be a path to a file or directory.
215
    /// It may contain variables to expand;
216
    /// they will be expanded according to the rules of [`CfgPath`],
217
    /// using the variables of [`tor_config_path::arti_client_base_resolver`].
218
    pub fn prepend_path(&mut self, p: String) {
219
        self.prepend_internal(SearchLocation::Path {
220
            path: CfgPath::new(p),
221
            is_default_entry: false,
222
        });
223
    }
224

            
225
    /// Prepend a single literal path entry to the search path in this RpcConnBuilder.
226
    ///
227
    /// This entry will be considered before any entries in
228
    /// the `$ARTI_RPC_CONNECT_PATH` environment variable,
229
    /// but after any entry in
230
    /// the `$ARTI_RPC_CONNECT_PATH_OVERRIDE` environment variable.
231
    ///
232
    /// Variables in this entry will not be expanded.
233
    pub fn prepend_literal_path(&mut self, p: PathBuf) {
234
        self.prepend_internal(SearchLocation::Path {
235
            path: CfgPath::new_literal(p),
236
            is_default_entry: false,
237
        });
238
    }
239

            
240
    /// Prepend the application-provided [`SearchLocation`] to the path.
241
    fn prepend_internal(&mut self, location: SearchLocation) {
242
        self.prepend_path_reversed.push(SearchEntry {
243
            source: ConnPtOrigin::Application,
244
            location,
245
        });
246
    }
247

            
248
    /// Return the list of default path entries that we search _after_
249
    /// all user-provided entries.
250
    fn default_path_entries() -> Vec<SearchEntry> {
251
        use SearchLocation::*;
252
        let dflt = |location| SearchEntry {
253
            source: ConnPtOrigin::Default,
254
            location,
255
        };
256
        let mut result = vec![
257
            dflt(Path {
258
                path: CfgPath::new("${ARTI_LOCAL_DATA}/rpc/connect.d/".to_owned()),
259
                is_default_entry: true,
260
            }),
261
            #[cfg(unix)]
262
            dflt(Path {
263
                path: CfgPath::new_literal("/etc/arti-rpc/connect.d/"),
264
                is_default_entry: true,
265
            }),
266
            dflt(Literal(
267
                tor_rpc_connect::USER_DEFAULT_CONNECT_POINT.to_owned(),
268
            )),
269
        ];
270
        if let Some(p) = tor_rpc_connect::SYSTEM_DEFAULT_CONNECT_POINT {
271
            result.push(dflt(Literal(p.to_owned())));
272
        }
273
        result
274
    }
275

            
276
    /// Return a vector of every PathEntry that we should try to connect to.
277
    fn all_entries(&self) -> Result<Vec<SearchEntry>, ConnectError> {
278
        let mut entries = SearchEntry::from_env_var("ARTI_RPC_CONNECT_PATH_OVERRIDE")?;
279
        entries.extend(self.prepend_path_reversed.iter().rev().cloned());
280
        entries.extend(SearchEntry::from_env_var("ARTI_RPC_CONNECT_PATH")?);
281
        entries.extend(Self::default_path_entries());
282
        Ok(entries)
283
    }
284

            
285
    /// Try to connect to an Arti process as specified by this Builder.
286
    pub fn connect(&self) -> Result<RpcConn, ConnectFailure> {
287
        let resolver = tor_config_path::arti_client_base_resolver();
288
        // TODO RPC: Make this configurable.  (Currently, you can override it with
289
        // the environment variable FS_MISTRUST_DISABLE_PERMISSIONS_CHECKS.)
290
        let mistrust = Mistrust::default();
291
        let options = HashMap::new();
292
        let all_entries = self.all_entries().map_err(|e| ConnectFailure {
293
            declined: vec![],
294
            final_desc: None,
295
            final_error: e,
296
        })?;
297
        let mut declined = Vec::new();
298
        for (description, load_result) in all_entries
299
            .into_iter()
300
            .flat_map(|ent| ent.load(&resolver, &mistrust, &options))
301
        {
302
            match load_result.and_then(|e| try_connect(&e, &resolver, &mistrust)) {
303
                Ok(conn) => return Ok(conn),
304
                Err(e) => match e.client_action() {
305
                    ClientErrorAction::Abort => {
306
                        return Err(ConnectFailure {
307
                            declined,
308
                            final_desc: Some(description),
309
                            final_error: e,
310
                        });
311
                    }
312
                    ClientErrorAction::Decline => {
313
                        declined.push((description, e));
314
                    }
315
                },
316
            }
317
        }
318
        Err(ConnectFailure {
319
            declined,
320
            final_desc: None,
321
            final_error: ConnectError::AllAttemptsDeclined,
322
        })
323
    }
324
}
325

            
326
/// Helper: Try to resolve any variables in parsed,
327
/// and open and authenticate an RPC connection to it.
328
///
329
/// This is a separate function from `RpcConnBuilder::connect` to make error handling easier to read.
330
fn try_connect(
331
    parsed: &ParsedConnectPoint,
332
    resolver: &CfgPathResolver,
333
    mistrust: &Mistrust,
334
) -> Result<RpcConn, ConnectError> {
335
    let tor_rpc_connect::client::Connection {
336
        reader,
337
        writer,
338
        auth,
339
        ..
340
    } = parsed.resolve(resolver)?.connect(mistrust)?;
341
    let mut reader = llconn::Reader::new(io::BufReader::new(reader));
342
    let banner = reader
343
        .read_msg()
344
        .map_err(|e| ConnectError::CannotConnect(e.into()))?
345
        .ok_or(ConnectError::InvalidBanner)?;
346
    check_banner(&banner)?;
347

            
348
    let mut conn = RpcConn::new(reader, llconn::Writer::new(writer));
349

            
350
    // TODO RPC: remove this "scheme name" from the protocol?
351
    let session_id = match auth {
352
        RpcAuth::Inherent => conn.authenticate_inherent("auth:inherent")?,
353
        RpcAuth::Cookie {
354
            secret,
355
            server_address,
356
        } => conn.authenticate_cookie(secret.load()?.as_ref(), &server_address)?,
357
        _ => return Err(ConnectError::AuthenticationNotSupported),
358
    };
359
    conn.session = Some(session_id);
360

            
361
    Ok(conn)
362
}
363

            
364
/// Return Ok if `msg` is a banner indicating the correct protocol.
365
fn check_banner(msg: &UnparsedResponse) -> Result<(), ConnectError> {
366
    /// Structure to indicate that this is indeed an Arti RPC connection.
367
    #[derive(serde::Deserialize)]
368
    struct BannerMsg {
369
        /// Ignored value
370
        #[allow(dead_code)]
371
        arti_rpc: serde_json::Value,
372
    }
373
    let _: BannerMsg =
374
        serde_json::from_str(msg.as_str()).map_err(|_| ConnectError::InvalidBanner)?;
375
    Ok(())
376
}
377

            
378
impl SearchEntry {
379
    /// Return an iterator over ParsedConnPoints from this `SearchEntry`.
380
    fn load<'a>(
381
        &self,
382
        resolver: &CfgPathResolver,
383
        mistrust: &Mistrust,
384
        options: &'a HashMap<PathBuf, LoadOptions>,
385
    ) -> ConnPtIterator<'a> {
386
        // Create a ConnPtDescription given a connect point's location, so we can describe
387
        // an error origin.
388
        let descr = |location| ConnPtDescription {
389
            source: self.source,
390
            location,
391
        };
392

            
393
        match &self.location {
394
            SearchLocation::Literal(s) => ConnPtIterator::Singleton(
395
                descr(ConnPtLocation::Literal(s.clone())),
396
                // It's a literal entry, so we just try to parse it.
397
                ParsedConnectPoint::from_str(s).map_err(|e| ConnectError::from(LoadError::from(e))),
398
            ),
399
            SearchLocation::Path {
400
                path: cfgpath,
401
                is_default_entry,
402
            } => {
403
                // Create a ConnPtDescription given an optional expanded path.
404
                let descr_file = |expanded| {
405
                    descr(ConnPtLocation::File {
406
                        path: cfgpath.clone(),
407
                        expanded,
408
                    })
409
                };
410

            
411
                // It's a path, so we need to expand it...
412
                let path = match cfgpath.path(resolver) {
413
                    Ok(p) => p,
414
                    Err(e) => {
415
                        return ConnPtIterator::Singleton(
416
                            descr_file(None),
417
                            Err(ConnectError::CannotResolvePath(e)),
418
                        )
419
                    }
420
                };
421
                if !path.is_absolute() {
422
                    if *is_default_entry {
423
                        return ConnPtIterator::Done;
424
                    } else {
425
                        return ConnPtIterator::Singleton(
426
                            descr_file(Some(path)),
427
                            Err(ConnectError::RelativeConnectFile),
428
                        );
429
                    }
430
                }
431
                // ..then try to load it as a directory...
432
                match ParsedConnectPoint::load_dir(&path, mistrust, options) {
433
                    Ok(iter) => ConnPtIterator::Dir(self.source, cfgpath.clone(), iter),
434
                    Err(LoadError::NotADirectory) => {
435
                        // ... and if that fails, try to load it as a file.
436
                        let loaded =
437
                            ParsedConnectPoint::load_file(&path, mistrust).map_err(|e| e.into());
438
                        ConnPtIterator::Singleton(descr_file(Some(path)), loaded)
439
                    }
440
                    Err(other) => {
441
                        ConnPtIterator::Singleton(descr_file(Some(path)), Err(other.into()))
442
                    }
443
                }
444
            }
445
        }
446
    }
447

            
448
    /// Return a list of `SearchEntry` as specified in an environment variable with a given name.
449
    fn from_env_var(varname: &'static str) -> Result<Vec<Self>, ConnectError> {
450
        match std::env::var(varname) {
451
            Ok(s) if s.is_empty() => Ok(vec![]),
452
            Ok(s) => Self::from_env_string(varname, &s),
453
            Err(std::env::VarError::NotPresent) => Ok(vec![]),
454
            Err(_) => Err(ConnectError::BadEnvironment), // TODO RPC: Preserve more information?
455
        }
456
    }
457

            
458
    /// Return a list of `SearchEntry` as specified in the value `s` from an envvar called `varname`.
459
    fn from_env_string(varname: &'static str, s: &str) -> Result<Vec<Self>, ConnectError> {
460
        // TODO RPC: Possibly we should be using std::env::split_paths, if it behaves correctly
461
        // with our url-escaped entries.
462
        s.split(PATH_SEP_CHAR)
463
            .map(|s| {
464
                Ok(SearchEntry {
465
                    source: ConnPtOrigin::EnvVar(varname),
466
                    location: SearchLocation::from_env_string_elt(s)?,
467
                })
468
            })
469
            .collect()
470
    }
471
}
472

            
473
impl SearchLocation {
474
    /// Return a `SearchLocation` from a single entry within an environment variable.
475
    fn from_env_string_elt(s: &str) -> Result<SearchLocation, ConnectError> {
476
        match s.bytes().next() {
477
            Some(b'%') | Some(b'[') => Ok(Self::Literal(
478
                percent_encoding::percent_decode_str(s)
479
                    .decode_utf8()
480
                    .map_err(|_| ConnectError::BadEnvironment)?
481
                    .into_owned(),
482
            )),
483
            _ => Ok(Self::Path {
484
                path: CfgPath::new(s.to_owned()),
485
                is_default_entry: false,
486
            }),
487
        }
488
    }
489
}
490

            
491
/// Character used to separate path environment variables.
492
const PATH_SEP_CHAR: char = {
493
    cfg_if::cfg_if! {
494
         if #[cfg(windows)] { ';' } else { ':' }
495
    }
496
};
497

            
498
/// Iterator over connect points returned by PathEntry::load().
499
enum ConnPtIterator<'a> {
500
    /// Iterator over a directory
501
    Dir(
502
        /// Origin of the directory
503
        ConnPtOrigin,
504
        /// The directory as configured
505
        CfgPath,
506
        /// Iterator over the elements loaded from the directory
507
        tor_rpc_connect::load::ConnPointIterator<'a>,
508
    ),
509
    /// A single connect point or error
510
    Singleton(ConnPtDescription, Result<ParsedConnectPoint, ConnectError>),
511
    /// An exhausted iterator
512
    Done,
513
}
514

            
515
impl<'a> Iterator for ConnPtIterator<'a> {
516
    // TODO RPC yield the pathbuf too, for better errors.
517
    type Item = (ConnPtDescription, Result<ParsedConnectPoint, ConnectError>);
518

            
519
    fn next(&mut self) -> Option<Self::Item> {
520
        let mut t = ConnPtIterator::Done;
521
        std::mem::swap(self, &mut t);
522
        match t {
523
            ConnPtIterator::Dir(source, cfgpath, mut iter) => {
524
                let next = iter
525
                    .next()
526
                    .map(|(path, res)| (path, res.map_err(|e| e.into())));
527
                let Some((expanded, result)) = next else {
528
                    *self = ConnPtIterator::Done;
529
                    return None;
530
                };
531
                let description = ConnPtDescription {
532
                    source,
533
                    location: ConnPtLocation::WithinDir {
534
                        path: cfgpath.clone(),
535
                        file: expanded,
536
                    },
537
                };
538
                *self = ConnPtIterator::Dir(source, cfgpath, iter);
539
                Some((description, result))
540
            }
541
            ConnPtIterator::Singleton(desc, res) => Some((desc, res)),
542
            ConnPtIterator::Done => None,
543
        }
544
    }
545
}