1
#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))]
2
#![doc = include_str!("../README.md")]
3
// @@ begin lint list maintained by maint/add_warning @@
4
#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5
#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6
#![warn(missing_docs)]
7
#![warn(noop_method_call)]
8
#![warn(unreachable_pub)]
9
#![warn(clippy::all)]
10
#![deny(clippy::await_holding_lock)]
11
#![deny(clippy::cargo_common_metadata)]
12
#![deny(clippy::cast_lossless)]
13
#![deny(clippy::checked_conversions)]
14
#![warn(clippy::cognitive_complexity)]
15
#![deny(clippy::debug_assert_with_mut_call)]
16
#![deny(clippy::exhaustive_enums)]
17
#![deny(clippy::exhaustive_structs)]
18
#![deny(clippy::expl_impl_clone_on_copy)]
19
#![deny(clippy::fallible_impl_from)]
20
#![deny(clippy::implicit_clone)]
21
#![deny(clippy::large_stack_arrays)]
22
#![warn(clippy::manual_ok_or)]
23
#![deny(clippy::missing_docs_in_private_items)]
24
#![warn(clippy::needless_borrow)]
25
#![warn(clippy::needless_pass_by_value)]
26
#![warn(clippy::option_option)]
27
#![deny(clippy::print_stderr)]
28
#![deny(clippy::print_stdout)]
29
#![warn(clippy::rc_buffer)]
30
#![deny(clippy::ref_option_ref)]
31
#![warn(clippy::semicolon_if_nothing_returned)]
32
#![warn(clippy::trait_duplication_in_bounds)]
33
#![deny(clippy::unchecked_duration_subtraction)]
34
#![deny(clippy::unnecessary_wraps)]
35
#![warn(clippy::unseparated_literal_suffix)]
36
#![deny(clippy::unwrap_used)]
37
#![deny(clippy::mod_module_files)]
38
#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39
#![allow(clippy::uninlined_format_args)]
40
#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41
#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42
#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43
#![allow(clippy::needless_lifetimes)] // See arti#1765
44
#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
45
//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
46

            
47
// TODO probably remove this at some point - see tpo/core/arti#1060
48
#![cfg_attr(
49
    not(all(feature = "full", feature = "experimental")),
50
    allow(unused_imports)
51
)]
52

            
53
mod err;
54
pub mod request;
55
mod response;
56
mod util;
57

            
58
use tor_circmgr::{CircMgr, DirInfo};
59
use tor_error::bad_api_usage;
60
use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt};
61

            
62
// Zlib is required; the others are optional.
63
#[cfg(feature = "xz")]
64
use async_compression::futures::bufread::XzDecoder;
65
use async_compression::futures::bufread::ZlibDecoder;
66
#[cfg(feature = "zstd")]
67
use async_compression::futures::bufread::ZstdDecoder;
68

            
69
use futures::FutureExt;
70
use futures::io::{
71
    AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader,
72
};
73
use memchr::memchr;
74
use std::sync::Arc;
75
use std::time::Duration;
76
use tracing::info;
77

            
78
pub use err::{Error, RequestError, RequestFailedError};
79
pub use response::{DirResponse, SourceInfo};
80

            
81
/// Type for results returned in this crate.
82
pub type Result<T> = std::result::Result<T, Error>;
83

            
84
/// Type for internal results  containing a RequestError.
85
pub type RequestResult<T> = std::result::Result<T, RequestError>;
86

            
87
/// Flag to declare whether a request is anonymized or not.
88
///
89
/// Some requests (like those to download onion service descriptors) are always
90
/// anonymized, and should never be sent in a way that leaks information about
91
/// our settings or configuration.
92
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
93
#[non_exhaustive]
94
pub enum AnonymizedRequest {
95
    /// This request should not leak any information about our configuration.
96
    Anonymized,
97
    /// This request is allowed to include information about our capabilities.
98
    Direct,
99
}
100

            
101
/// Fetch the resource described by `req` over the Tor network.
102
///
103
/// Circuits are built or found using `circ_mgr`, using paths
104
/// constructed using `dirinfo`.
105
///
106
/// For more fine-grained control over the circuit and stream used,
107
/// construct them yourself, and then call [`send_request`] instead.
108
///
109
/// # TODO
110
///
111
/// This is the only function in this crate that knows about CircMgr and
112
/// DirInfo.  Perhaps this function should move up a level into DirMgr?
113
pub async fn get_resource<CR, R, SP>(
114
    req: &CR,
115
    dirinfo: DirInfo<'_>,
116
    runtime: &SP,
117
    circ_mgr: Arc<CircMgr<R>>,
118
) -> Result<DirResponse>
119
where
120
    CR: request::Requestable + ?Sized,
121
    R: Runtime,
122
    SP: SleepProvider,
123
{
124
    let tunnel = circ_mgr.get_or_launch_dir(dirinfo).await?;
125

            
126
    if req.anonymized() == AnonymizedRequest::Anonymized {
127
        return Err(bad_api_usage!("Tried to use get_resource for an anonymized request").into());
128
    }
129

            
130
    // TODO(nickm) This should be an option, and is too long.
131
    let begin_timeout = Duration::from_secs(5);
132
    let source = match SourceInfo::from_tunnel(&tunnel) {
133
        Ok(source) => source,
134
        Err(e) => {
135
            return Err(Error::RequestFailed(RequestFailedError {
136
                source: None,
137
                error: e.into(),
138
            }));
139
        }
140
    };
141

            
142
    let wrap_err = |error| {
143
        Error::RequestFailed(RequestFailedError {
144
            source: source.clone(),
145
            error,
146
        })
147
    };
148

            
149
    req.check_circuit(&tunnel).await.map_err(wrap_err)?;
150

            
151
    // Launch the stream.
152
    let mut stream = runtime
153
        .timeout(begin_timeout, tunnel.begin_dir_stream())
154
        .await
155
        .map_err(RequestError::from)
156
        .map_err(wrap_err)?
157
        .map_err(RequestError::from)
158
        .map_err(wrap_err)?; // TODO(nickm) handle fatalities here too
159

            
160
    // TODO: Perhaps we want separate timeouts for each phase of this.
161
    // For now, we just use higher-level timeouts in `dirmgr`.
162
    let r = send_request(runtime, req, &mut stream, source.clone()).await;
163

            
164
    if should_retire_circ(&r) {
165
        retire_circ(&circ_mgr, &tunnel.unique_id(), "Partial response");
166
    }
167

            
168
    r
169
}
170

            
171
/// Return true if `result` holds an error indicating that we should retire the
172
/// circuit used for the corresponding request.
173
fn should_retire_circ(result: &Result<DirResponse>) -> bool {
174
    match result {
175
        Err(e) => e.should_retire_circ(),
176
        Ok(dr) => dr.error().map(RequestError::should_retire_circ) == Some(true),
177
    }
178
}
179

            
180
/// Fetch a Tor directory object from a provided stream.
181
#[deprecated(since = "0.8.1", note = "Use send_request instead.")]
182
pub async fn download<R, S, SP>(
183
    runtime: &SP,
184
    req: &R,
185
    stream: &mut S,
186
    source: Option<SourceInfo>,
187
) -> Result<DirResponse>
188
where
189
    R: request::Requestable + ?Sized,
190
    S: AsyncRead + AsyncWrite + Send + Unpin,
191
    SP: SleepProvider,
192
{
193
    send_request(runtime, req, stream, source).await
194
}
195

            
196
/// Fetch or upload a Tor directory object using the provided stream.
197
///
198
/// To do this, we send a simple HTTP/1.0 request for the described
199
/// object in `req` over `stream`, and then wait for a response.  In
200
/// log messages, we describe the origin of the data as coming from
201
/// `source`.
202
///
203
/// # Notes
204
///
205
/// It's kind of bogus to have a 'source' field here at all; we may
206
/// eventually want to remove it.
207
///
208
/// This function doesn't close the stream; you may want to do that
209
/// yourself.
210
///
211
/// The only error variant returned is [`Error::RequestFailed`].
212
// TODO: should the error return type change to `RequestFailedError`?
213
// If so, that would simplify some code in_dirmgr::bridgedesc.
214
176
pub async fn send_request<R, S, SP>(
215
176
    runtime: &SP,
216
176
    req: &R,
217
176
    stream: &mut S,
218
176
    source: Option<SourceInfo>,
219
176
) -> Result<DirResponse>
220
176
where
221
176
    R: request::Requestable + ?Sized,
222
176
    S: AsyncRead + AsyncWrite + Send + Unpin,
223
176
    SP: SleepProvider,
224
176
{
225
176
    let wrap_err = |error| {
226
40
        Error::RequestFailed(RequestFailedError {
227
40
            source: source.clone(),
228
40
            error,
229
40
        })
230
40
    };
231

            
232
176
    let partial_ok = req.partial_response_body_ok();
233
176
    let maxlen = req.max_response_len();
234
176
    let anonymized = req.anonymized();
235
176
    let req = req.make_request().map_err(wrap_err)?;
236
176
    let encoded = util::encode_request(&req);
237
176

            
238
176
    // Write the request.
239
176
    stream
240
176
        .write_all(encoded.as_bytes())
241
176
        .await
242
176
        .map_err(RequestError::from)
243
176
        .map_err(wrap_err)?;
244
176
    stream
245
176
        .flush()
246
176
        .await
247
176
        .map_err(RequestError::from)
248
176
        .map_err(wrap_err)?;
249

            
250
176
    let mut buffered = BufReader::new(stream);
251

            
252
    // Handle the response
253
    // TODO: should there be a separate timeout here?
254
176
    let header = read_headers(&mut buffered).await.map_err(wrap_err)?;
255
138
    if header.status != Some(200) {
256
34
        return Ok(DirResponse::new(
257
34
            header.status.unwrap_or(0),
258
34
            header.status_message,
259
34
            None,
260
34
            vec![],
261
34
            source,
262
34
        ));
263
104
    }
264

            
265
104
    let mut decoder =
266
104
        get_decoder(buffered, header.encoding.as_deref(), anonymized).map_err(wrap_err)?;
267

            
268
104
    let mut result = Vec::new();
269
104
    let ok = read_and_decompress(runtime, &mut decoder, maxlen, &mut result).await;
270

            
271
104
    let ok = match (partial_ok, ok, result.len()) {
272
2
        (true, Err(e), n) if n > 0 => {
273
2
            // Note that we _don't_ return here: we want the partial response.
274
2
            Err(e)
275
        }
276
2
        (_, Err(e), _) => {
277
2
            return Err(wrap_err(e));
278
        }
279
100
        (_, Ok(()), _) => Ok(()),
280
    };
281

            
282
102
    Ok(DirResponse::new(200, None, ok.err(), result, source))
283
176
}
284

            
285
/// Maximum length for the HTTP headers in a single request or response.
286
///
287
/// Chosen more or less arbitrarily.
288
const MAX_HEADERS_LEN: usize = 16384;
289

            
290
/// Read and parse HTTP/1 headers from `stream`.
291
184
async fn read_headers<S>(stream: &mut S) -> RequestResult<HeaderStatus>
292
184
where
293
184
    S: AsyncBufRead + Unpin,
294
184
{
295
184
    let mut buf = Vec::with_capacity(1024);
296

            
297
    loop {
298
        // TODO: it's inefficient to do this a line at a time; it would
299
        // probably be better to read until the CRLF CRLF ending of the
300
        // response.  But this should be fast enough.
301
354
        let n = read_until_limited(stream, b'\n', 2048, &mut buf).await?;
302

            
303
        // TODO(nickm): Better maximum and/or let this expand.
304
322
        let mut headers = [httparse::EMPTY_HEADER; 32];
305
322
        let mut response = httparse::Response::new(&mut headers);
306
322

            
307
322
        match response.parse(&buf[..])? {
308
            httparse::Status::Partial => {
309
                // We didn't get a whole response; we may need to try again.
310

            
311
178
                if n == 0 {
312
                    // We hit an EOF; no more progress can be made.
313
6
                    return Err(RequestError::TruncatedHeaders);
314
172
                }
315
172

            
316
172
                if buf.len() >= MAX_HEADERS_LEN {
317
2
                    return Err(RequestError::HeadersTooLong(buf.len()));
318
170
                }
319
            }
320
142
            httparse::Status::Complete(n_parsed) => {
321
142
                if response.code != Some(200) {
322
36
                    return Ok(HeaderStatus {
323
36
                        status: response.code,
324
36
                        status_message: response.reason.map(str::to_owned),
325
36
                        encoding: None,
326
36
                    });
327
106
                }
328
106
                let encoding = if let Some(enc) = response
329
106
                    .headers
330
106
                    .iter()
331
106
                    .find(|h| h.name == "Content-Encoding")
332
                {
333
6
                    Some(String::from_utf8(enc.value.to_vec())?)
334
                } else {
335
100
                    None
336
                };
337
                /*
338
                if let Some(clen) = response.headers.iter().find(|h| h.name == "Content-Length") {
339
                    let clen = std::str::from_utf8(clen.value)?;
340
                    length = Some(clen.parse()?);
341
                }
342
                 */
343
106
                assert!(n_parsed == buf.len());
344
106
                return Ok(HeaderStatus {
345
106
                    status: Some(200),
346
106
                    status_message: None,
347
106
                    encoding,
348
106
                });
349
            }
350
        }
351
170
        if n == 0 {
352
            return Err(RequestError::TruncatedHeaders);
353
170
        }
354
    }
355
184
}
356

            
357
/// Return value from read_headers
358
#[derive(Debug, Clone)]
359
struct HeaderStatus {
360
    /// HTTP status code.
361
    status: Option<u16>,
362
    /// HTTP status message associated with the status code.
363
    status_message: Option<String>,
364
    /// The Content-Encoding header, if any.
365
    encoding: Option<String>,
366
}
367

            
368
/// Helper: download directory information from `stream` and
369
/// decompress it into a result buffer.  Assumes that `buf` is empty.
370
///
371
/// If we get more than maxlen bytes after decompression, give an error.
372
///
373
/// Returns the status of our download attempt, stores any data that
374
/// we were able to download into `result`.  Existing contents of
375
/// `result` are overwritten.
376
118
async fn read_and_decompress<S, SP>(
377
118
    runtime: &SP,
378
118
    mut stream: S,
379
118
    maxlen: usize,
380
118
    result: &mut Vec<u8>,
381
118
) -> RequestResult<()>
382
118
where
383
118
    S: AsyncRead + Unpin,
384
118
    SP: SleepProvider,
385
118
{
386
118
    let buffer_window_size = 1024;
387
118
    let mut written_total: usize = 0;
388
118
    // TODO(nickm): This should be an option, and is maybe too long.
389
118
    // Though for some users it may be too short?
390
118
    let read_timeout = Duration::from_secs(10);
391
118
    let timer = runtime.sleep(read_timeout).fuse();
392
118
    futures::pin_mut!(timer);
393

            
394
    loop {
395
        // allocate buffer for next read
396
594
        result.resize(written_total + buffer_window_size, 0);
397
594
        let buf: &mut [u8] = &mut result[written_total..written_total + buffer_window_size];
398

            
399
594
        let status = futures::select! {
400
594
            status = stream.read(buf).fuse() => status,
401
            _ = timer => {
402
                result.resize(written_total, 0); // truncate as needed
403
                return Err(RequestError::DirTimeout);
404
            }
405
        };
406
594
        let written_in_this_loop = match status {
407
588
            Ok(n) => n,
408
6
            Err(other) => {
409
6
                result.resize(written_total, 0); // truncate as needed
410
6
                return Err(other.into());
411
            }
412
        };
413

            
414
588
        written_total += written_in_this_loop;
415
588

            
416
588
        // exit conditions below
417
588

            
418
588
        if written_in_this_loop == 0 {
419
            /*
420
            in case we read less than `buffer_window_size` in last `read`
421
            we need to shrink result because otherwise we'll return those
422
            un-read 0s
423
            */
424
110
            if written_total < result.len() {
425
110
                result.resize(written_total, 0);
426
110
            }
427
110
            return Ok(());
428
478
        }
429
478

            
430
478
        // TODO: It would be good to detect compression bombs, but
431
478
        // that would require access to the internal stream, which
432
478
        // would in turn require some tricky programming.  For now, we
433
478
        // use the maximum length here to prevent an attacker from
434
478
        // filling our RAM.
435
478
        if written_total > maxlen {
436
2
            result.resize(maxlen, 0);
437
2
            return Err(RequestError::ResponseTooLong(written_total));
438
476
        }
439
    }
440
118
}
441

            
442
/// Retire a directory circuit because of an error we've encountered on it.
443
fn retire_circ<R>(circ_mgr: &Arc<CircMgr<R>>, id: &tor_proto::circuit::UniqId, error: &str)
444
where
445
    R: Runtime,
446
{
447
    info!(
448
        "{}: Retiring circuit because of directory failure: {}",
449
        &id, &error
450
    );
451
    circ_mgr.retire_circ(id);
452
}
453

            
454
/// As AsyncBufReadExt::read_until, but stops after reading `max` bytes.
455
///
456
/// Note that this function might not actually read any byte of value
457
/// `byte`, since EOF might occur, or we might fill the buffer.
458
///
459
/// A return value of 0 indicates an end-of-file.
460
360
async fn read_until_limited<S>(
461
360
    stream: &mut S,
462
360
    byte: u8,
463
360
    max: usize,
464
360
    buf: &mut Vec<u8>,
465
360
) -> std::io::Result<usize>
466
360
where
467
360
    S: AsyncBufRead + Unpin,
468
360
{
469
360
    let mut n_added = 0;
470
    loop {
471
518
        let data = stream.fill_buf().await?;
472
486
        if data.is_empty() {
473
            // End-of-file has been reached.
474
12
            return Ok(n_added);
475
474
        }
476
474
        debug_assert!(n_added < max);
477
474
        let remaining_space = max - n_added;
478
474
        let (available, found_byte) = match memchr(byte, data) {
479
302
            Some(idx) => (idx + 1, true),
480
172
            None => (data.len(), false),
481
        };
482
474
        debug_assert!(available >= 1);
483
474
        let n_to_copy = std::cmp::min(remaining_space, available);
484
474
        buf.extend(&data[..n_to_copy]);
485
474
        stream.consume_unpin(n_to_copy);
486
474
        n_added += n_to_copy;
487
474
        if found_byte || n_added == max {
488
316
            return Ok(n_added);
489
158
        }
490
    }
491
360
}
492

            
493
/// Helper: Return a boxed decoder object that wraps the stream  $s.
494
macro_rules! decoder {
495
    ($dec:ident, $s:expr) => {{
496
        let mut decoder = $dec::new($s);
497
        decoder.multiple_members(true);
498
        Ok(Box::new(decoder))
499
    }};
500
}
501

            
502
/// Wrap `stream` in an appropriate type to undo the content encoding
503
/// as described in `encoding`.
504
120
fn get_decoder<'a, S: AsyncBufRead + Unpin + Send + 'a>(
505
120
    stream: S,
506
120
    encoding: Option<&str>,
507
120
    anonymized: AnonymizedRequest,
508
120
) -> RequestResult<Box<dyn AsyncRead + Unpin + Send + 'a>> {
509
    use AnonymizedRequest::Direct;
510
120
    match (encoding, anonymized) {
511
116
        (None | Some("identity"), _) => Ok(Box::new(stream)),
512
14
        (Some("deflate"), _) => decoder!(ZlibDecoder, stream),
513
        // We only admit to supporting these on a direct connection; otherwise,
514
        // a hostile directory could send them back even though we hadn't
515
        // requested them.
516
        #[cfg(feature = "xz")]
517
6
        (Some("x-tor-lzma"), Direct) => decoder!(XzDecoder, stream),
518
        #[cfg(feature = "zstd")]
519
4
        (Some("x-zstd"), Direct) => decoder!(ZstdDecoder, stream),
520
2
        (Some(other), _) => Err(RequestError::ContentEncoding(other.into())),
521
    }
522
120
}
523

            
524
#[cfg(test)]
525
mod test {
526
    // @@ begin test lint list maintained by maint/add_warning @@
527
    #![allow(clippy::bool_assert_comparison)]
528
    #![allow(clippy::clone_on_copy)]
529
    #![allow(clippy::dbg_macro)]
530
    #![allow(clippy::mixed_attributes_style)]
531
    #![allow(clippy::print_stderr)]
532
    #![allow(clippy::print_stdout)]
533
    #![allow(clippy::single_char_pattern)]
534
    #![allow(clippy::unwrap_used)]
535
    #![allow(clippy::unchecked_duration_subtraction)]
536
    #![allow(clippy::useless_vec)]
537
    #![allow(clippy::needless_pass_by_value)]
538
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
539
    use super::*;
540
    use tor_rtmock::io::stream_pair;
541

            
542
    #[allow(deprecated)] // TODO #1885
543
    use tor_rtmock::time::MockSleepProvider;
544

            
545
    use futures_await_test::async_test;
546

            
547
    #[async_test]
548
    async fn test_read_until_limited() -> RequestResult<()> {
549
        let mut out = Vec::new();
550
        let bytes = b"This line eventually ends\nthen comes another\n";
551

            
552
        // Case 1: find a whole line.
553
        let mut s = &bytes[..];
554
        let res = read_until_limited(&mut s, b'\n', 100, &mut out).await;
555
        assert_eq!(res?, 26);
556
        assert_eq!(&out[..], b"This line eventually ends\n");
557

            
558
        // Case 2: reach the limit.
559
        let mut s = &bytes[..];
560
        out.clear();
561
        let res = read_until_limited(&mut s, b'\n', 10, &mut out).await;
562
        assert_eq!(res?, 10);
563
        assert_eq!(&out[..], b"This line ");
564

            
565
        // Case 3: reach EOF.
566
        let mut s = &bytes[..];
567
        out.clear();
568
        let res = read_until_limited(&mut s, b'Z', 100, &mut out).await;
569
        assert_eq!(res?, 45);
570
        assert_eq!(&out[..], &bytes[..]);
571

            
572
        Ok(())
573
    }
574

            
575
    // Basic decompression wrapper.
576
    async fn decomp_basic(
577
        encoding: Option<&str>,
578
        data: &[u8],
579
        maxlen: usize,
580
    ) -> (RequestResult<()>, Vec<u8>) {
581
        // We don't need to do anything fancy here, since we aren't simulating
582
        // a timeout.
583
        #[allow(deprecated)] // TODO #1885
584
        let mock_time = MockSleepProvider::new(std::time::SystemTime::now());
585

            
586
        let mut output = Vec::new();
587
        let mut stream = match get_decoder(data, encoding, AnonymizedRequest::Direct) {
588
            Ok(s) => s,
589
            Err(e) => return (Err(e), output),
590
        };
591

            
592
        let r = read_and_decompress(&mock_time, &mut stream, maxlen, &mut output).await;
593

            
594
        (r, output)
595
    }
596

            
597
    #[async_test]
598
    async fn decompress_identity() -> RequestResult<()> {
599
        let mut text = Vec::new();
600
        for _ in 0..1000 {
601
            text.extend(b"This is a string with a nontrivial length that we'll use to make sure that the loop is executed more than once.");
602
        }
603

            
604
        let limit = 10 << 20;
605
        let (s, r) = decomp_basic(None, &text[..], limit).await;
606
        s?;
607
        assert_eq!(r, text);
608

            
609
        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
610
        s?;
611
        assert_eq!(r, text);
612

            
613
        // Try truncated result
614
        let limit = 100;
615
        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
616
        assert!(s.is_err());
617
        assert_eq!(r, &text[..100]);
618

            
619
        Ok(())
620
    }
621

            
622
    #[async_test]
623
    async fn decomp_zlib() -> RequestResult<()> {
624
        let compressed =
625
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap();
626

            
627
        let limit = 10 << 20;
628
        let (s, r) = decomp_basic(Some("deflate"), &compressed, limit).await;
629
        s?;
630
        assert_eq!(r, b"One fish Two fish Red fish Blue fish");
631

            
632
        Ok(())
633
    }
634

            
635
    #[cfg(feature = "zstd")]
636
    #[async_test]
637
    async fn decomp_zstd() -> RequestResult<()> {
638
        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
639
        let limit = 10 << 20;
640
        let (s, r) = decomp_basic(Some("x-zstd"), &compressed, limit).await;
641
        s?;
642
        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
643

            
644
        Ok(())
645
    }
646

            
647
    #[cfg(feature = "xz")]
648
    #[async_test]
649
    async fn decomp_xz2() -> RequestResult<()> {
650
        // Not so good at tiny files...
651
        let compressed = hex::decode("fd377a585a000004e6d6b446020021011c00000010cf58cce00024001d5d00279b88a202ca8612cfb3c19c87c34248a570451e4851d3323d34ab8000000000000901af64854c91f600013925d6ec06651fb6f37d010000000004595a").unwrap();
652
        let limit = 10 << 20;
653
        let (s, r) = decomp_basic(Some("x-tor-lzma"), &compressed, limit).await;
654
        s?;
655
        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
656

            
657
        Ok(())
658
    }
659

            
660
    #[async_test]
661
    async fn decomp_unknown() {
662
        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
663
        let limit = 10 << 20;
664
        let (s, _r) = decomp_basic(Some("x-proprietary-rle"), &compressed, limit).await;
665

            
666
        assert!(matches!(s, Err(RequestError::ContentEncoding(_))));
667
    }
668

            
669
    #[async_test]
670
    async fn decomp_bad_data() {
671
        let compressed = b"This is not good zlib data";
672
        let limit = 10 << 20;
673
        let (s, _r) = decomp_basic(Some("deflate"), compressed, limit).await;
674

            
675
        // This should possibly be a different type in the future.
676
        assert!(matches!(s, Err(RequestError::IoError(_))));
677
    }
678

            
679
    #[async_test]
680
    async fn headers_ok() -> RequestResult<()> {
681
        let text = b"HTTP/1.0 200 OK\r\nDate: ignored\r\nContent-Encoding: Waffles\r\n\r\n";
682

            
683
        let mut s = &text[..];
684
        let h = read_headers(&mut s).await?;
685

            
686
        assert_eq!(h.status, Some(200));
687
        assert_eq!(h.encoding.as_deref(), Some("Waffles"));
688

            
689
        // now try truncated
690
        let mut s = &text[..15];
691
        let h = read_headers(&mut s).await;
692
        assert!(matches!(h, Err(RequestError::TruncatedHeaders)));
693

            
694
        // now try with no encoding.
695
        let text = b"HTTP/1.0 404 Not found\r\n\r\n";
696
        let mut s = &text[..];
697
        let h = read_headers(&mut s).await?;
698

            
699
        assert_eq!(h.status, Some(404));
700
        assert!(h.encoding.is_none());
701

            
702
        Ok(())
703
    }
704

            
705
    #[async_test]
706
    async fn headers_bogus() -> Result<()> {
707
        let text = b"HTTP/999.0 WHAT EVEN\r\n\r\n";
708
        let mut s = &text[..];
709
        let h = read_headers(&mut s).await;
710

            
711
        assert!(h.is_err());
712
        assert!(matches!(h, Err(RequestError::HttparseError(_))));
713
        Ok(())
714
    }
715

            
716
    /// Run a trivial download example with a response provided as a binary
717
    /// string.
718
    ///
719
    /// Return the directory response (if any) and the request as encoded (if
720
    /// any.)
721
    fn run_download_test<Req: request::Requestable>(
722
        req: Req,
723
        response: &[u8],
724
    ) -> (Result<DirResponse>, RequestResult<Vec<u8>>) {
725
        let (mut s1, s2) = stream_pair();
726
        let (mut s2_r, mut s2_w) = s2.split();
727

            
728
        tor_rtcompat::test_with_one_runtime!(|rt| async move {
729
            let rt2 = rt.clone();
730
            let (v1, v2, v3): (
731
                Result<DirResponse>,
732
                RequestResult<Vec<u8>>,
733
                RequestResult<()>,
734
            ) = futures::join!(
735
                async {
736
                    // Run the download function.
737
                    let r = send_request(&rt, &req, &mut s1, None).await;
738
                    s1.close().await.map_err(|error| {
739
                        Error::RequestFailed(RequestFailedError {
740
                            source: None,
741
                            error: error.into(),
742
                        })
743
                    })?;
744
                    r
745
                },
746
                async {
747
                    // Take the request from the client, and return it in "v2"
748
                    let mut v = Vec::new();
749
                    s2_r.read_to_end(&mut v).await?;
750
                    Ok(v)
751
                },
752
                async {
753
                    // Send back a response.
754
                    s2_w.write_all(response).await?;
755
                    // We wait a moment to give the other side time to notice it
756
                    // has data.
757
                    //
758
                    // (Tentative diagnosis: The `async-compress` crate seems to
759
                    // be behave differently depending on whether the "close"
760
                    // comes right after the incomplete data or whether it comes
761
                    // after a delay.  If there's a delay, it notices the
762
                    // truncated data and tells us about it. But when there's
763
                    // _no_delay, it treats the data as an error and doesn't
764
                    // tell our code.)
765

            
766
                    // TODO: sleeping in tests is not great.
767
                    rt2.sleep(Duration::from_millis(50)).await;
768
                    s2_w.close().await?;
769
                    Ok(())
770
                }
771
            );
772

            
773
            assert!(v3.is_ok());
774

            
775
            (v1, v2)
776
        })
777
    }
778

            
779
    #[test]
780
    fn test_send_request() -> RequestResult<()> {
781
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
782

            
783
        let (response, request) = run_download_test(
784
            req,
785
            b"HTTP/1.0 200 OK\r\n\r\nThis is where the descs would go.",
786
        );
787

            
788
        let request = request?;
789
        assert!(request[..].starts_with(
790
            b"GET /tor/micro/d/CQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQk HTTP/1.0\r\n"
791
        ));
792

            
793
        let response = response.unwrap();
794
        assert_eq!(response.status_code(), 200);
795
        assert!(!response.is_partial());
796
        assert!(response.error().is_none());
797
        assert!(response.source().is_none());
798
        let out_ref = response.output_unchecked();
799
        assert_eq!(out_ref, b"This is where the descs would go.");
800
        let out = response.into_output_unchecked();
801
        assert_eq!(&out, b"This is where the descs would go.");
802

            
803
        Ok(())
804
    }
805

            
806
    #[test]
807
    fn test_download_truncated() {
808
        // Request only one md, so "partial ok" will not be set.
809
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
810
        let mut response_text: Vec<u8> =
811
            (*b"HTTP/1.0 200 OK\r\nContent-Encoding: deflate\r\n\r\n").into();
812
        // "One fish two fish" as above twice, but truncated the second time
813
        response_text.extend(
814
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap(),
815
        );
816
        response_text.extend(
817
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5").unwrap(),
818
        );
819
        let (response, request) = run_download_test(req, &response_text);
820
        assert!(request.is_ok());
821
        assert!(response.is_err()); // The whole download should fail, since partial_ok wasn't set.
822

            
823
        // request two microdescs, so "partial_ok" will be set.
824
        let req: request::MicrodescRequest = vec![[9; 32]; 2].into_iter().collect();
825

            
826
        let (response, request) = run_download_test(req, &response_text);
827
        assert!(request.is_ok());
828

            
829
        let response = response.unwrap();
830
        assert_eq!(response.status_code(), 200);
831
        assert!(response.error().is_some());
832
        assert!(response.is_partial());
833
        assert!(response.output_unchecked().len() < 37 * 2);
834
        assert!(response.output_unchecked().starts_with(b"One fish"));
835
    }
836

            
837
    #[test]
838
    fn test_404() {
839
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
840
        let response_text = b"HTTP/1.0 418 I'm a teapot\r\n\r\n";
841
        let (response, _request) = run_download_test(req, response_text);
842

            
843
        assert_eq!(response.unwrap().status_code(), 418);
844
    }
845

            
846
    #[test]
847
    fn test_headers_truncated() {
848
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
849
        let response_text = b"HTTP/1.0 404 truncation happens here\r\n";
850
        let (response, _request) = run_download_test(req, response_text);
851

            
852
        assert!(matches!(
853
            response,
854
            Err(Error::RequestFailed(RequestFailedError {
855
                error: RequestError::TruncatedHeaders,
856
                ..
857
            }))
858
        ));
859

            
860
        // Try a completely empty response.
861
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
862
        let response_text = b"";
863
        let (response, _request) = run_download_test(req, response_text);
864

            
865
        assert!(matches!(
866
            response,
867
            Err(Error::RequestFailed(RequestFailedError {
868
                error: RequestError::TruncatedHeaders,
869
                ..
870
            }))
871
        ));
872
    }
873

            
874
    #[test]
875
    fn test_headers_too_long() {
876
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
877
        let mut response_text: Vec<u8> = (*b"HTTP/1.0 418 I'm a teapot\r\nX-Too-Many-As: ").into();
878
        response_text.resize(16384, b'A');
879
        let (response, _request) = run_download_test(req, &response_text);
880

            
881
        assert!(response.as_ref().unwrap_err().should_retire_circ());
882
        assert!(matches!(
883
            response,
884
            Err(Error::RequestFailed(RequestFailedError {
885
                error: RequestError::HeadersTooLong(_),
886
                ..
887
            }))
888
        ));
889
    }
890

            
891
    // TODO: test with bad utf-8
892
}