1
#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))]
2
#![doc = include_str!("../README.md")]
3
// @@ begin lint list maintained by maint/add_warning @@
4
#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5
#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6
#![warn(missing_docs)]
7
#![warn(noop_method_call)]
8
#![warn(unreachable_pub)]
9
#![warn(clippy::all)]
10
#![deny(clippy::await_holding_lock)]
11
#![deny(clippy::cargo_common_metadata)]
12
#![deny(clippy::cast_lossless)]
13
#![deny(clippy::checked_conversions)]
14
#![warn(clippy::cognitive_complexity)]
15
#![deny(clippy::debug_assert_with_mut_call)]
16
#![deny(clippy::exhaustive_enums)]
17
#![deny(clippy::exhaustive_structs)]
18
#![deny(clippy::expl_impl_clone_on_copy)]
19
#![deny(clippy::fallible_impl_from)]
20
#![deny(clippy::implicit_clone)]
21
#![deny(clippy::large_stack_arrays)]
22
#![warn(clippy::manual_ok_or)]
23
#![deny(clippy::missing_docs_in_private_items)]
24
#![warn(clippy::needless_borrow)]
25
#![warn(clippy::needless_pass_by_value)]
26
#![warn(clippy::option_option)]
27
#![deny(clippy::print_stderr)]
28
#![deny(clippy::print_stdout)]
29
#![warn(clippy::rc_buffer)]
30
#![deny(clippy::ref_option_ref)]
31
#![warn(clippy::semicolon_if_nothing_returned)]
32
#![warn(clippy::trait_duplication_in_bounds)]
33
#![deny(clippy::unchecked_duration_subtraction)]
34
#![deny(clippy::unnecessary_wraps)]
35
#![warn(clippy::unseparated_literal_suffix)]
36
#![deny(clippy::unwrap_used)]
37
#![deny(clippy::mod_module_files)]
38
#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39
#![allow(clippy::uninlined_format_args)]
40
#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41
#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42
#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43
#![allow(clippy::needless_lifetimes)] // See arti#1765
44
//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
45

            
46
// TODO probably remove this at some point - see tpo/core/arti#1060
47
#![cfg_attr(
48
    not(all(feature = "full", feature = "experimental")),
49
    allow(unused_imports)
50
)]
51

            
52
mod err;
53
pub mod request;
54
mod response;
55
mod util;
56

            
57
use tor_circmgr::{CircMgr, DirInfo};
58
use tor_error::bad_api_usage;
59
use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt};
60

            
61
// Zlib is required; the others are optional.
62
#[cfg(feature = "xz")]
63
use async_compression::futures::bufread::XzDecoder;
64
use async_compression::futures::bufread::ZlibDecoder;
65
#[cfg(feature = "zstd")]
66
use async_compression::futures::bufread::ZstdDecoder;
67

            
68
use futures::io::{
69
    AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader,
70
};
71
use futures::FutureExt;
72
use memchr::memchr;
73
use std::sync::Arc;
74
use std::time::Duration;
75
use tracing::info;
76

            
77
pub use err::{Error, RequestError, RequestFailedError};
78
pub use response::{DirResponse, SourceInfo};
79

            
80
/// Type for results returned in this crate.
81
pub type Result<T> = std::result::Result<T, Error>;
82

            
83
/// Type for internal results  containing a RequestError.
84
pub type RequestResult<T> = std::result::Result<T, RequestError>;
85

            
86
/// Flag to declare whether a request is anonymized or not.
87
///
88
/// Some requests (like those to download onion service descriptors) are always
89
/// anonymized, and should never be sent in a way that leaks information about
90
/// our settings or configuration.
91
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
92
#[non_exhaustive]
93
pub enum AnonymizedRequest {
94
    /// This request should not leak any information about our configuration.
95
    Anonymized,
96
    /// This request is allowed to include information about our capabilities.
97
    Direct,
98
}
99

            
100
/// Fetch the resource described by `req` over the Tor network.
101
///
102
/// Circuits are built or found using `circ_mgr`, using paths
103
/// constructed using `dirinfo`.
104
///
105
/// For more fine-grained control over the circuit and stream used,
106
/// construct them yourself, and then call [`send_request`] instead.
107
///
108
/// # TODO
109
///
110
/// This is the only function in this crate that knows about CircMgr and
111
/// DirInfo.  Perhaps this function should move up a level into DirMgr?
112
pub async fn get_resource<CR, R, SP>(
113
    req: &CR,
114
    dirinfo: DirInfo<'_>,
115
    runtime: &SP,
116
    circ_mgr: Arc<CircMgr<R>>,
117
) -> Result<DirResponse>
118
where
119
    CR: request::Requestable + ?Sized,
120
    R: Runtime,
121
    SP: SleepProvider,
122
{
123
    let circuit = circ_mgr.get_or_launch_dir(dirinfo).await?;
124

            
125
    if req.anonymized() == AnonymizedRequest::Anonymized {
126
        return Err(bad_api_usage!("Tried to use get_resource for an anonymized request").into());
127
    }
128

            
129
    // TODO(nickm) This should be an option, and is too long.
130
    let begin_timeout = Duration::from_secs(5);
131
    let source = match SourceInfo::from_circuit(&circuit) {
132
        Ok(source) => source,
133
        Err(e) => {
134
            return Err(Error::RequestFailed(RequestFailedError {
135
                source: None,
136
                error: e.into(),
137
            }));
138
        }
139
    };
140

            
141
    let wrap_err = |error| {
142
        Error::RequestFailed(RequestFailedError {
143
            source: Some(source.clone()),
144
            error,
145
        })
146
    };
147

            
148
    req.check_circuit(&circuit).await.map_err(wrap_err)?;
149

            
150
    // Launch the stream.
151
    let mut stream = runtime
152
        .timeout(begin_timeout, circuit.begin_dir_stream())
153
        .await
154
        .map_err(RequestError::from)
155
        .map_err(wrap_err)?
156
        .map_err(RequestError::from)
157
        .map_err(wrap_err)?; // TODO(nickm) handle fatalities here too
158

            
159
    // TODO: Perhaps we want separate timeouts for each phase of this.
160
    // For now, we just use higher-level timeouts in `dirmgr`.
161
    let r = send_request(runtime, req, &mut stream, Some(source.clone())).await;
162

            
163
    if should_retire_circ(&r) {
164
        retire_circ(&circ_mgr, &source, "Partial response");
165
    }
166

            
167
    r
168
}
169

            
170
/// Return true if `result` holds an error indicating that we should retire the
171
/// circuit used for the corresponding request.
172
fn should_retire_circ(result: &Result<DirResponse>) -> bool {
173
    match result {
174
        Err(e) => e.should_retire_circ(),
175
        Ok(dr) => dr.error().map(RequestError::should_retire_circ) == Some(true),
176
    }
177
}
178

            
179
/// Fetch a Tor directory object from a provided stream.
180
#[deprecated(since = "0.8.1", note = "Use send_request instead.")]
181
pub async fn download<R, S, SP>(
182
    runtime: &SP,
183
    req: &R,
184
    stream: &mut S,
185
    source: Option<SourceInfo>,
186
) -> Result<DirResponse>
187
where
188
    R: request::Requestable + ?Sized,
189
    S: AsyncRead + AsyncWrite + Send + Unpin,
190
    SP: SleepProvider,
191
{
192
    send_request(runtime, req, stream, source).await
193
}
194

            
195
/// Fetch or upload a Tor directory object using the provided stream.
196
///
197
/// To do this, we send a simple HTTP/1.0 request for the described
198
/// object in `req` over `stream`, and then wait for a response.  In
199
/// log messages, we describe the origin of the data as coming from
200
/// `source`.
201
///
202
/// # Notes
203
///
204
/// It's kind of bogus to have a 'source' field here at all; we may
205
/// eventually want to remove it.
206
///
207
/// This function doesn't close the stream; you may want to do that
208
/// yourself.
209
///
210
/// The only error variant returned is [`Error::RequestFailed`].
211
// TODO: should the error return type change to `RequestFailedError`?
212
// If so, that would simplify some code in_dirmgr::bridgedesc.
213
16
pub async fn send_request<R, S, SP>(
214
16
    runtime: &SP,
215
16
    req: &R,
216
16
    stream: &mut S,
217
16
    source: Option<SourceInfo>,
218
16
) -> Result<DirResponse>
219
16
where
220
16
    R: request::Requestable + ?Sized,
221
16
    S: AsyncRead + AsyncWrite + Send + Unpin,
222
16
    SP: SleepProvider,
223
16
{
224
16
    let wrap_err = |error| {
225
8
        Error::RequestFailed(RequestFailedError {
226
8
            source: source.clone(),
227
8
            error,
228
8
        })
229
8
    };
230

            
231
16
    let partial_ok = req.partial_response_body_ok();
232
16
    let maxlen = req.max_response_len();
233
16
    let anonymized = req.anonymized();
234
16
    let req = req.make_request().map_err(wrap_err)?;
235
16
    let encoded = util::encode_request(&req);
236
16

            
237
16
    // Write the request.
238
16
    stream
239
16
        .write_all(encoded.as_bytes())
240
16
        .await
241
16
        .map_err(RequestError::from)
242
16
        .map_err(wrap_err)?;
243
16
    stream
244
16
        .flush()
245
16
        .await
246
16
        .map_err(RequestError::from)
247
16
        .map_err(wrap_err)?;
248

            
249
16
    let mut buffered = BufReader::new(stream);
250

            
251
    // Handle the response
252
    // TODO: should there be a separate timeout here?
253
16
    let header = read_headers(&mut buffered).await.map_err(wrap_err)?;
254
10
    if header.status != Some(200) {
255
2
        return Ok(DirResponse::new(
256
2
            header.status.unwrap_or(0),
257
2
            header.status_message,
258
2
            None,
259
2
            vec![],
260
2
            source,
261
2
        ));
262
8
    }
263

            
264
8
    let mut decoder =
265
8
        get_decoder(buffered, header.encoding.as_deref(), anonymized).map_err(wrap_err)?;
266

            
267
8
    let mut result = Vec::new();
268
8
    let ok = read_and_decompress(runtime, &mut decoder, maxlen, &mut result).await;
269

            
270
8
    let ok = match (partial_ok, ok, result.len()) {
271
2
        (true, Err(e), n) if n > 0 => {
272
2
            // Note that we _don't_ return here: we want the partial response.
273
2
            Err(e)
274
        }
275
2
        (_, Err(e), _) => {
276
2
            return Err(wrap_err(e));
277
        }
278
4
        (_, Ok(()), _) => Ok(()),
279
    };
280

            
281
6
    Ok(DirResponse::new(200, None, ok.err(), result, source))
282
16
}
283

            
284
/// Read and parse HTTP/1 headers from `stream`.
285
24
async fn read_headers<S>(stream: &mut S) -> RequestResult<HeaderStatus>
286
24
where
287
24
    S: AsyncBufRead + Unpin,
288
24
{
289
24
    let mut buf = Vec::with_capacity(1024);
290

            
291
    loop {
292
        // TODO: it's inefficient to do this a line at a time; it would
293
        // probably be better to read until the CRLF CRLF ending of the
294
        // response.  But this should be fast enough.
295
66
        let n = read_until_limited(stream, b'\n', 2048, &mut buf).await?;
296

            
297
        // TODO(nickm): Better maximum and/or let this expand.
298
66
        let mut headers = [httparse::EMPTY_HEADER; 32];
299
66
        let mut response = httparse::Response::new(&mut headers);
300
66

            
301
66
        match response.parse(&buf[..])? {
302
            httparse::Status::Partial => {
303
                // We didn't get a whole response; we may need to try again.
304

            
305
50
                if n == 0 {
306
                    // We hit an EOF; no more progress can be made.
307
6
                    return Err(RequestError::TruncatedHeaders);
308
44
                }
309
44

            
310
44
                // TODO(nickm): Pick a better maximum
311
44
                if buf.len() >= 16384 {
312
2
                    return Err(httparse::Error::TooManyHeaders.into());
313
42
                }
314
            }
315
14
            httparse::Status::Complete(n_parsed) => {
316
14
                if response.code != Some(200) {
317
4
                    return Ok(HeaderStatus {
318
4
                        status: response.code,
319
4
                        status_message: response.reason.map(str::to_owned),
320
4
                        encoding: None,
321
4
                    });
322
10
                }
323
10
                let encoding = if let Some(enc) = response
324
10
                    .headers
325
10
                    .iter()
326
10
                    .find(|h| h.name == "Content-Encoding")
327
                {
328
6
                    Some(String::from_utf8(enc.value.to_vec())?)
329
                } else {
330
4
                    None
331
                };
332
                /*
333
                if let Some(clen) = response.headers.iter().find(|h| h.name == "Content-Length") {
334
                    let clen = std::str::from_utf8(clen.value)?;
335
                    length = Some(clen.parse()?);
336
                }
337
                 */
338
10
                assert!(n_parsed == buf.len());
339
10
                return Ok(HeaderStatus {
340
10
                    status: Some(200),
341
10
                    status_message: None,
342
10
                    encoding,
343
10
                });
344
            }
345
        }
346
42
        if n == 0 {
347
            return Err(RequestError::TruncatedHeaders);
348
42
        }
349
    }
350
24
}
351

            
352
/// Return value from read_headers
353
#[derive(Debug, Clone)]
354
struct HeaderStatus {
355
    /// HTTP status code.
356
    status: Option<u16>,
357
    /// HTTP status message associated with the status code.
358
    status_message: Option<String>,
359
    /// The Content-Encoding header, if any.
360
    encoding: Option<String>,
361
}
362

            
363
/// Helper: download directory information from `stream` and
364
/// decompress it into a result buffer.  Assumes that `buf` is empty.
365
///
366
/// If we get more than maxlen bytes after decompression, give an error.
367
///
368
/// Returns the status of our download attempt, stores any data that
369
/// we were able to download into `result`.  Existing contents of
370
/// `result` are overwritten.
371
22
async fn read_and_decompress<S, SP>(
372
22
    runtime: &SP,
373
22
    mut stream: S,
374
22
    maxlen: usize,
375
22
    result: &mut Vec<u8>,
376
22
) -> RequestResult<()>
377
22
where
378
22
    S: AsyncRead + Unpin,
379
22
    SP: SleepProvider,
380
22
{
381
22
    let buffer_window_size = 1024;
382
22
    let mut written_total: usize = 0;
383
22
    // TODO(nickm): This should be an option, and is maybe too long.
384
22
    // Though for some users it may be too short?
385
22
    let read_timeout = Duration::from_secs(10);
386
22
    let timer = runtime.sleep(read_timeout).fuse();
387
22
    futures::pin_mut!(timer);
388

            
389
    loop {
390
        // allocate buffer for next read
391
498
        result.resize(written_total + buffer_window_size, 0);
392
498
        let buf: &mut [u8] = &mut result[written_total..written_total + buffer_window_size];
393

            
394
498
        let status = futures::select! {
395
498
            status = stream.read(buf).fuse() => status,
396
            _ = timer => {
397
                result.resize(written_total, 0); // truncate as needed
398
                return Err(RequestError::DirTimeout);
399
            }
400
        };
401
498
        let written_in_this_loop = match status {
402
492
            Ok(n) => n,
403
6
            Err(other) => {
404
6
                result.resize(written_total, 0); // truncate as needed
405
6
                return Err(other.into());
406
            }
407
        };
408

            
409
492
        written_total += written_in_this_loop;
410
492

            
411
492
        // exit conditions below
412
492

            
413
492
        if written_in_this_loop == 0 {
414
            /*
415
            in case we read less than `buffer_window_size` in last `read`
416
            we need to shrink result because otherwise we'll return those
417
            un-read 0s
418
            */
419
14
            if written_total < result.len() {
420
14
                result.resize(written_total, 0);
421
14
            }
422
14
            return Ok(());
423
478
        }
424
478

            
425
478
        // TODO: It would be good to detect compression bombs, but
426
478
        // that would require access to the internal stream, which
427
478
        // would in turn require some tricky programming.  For now, we
428
478
        // use the maximum length here to prevent an attacker from
429
478
        // filling our RAM.
430
478
        if written_total > maxlen {
431
2
            result.resize(maxlen, 0);
432
2
            return Err(RequestError::ResponseTooLong(written_total));
433
476
        }
434
    }
435
22
}
436

            
437
/// Retire a directory circuit because of an error we've encountered on it.
438
fn retire_circ<R>(circ_mgr: &Arc<CircMgr<R>>, source_info: &SourceInfo, error: &str)
439
where
440
    R: Runtime,
441
{
442
    let id = source_info.unique_circ_id();
443
    info!(
444
        "{}: Retiring circuit because of directory failure: {}",
445
        &id, &error
446
    );
447
    circ_mgr.retire_circ(id);
448
}
449

            
450
/// As AsyncBufReadExt::read_until, but stops after reading `max` bytes.
451
///
452
/// Note that this function might not actually read any byte of value
453
/// `byte`, since EOF might occur, or we might fill the buffer.
454
///
455
/// A return value of 0 indicates an end-of-file.
456
72
async fn read_until_limited<S>(
457
72
    stream: &mut S,
458
72
    byte: u8,
459
72
    max: usize,
460
72
    buf: &mut Vec<u8>,
461
72
) -> std::io::Result<usize>
462
72
where
463
72
    S: AsyncBufRead + Unpin,
464
72
{
465
72
    let mut n_added = 0;
466
    loop {
467
230
        let data = stream.fill_buf().await?;
468
230
        if data.is_empty() {
469
            // End-of-file has been reached.
470
12
            return Ok(n_added);
471
218
        }
472
218
        debug_assert!(n_added < max);
473
218
        let remaining_space = max - n_added;
474
218
        let (available, found_byte) = match memchr(byte, data) {
475
46
            Some(idx) => (idx + 1, true),
476
172
            None => (data.len(), false),
477
        };
478
218
        debug_assert!(available >= 1);
479
218
        let n_to_copy = std::cmp::min(remaining_space, available);
480
218
        buf.extend(&data[..n_to_copy]);
481
218
        stream.consume_unpin(n_to_copy);
482
218
        n_added += n_to_copy;
483
218
        if found_byte || n_added == max {
484
60
            return Ok(n_added);
485
158
        }
486
    }
487
72
}
488

            
489
/// Helper: Return a boxed decoder object that wraps the stream  $s.
490
macro_rules! decoder {
491
    ($dec:ident, $s:expr) => {{
492
        let mut decoder = $dec::new($s);
493
        decoder.multiple_members(true);
494
        Ok(Box::new(decoder))
495
    }};
496
}
497

            
498
/// Wrap `stream` in an appropriate type to undo the content encoding
499
/// as described in `encoding`.
500
24
fn get_decoder<'a, S: AsyncBufRead + Unpin + Send + 'a>(
501
24
    stream: S,
502
24
    encoding: Option<&str>,
503
24
    anonymized: AnonymizedRequest,
504
24
) -> RequestResult<Box<dyn AsyncRead + Unpin + Send + 'a>> {
505
    use AnonymizedRequest::Direct;
506
24
    match (encoding, anonymized) {
507
20
        (None | Some("identity"), _) => Ok(Box::new(stream)),
508
14
        (Some("deflate"), _) => decoder!(ZlibDecoder, stream),
509
        // We only admit to supporting these on a direct connection; otherwise,
510
        // a hostile directory could send them back even though we hadn't
511
        // requested them.
512
        #[cfg(feature = "xz")]
513
6
        (Some("x-tor-lzma"), Direct) => decoder!(XzDecoder, stream),
514
        #[cfg(feature = "zstd")]
515
4
        (Some("x-zstd"), Direct) => decoder!(ZstdDecoder, stream),
516
2
        (Some(other), _) => Err(RequestError::ContentEncoding(other.into())),
517
    }
518
24
}
519

            
520
#[cfg(test)]
521
mod test {
522
    // @@ begin test lint list maintained by maint/add_warning @@
523
    #![allow(clippy::bool_assert_comparison)]
524
    #![allow(clippy::clone_on_copy)]
525
    #![allow(clippy::dbg_macro)]
526
    #![allow(clippy::mixed_attributes_style)]
527
    #![allow(clippy::print_stderr)]
528
    #![allow(clippy::print_stdout)]
529
    #![allow(clippy::single_char_pattern)]
530
    #![allow(clippy::unwrap_used)]
531
    #![allow(clippy::unchecked_duration_subtraction)]
532
    #![allow(clippy::useless_vec)]
533
    #![allow(clippy::needless_pass_by_value)]
534
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
535
    use super::*;
536
    use tor_rtmock::io::stream_pair;
537

            
538
    #[allow(deprecated)] // TODO #1885
539
    use tor_rtmock::time::MockSleepProvider;
540

            
541
    use futures_await_test::async_test;
542

            
543
    #[async_test]
544
    async fn test_read_until_limited() -> RequestResult<()> {
545
        let mut out = Vec::new();
546
        let bytes = b"This line eventually ends\nthen comes another\n";
547

            
548
        // Case 1: find a whole line.
549
        let mut s = &bytes[..];
550
        let res = read_until_limited(&mut s, b'\n', 100, &mut out).await;
551
        assert_eq!(res?, 26);
552
        assert_eq!(&out[..], b"This line eventually ends\n");
553

            
554
        // Case 2: reach the limit.
555
        let mut s = &bytes[..];
556
        out.clear();
557
        let res = read_until_limited(&mut s, b'\n', 10, &mut out).await;
558
        assert_eq!(res?, 10);
559
        assert_eq!(&out[..], b"This line ");
560

            
561
        // Case 3: reach EOF.
562
        let mut s = &bytes[..];
563
        out.clear();
564
        let res = read_until_limited(&mut s, b'Z', 100, &mut out).await;
565
        assert_eq!(res?, 45);
566
        assert_eq!(&out[..], &bytes[..]);
567

            
568
        Ok(())
569
    }
570

            
571
    // Basic decompression wrapper.
572
    async fn decomp_basic(
573
        encoding: Option<&str>,
574
        data: &[u8],
575
        maxlen: usize,
576
    ) -> (RequestResult<()>, Vec<u8>) {
577
        // We don't need to do anything fancy here, since we aren't simulating
578
        // a timeout.
579
        #[allow(deprecated)] // TODO #1885
580
        let mock_time = MockSleepProvider::new(std::time::SystemTime::now());
581

            
582
        let mut output = Vec::new();
583
        let mut stream = match get_decoder(data, encoding, AnonymizedRequest::Direct) {
584
            Ok(s) => s,
585
            Err(e) => return (Err(e), output),
586
        };
587

            
588
        let r = read_and_decompress(&mock_time, &mut stream, maxlen, &mut output).await;
589

            
590
        (r, output)
591
    }
592

            
593
    #[async_test]
594
    async fn decompress_identity() -> RequestResult<()> {
595
        let mut text = Vec::new();
596
        for _ in 0..1000 {
597
            text.extend(b"This is a string with a nontrivial length that we'll use to make sure that the loop is executed more than once.");
598
        }
599

            
600
        let limit = 10 << 20;
601
        let (s, r) = decomp_basic(None, &text[..], limit).await;
602
        s?;
603
        assert_eq!(r, text);
604

            
605
        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
606
        s?;
607
        assert_eq!(r, text);
608

            
609
        // Try truncated result
610
        let limit = 100;
611
        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
612
        assert!(s.is_err());
613
        assert_eq!(r, &text[..100]);
614

            
615
        Ok(())
616
    }
617

            
618
    #[async_test]
619
    async fn decomp_zlib() -> RequestResult<()> {
620
        let compressed =
621
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap();
622

            
623
        let limit = 10 << 20;
624
        let (s, r) = decomp_basic(Some("deflate"), &compressed, limit).await;
625
        s?;
626
        assert_eq!(r, b"One fish Two fish Red fish Blue fish");
627

            
628
        Ok(())
629
    }
630

            
631
    #[cfg(feature = "zstd")]
632
    #[async_test]
633
    async fn decomp_zstd() -> RequestResult<()> {
634
        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
635
        let limit = 10 << 20;
636
        let (s, r) = decomp_basic(Some("x-zstd"), &compressed, limit).await;
637
        s?;
638
        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
639

            
640
        Ok(())
641
    }
642

            
643
    #[cfg(feature = "xz")]
644
    #[async_test]
645
    async fn decomp_xz2() -> RequestResult<()> {
646
        // Not so good at tiny files...
647
        let compressed = hex::decode("fd377a585a000004e6d6b446020021011c00000010cf58cce00024001d5d00279b88a202ca8612cfb3c19c87c34248a570451e4851d3323d34ab8000000000000901af64854c91f600013925d6ec06651fb6f37d010000000004595a").unwrap();
648
        let limit = 10 << 20;
649
        let (s, r) = decomp_basic(Some("x-tor-lzma"), &compressed, limit).await;
650
        s?;
651
        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
652

            
653
        Ok(())
654
    }
655

            
656
    #[async_test]
657
    async fn decomp_unknown() {
658
        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
659
        let limit = 10 << 20;
660
        let (s, _r) = decomp_basic(Some("x-proprietary-rle"), &compressed, limit).await;
661

            
662
        assert!(matches!(s, Err(RequestError::ContentEncoding(_))));
663
    }
664

            
665
    #[async_test]
666
    async fn decomp_bad_data() {
667
        let compressed = b"This is not good zlib data";
668
        let limit = 10 << 20;
669
        let (s, _r) = decomp_basic(Some("deflate"), compressed, limit).await;
670

            
671
        // This should possibly be a different type in the future.
672
        assert!(matches!(s, Err(RequestError::IoError(_))));
673
    }
674

            
675
    #[async_test]
676
    async fn headers_ok() -> RequestResult<()> {
677
        let text = b"HTTP/1.0 200 OK\r\nDate: ignored\r\nContent-Encoding: Waffles\r\n\r\n";
678

            
679
        let mut s = &text[..];
680
        let h = read_headers(&mut s).await?;
681

            
682
        assert_eq!(h.status, Some(200));
683
        assert_eq!(h.encoding.as_deref(), Some("Waffles"));
684

            
685
        // now try truncated
686
        let mut s = &text[..15];
687
        let h = read_headers(&mut s).await;
688
        assert!(matches!(h, Err(RequestError::TruncatedHeaders)));
689

            
690
        // now try with no encoding.
691
        let text = b"HTTP/1.0 404 Not found\r\n\r\n";
692
        let mut s = &text[..];
693
        let h = read_headers(&mut s).await?;
694

            
695
        assert_eq!(h.status, Some(404));
696
        assert!(h.encoding.is_none());
697

            
698
        Ok(())
699
    }
700

            
701
    #[async_test]
702
    async fn headers_bogus() -> Result<()> {
703
        let text = b"HTTP/999.0 WHAT EVEN\r\n\r\n";
704
        let mut s = &text[..];
705
        let h = read_headers(&mut s).await;
706

            
707
        assert!(h.is_err());
708
        assert!(matches!(h, Err(RequestError::HttparseError(_))));
709
        Ok(())
710
    }
711

            
712
    /// Run a trivial download example with a response provided as a binary
713
    /// string.
714
    ///
715
    /// Return the directory response (if any) and the request as encoded (if
716
    /// any.)
717
    fn run_download_test<Req: request::Requestable>(
718
        req: Req,
719
        response: &[u8],
720
    ) -> (Result<DirResponse>, RequestResult<Vec<u8>>) {
721
        let (mut s1, s2) = stream_pair();
722
        let (mut s2_r, mut s2_w) = s2.split();
723

            
724
        tor_rtcompat::test_with_one_runtime!(|rt| async move {
725
            let rt2 = rt.clone();
726
            let (v1, v2, v3): (
727
                Result<DirResponse>,
728
                RequestResult<Vec<u8>>,
729
                RequestResult<()>,
730
            ) = futures::join!(
731
                async {
732
                    // Run the download function.
733
                    let r = send_request(&rt, &req, &mut s1, None).await;
734
                    s1.close().await.map_err(|error| {
735
                        Error::RequestFailed(RequestFailedError {
736
                            source: None,
737
                            error: error.into(),
738
                        })
739
                    })?;
740
                    r
741
                },
742
                async {
743
                    // Take the request from the client, and return it in "v2"
744
                    let mut v = Vec::new();
745
                    s2_r.read_to_end(&mut v).await?;
746
                    Ok(v)
747
                },
748
                async {
749
                    // Send back a response.
750
                    s2_w.write_all(response).await?;
751
                    // We wait a moment to give the other side time to notice it
752
                    // has data.
753
                    //
754
                    // (Tentative diagnosis: The `async-compress` crate seems to
755
                    // be behave differently depending on whether the "close"
756
                    // comes right after the incomplete data or whether it comes
757
                    // after a delay.  If there's a delay, it notices the
758
                    // truncated data and tells us about it. But when there's
759
                    // _no_delay, it treats the data as an error and doesn't
760
                    // tell our code.)
761

            
762
                    // TODO: sleeping in tests is not great.
763
                    rt2.sleep(Duration::from_millis(50)).await;
764
                    s2_w.close().await?;
765
                    Ok(())
766
                }
767
            );
768

            
769
            assert!(v3.is_ok());
770

            
771
            (v1, v2)
772
        })
773
    }
774

            
775
    #[test]
776
    fn test_send_request() -> RequestResult<()> {
777
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
778

            
779
        let (response, request) = run_download_test(
780
            req,
781
            b"HTTP/1.0 200 OK\r\n\r\nThis is where the descs would go.",
782
        );
783

            
784
        let request = request?;
785
        assert!(request[..].starts_with(
786
            b"GET /tor/micro/d/CQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQk HTTP/1.0\r\n"
787
        ));
788

            
789
        let response = response.unwrap();
790
        assert_eq!(response.status_code(), 200);
791
        assert!(!response.is_partial());
792
        assert!(response.error().is_none());
793
        assert!(response.source().is_none());
794
        let out_ref = response.output_unchecked();
795
        assert_eq!(out_ref, b"This is where the descs would go.");
796
        let out = response.into_output_unchecked();
797
        assert_eq!(&out, b"This is where the descs would go.");
798

            
799
        Ok(())
800
    }
801

            
802
    #[test]
803
    fn test_download_truncated() {
804
        // Request only one md, so "partial ok" will not be set.
805
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
806
        let mut response_text: Vec<u8> =
807
            (*b"HTTP/1.0 200 OK\r\nContent-Encoding: deflate\r\n\r\n").into();
808
        // "One fish two fish" as above twice, but truncated the second time
809
        response_text.extend(
810
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap(),
811
        );
812
        response_text.extend(
813
            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5").unwrap(),
814
        );
815
        let (response, request) = run_download_test(req, &response_text);
816
        assert!(request.is_ok());
817
        assert!(response.is_err()); // The whole download should fail, since partial_ok wasn't set.
818

            
819
        // request two microdescs, so "partial_ok" will be set.
820
        let req: request::MicrodescRequest = vec![[9; 32]; 2].into_iter().collect();
821

            
822
        let (response, request) = run_download_test(req, &response_text);
823
        assert!(request.is_ok());
824

            
825
        let response = response.unwrap();
826
        assert_eq!(response.status_code(), 200);
827
        assert!(response.error().is_some());
828
        assert!(response.is_partial());
829
        assert!(response.output_unchecked().len() < 37 * 2);
830
        assert!(response.output_unchecked().starts_with(b"One fish"));
831
    }
832

            
833
    #[test]
834
    fn test_404() {
835
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
836
        let response_text = b"HTTP/1.0 418 I'm a teapot\r\n\r\n";
837
        let (response, _request) = run_download_test(req, response_text);
838

            
839
        assert_eq!(response.unwrap().status_code(), 418);
840
    }
841

            
842
    #[test]
843
    fn test_headers_truncated() {
844
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
845
        let response_text = b"HTTP/1.0 404 truncation happens here\r\n";
846
        let (response, _request) = run_download_test(req, response_text);
847

            
848
        assert!(matches!(
849
            response,
850
            Err(Error::RequestFailed(RequestFailedError {
851
                error: RequestError::TruncatedHeaders,
852
                ..
853
            }))
854
        ));
855

            
856
        // Try a completely empty response.
857
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
858
        let response_text = b"";
859
        let (response, _request) = run_download_test(req, response_text);
860

            
861
        assert!(matches!(
862
            response,
863
            Err(Error::RequestFailed(RequestFailedError {
864
                error: RequestError::TruncatedHeaders,
865
                ..
866
            }))
867
        ));
868
    }
869

            
870
    #[test]
871
    fn test_headers_too_long() {
872
        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
873
        let mut response_text: Vec<u8> = (*b"HTTP/1.0 418 I'm a teapot\r\nX-Too-Many-As: ").into();
874
        response_text.resize(16384, b'A');
875
        let (response, _request) = run_download_test(req, &response_text);
876

            
877
        assert!(response.as_ref().unwrap_err().should_retire_circ());
878
        assert!(matches!(
879
            response,
880
            Err(Error::RequestFailed(RequestFailedError {
881
                error: RequestError::HttparseError(_),
882
                ..
883
            }))
884
        ));
885
    }
886

            
887
    // TODO: test with bad utf-8
888
}