tor_dirclient/
lib.rs

1#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))]
2#![doc = include_str!("../README.md")]
3// @@ begin lint list maintained by maint/add_warning @@
4#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6#![warn(missing_docs)]
7#![warn(noop_method_call)]
8#![warn(unreachable_pub)]
9#![warn(clippy::all)]
10#![deny(clippy::await_holding_lock)]
11#![deny(clippy::cargo_common_metadata)]
12#![deny(clippy::cast_lossless)]
13#![deny(clippy::checked_conversions)]
14#![warn(clippy::cognitive_complexity)]
15#![deny(clippy::debug_assert_with_mut_call)]
16#![deny(clippy::exhaustive_enums)]
17#![deny(clippy::exhaustive_structs)]
18#![deny(clippy::expl_impl_clone_on_copy)]
19#![deny(clippy::fallible_impl_from)]
20#![deny(clippy::implicit_clone)]
21#![deny(clippy::large_stack_arrays)]
22#![warn(clippy::manual_ok_or)]
23#![deny(clippy::missing_docs_in_private_items)]
24#![warn(clippy::needless_borrow)]
25#![warn(clippy::needless_pass_by_value)]
26#![warn(clippy::option_option)]
27#![deny(clippy::print_stderr)]
28#![deny(clippy::print_stdout)]
29#![warn(clippy::rc_buffer)]
30#![deny(clippy::ref_option_ref)]
31#![warn(clippy::semicolon_if_nothing_returned)]
32#![warn(clippy::trait_duplication_in_bounds)]
33#![deny(clippy::unchecked_duration_subtraction)]
34#![deny(clippy::unnecessary_wraps)]
35#![warn(clippy::unseparated_literal_suffix)]
36#![deny(clippy::unwrap_used)]
37#![deny(clippy::mod_module_files)]
38#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39#![allow(clippy::uninlined_format_args)]
40#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43#![allow(clippy::needless_lifetimes)] // See arti#1765
44//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
45
46// TODO probably remove this at some point - see tpo/core/arti#1060
47#![cfg_attr(
48    not(all(feature = "full", feature = "experimental")),
49    allow(unused_imports)
50)]
51
52mod err;
53pub mod request;
54mod response;
55mod util;
56
57use tor_circmgr::{CircMgr, DirInfo};
58use tor_error::bad_api_usage;
59use tor_rtcompat::{Runtime, SleepProvider, SleepProviderExt};
60
61// Zlib is required; the others are optional.
62#[cfg(feature = "xz")]
63use async_compression::futures::bufread::XzDecoder;
64use async_compression::futures::bufread::ZlibDecoder;
65#[cfg(feature = "zstd")]
66use async_compression::futures::bufread::ZstdDecoder;
67
68use futures::io::{
69    AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader,
70};
71use futures::FutureExt;
72use memchr::memchr;
73use std::sync::Arc;
74use std::time::Duration;
75use tracing::info;
76
77pub use err::{Error, RequestError, RequestFailedError};
78pub use response::{DirResponse, SourceInfo};
79
80/// Type for results returned in this crate.
81pub type Result<T> = std::result::Result<T, Error>;
82
83/// Type for internal results  containing a RequestError.
84pub type RequestResult<T> = std::result::Result<T, RequestError>;
85
86/// Flag to declare whether a request is anonymized or not.
87///
88/// Some requests (like those to download onion service descriptors) are always
89/// anonymized, and should never be sent in a way that leaks information about
90/// our settings or configuration.
91#[derive(Copy, Clone, Debug, Eq, PartialEq)]
92#[non_exhaustive]
93pub enum AnonymizedRequest {
94    /// This request should not leak any information about our configuration.
95    Anonymized,
96    /// This request is allowed to include information about our capabilities.
97    Direct,
98}
99
100/// Fetch the resource described by `req` over the Tor network.
101///
102/// Circuits are built or found using `circ_mgr`, using paths
103/// constructed using `dirinfo`.
104///
105/// For more fine-grained control over the circuit and stream used,
106/// construct them yourself, and then call [`send_request`] instead.
107///
108/// # TODO
109///
110/// This is the only function in this crate that knows about CircMgr and
111/// DirInfo.  Perhaps this function should move up a level into DirMgr?
112pub async fn get_resource<CR, R, SP>(
113    req: &CR,
114    dirinfo: DirInfo<'_>,
115    runtime: &SP,
116    circ_mgr: Arc<CircMgr<R>>,
117) -> Result<DirResponse>
118where
119    CR: request::Requestable + ?Sized,
120    R: Runtime,
121    SP: SleepProvider,
122{
123    let circuit = circ_mgr.get_or_launch_dir(dirinfo).await?;
124
125    if req.anonymized() == AnonymizedRequest::Anonymized {
126        return Err(bad_api_usage!("Tried to use get_resource for an anonymized request").into());
127    }
128
129    // TODO(nickm) This should be an option, and is too long.
130    let begin_timeout = Duration::from_secs(5);
131    let source = match SourceInfo::from_circuit(&circuit) {
132        Ok(source) => source,
133        Err(e) => {
134            return Err(Error::RequestFailed(RequestFailedError {
135                source: None,
136                error: e.into(),
137            }));
138        }
139    };
140
141    let wrap_err = |error| {
142        Error::RequestFailed(RequestFailedError {
143            source: Some(source.clone()),
144            error,
145        })
146    };
147
148    req.check_circuit(&circuit).await.map_err(wrap_err)?;
149
150    // Launch the stream.
151    let mut stream = runtime
152        .timeout(begin_timeout, circuit.begin_dir_stream())
153        .await
154        .map_err(RequestError::from)
155        .map_err(wrap_err)?
156        .map_err(RequestError::from)
157        .map_err(wrap_err)?; // TODO(nickm) handle fatalities here too
158
159    // TODO: Perhaps we want separate timeouts for each phase of this.
160    // For now, we just use higher-level timeouts in `dirmgr`.
161    let r = send_request(runtime, req, &mut stream, Some(source.clone())).await;
162
163    if should_retire_circ(&r) {
164        retire_circ(&circ_mgr, &source, "Partial response");
165    }
166
167    r
168}
169
170/// Return true if `result` holds an error indicating that we should retire the
171/// circuit used for the corresponding request.
172fn should_retire_circ(result: &Result<DirResponse>) -> bool {
173    match result {
174        Err(e) => e.should_retire_circ(),
175        Ok(dr) => dr.error().map(RequestError::should_retire_circ) == Some(true),
176    }
177}
178
179/// Fetch a Tor directory object from a provided stream.
180#[deprecated(since = "0.8.1", note = "Use send_request instead.")]
181pub async fn download<R, S, SP>(
182    runtime: &SP,
183    req: &R,
184    stream: &mut S,
185    source: Option<SourceInfo>,
186) -> Result<DirResponse>
187where
188    R: request::Requestable + ?Sized,
189    S: AsyncRead + AsyncWrite + Send + Unpin,
190    SP: SleepProvider,
191{
192    send_request(runtime, req, stream, source).await
193}
194
195/// Fetch or upload a Tor directory object using the provided stream.
196///
197/// To do this, we send a simple HTTP/1.0 request for the described
198/// object in `req` over `stream`, and then wait for a response.  In
199/// log messages, we describe the origin of the data as coming from
200/// `source`.
201///
202/// # Notes
203///
204/// It's kind of bogus to have a 'source' field here at all; we may
205/// eventually want to remove it.
206///
207/// This function doesn't close the stream; you may want to do that
208/// yourself.
209///
210/// The only error variant returned is [`Error::RequestFailed`].
211// TODO: should the error return type change to `RequestFailedError`?
212// If so, that would simplify some code in_dirmgr::bridgedesc.
213pub async fn send_request<R, S, SP>(
214    runtime: &SP,
215    req: &R,
216    stream: &mut S,
217    source: Option<SourceInfo>,
218) -> Result<DirResponse>
219where
220    R: request::Requestable + ?Sized,
221    S: AsyncRead + AsyncWrite + Send + Unpin,
222    SP: SleepProvider,
223{
224    let wrap_err = |error| {
225        Error::RequestFailed(RequestFailedError {
226            source: source.clone(),
227            error,
228        })
229    };
230
231    let partial_ok = req.partial_response_body_ok();
232    let maxlen = req.max_response_len();
233    let anonymized = req.anonymized();
234    let req = req.make_request().map_err(wrap_err)?;
235    let encoded = util::encode_request(&req);
236
237    // Write the request.
238    stream
239        .write_all(encoded.as_bytes())
240        .await
241        .map_err(RequestError::from)
242        .map_err(wrap_err)?;
243    stream
244        .flush()
245        .await
246        .map_err(RequestError::from)
247        .map_err(wrap_err)?;
248
249    let mut buffered = BufReader::new(stream);
250
251    // Handle the response
252    // TODO: should there be a separate timeout here?
253    let header = read_headers(&mut buffered).await.map_err(wrap_err)?;
254    if header.status != Some(200) {
255        return Ok(DirResponse::new(
256            header.status.unwrap_or(0),
257            header.status_message,
258            None,
259            vec![],
260            source,
261        ));
262    }
263
264    let mut decoder =
265        get_decoder(buffered, header.encoding.as_deref(), anonymized).map_err(wrap_err)?;
266
267    let mut result = Vec::new();
268    let ok = read_and_decompress(runtime, &mut decoder, maxlen, &mut result).await;
269
270    let ok = match (partial_ok, ok, result.len()) {
271        (true, Err(e), n) if n > 0 => {
272            // Note that we _don't_ return here: we want the partial response.
273            Err(e)
274        }
275        (_, Err(e), _) => {
276            return Err(wrap_err(e));
277        }
278        (_, Ok(()), _) => Ok(()),
279    };
280
281    Ok(DirResponse::new(200, None, ok.err(), result, source))
282}
283
284/// Read and parse HTTP/1 headers from `stream`.
285async fn read_headers<S>(stream: &mut S) -> RequestResult<HeaderStatus>
286where
287    S: AsyncBufRead + Unpin,
288{
289    let mut buf = Vec::with_capacity(1024);
290
291    loop {
292        // TODO: it's inefficient to do this a line at a time; it would
293        // probably be better to read until the CRLF CRLF ending of the
294        // response.  But this should be fast enough.
295        let n = read_until_limited(stream, b'\n', 2048, &mut buf).await?;
296
297        // TODO(nickm): Better maximum and/or let this expand.
298        let mut headers = [httparse::EMPTY_HEADER; 32];
299        let mut response = httparse::Response::new(&mut headers);
300
301        match response.parse(&buf[..])? {
302            httparse::Status::Partial => {
303                // We didn't get a whole response; we may need to try again.
304
305                if n == 0 {
306                    // We hit an EOF; no more progress can be made.
307                    return Err(RequestError::TruncatedHeaders);
308                }
309
310                // TODO(nickm): Pick a better maximum
311                if buf.len() >= 16384 {
312                    return Err(httparse::Error::TooManyHeaders.into());
313                }
314            }
315            httparse::Status::Complete(n_parsed) => {
316                if response.code != Some(200) {
317                    return Ok(HeaderStatus {
318                        status: response.code,
319                        status_message: response.reason.map(str::to_owned),
320                        encoding: None,
321                    });
322                }
323                let encoding = if let Some(enc) = response
324                    .headers
325                    .iter()
326                    .find(|h| h.name == "Content-Encoding")
327                {
328                    Some(String::from_utf8(enc.value.to_vec())?)
329                } else {
330                    None
331                };
332                /*
333                if let Some(clen) = response.headers.iter().find(|h| h.name == "Content-Length") {
334                    let clen = std::str::from_utf8(clen.value)?;
335                    length = Some(clen.parse()?);
336                }
337                 */
338                assert!(n_parsed == buf.len());
339                return Ok(HeaderStatus {
340                    status: Some(200),
341                    status_message: None,
342                    encoding,
343                });
344            }
345        }
346        if n == 0 {
347            return Err(RequestError::TruncatedHeaders);
348        }
349    }
350}
351
352/// Return value from read_headers
353#[derive(Debug, Clone)]
354struct HeaderStatus {
355    /// HTTP status code.
356    status: Option<u16>,
357    /// HTTP status message associated with the status code.
358    status_message: Option<String>,
359    /// The Content-Encoding header, if any.
360    encoding: Option<String>,
361}
362
363/// Helper: download directory information from `stream` and
364/// decompress it into a result buffer.  Assumes that `buf` is empty.
365///
366/// If we get more than maxlen bytes after decompression, give an error.
367///
368/// Returns the status of our download attempt, stores any data that
369/// we were able to download into `result`.  Existing contents of
370/// `result` are overwritten.
371async fn read_and_decompress<S, SP>(
372    runtime: &SP,
373    mut stream: S,
374    maxlen: usize,
375    result: &mut Vec<u8>,
376) -> RequestResult<()>
377where
378    S: AsyncRead + Unpin,
379    SP: SleepProvider,
380{
381    let buffer_window_size = 1024;
382    let mut written_total: usize = 0;
383    // TODO(nickm): This should be an option, and is maybe too long.
384    // Though for some users it may be too short?
385    let read_timeout = Duration::from_secs(10);
386    let timer = runtime.sleep(read_timeout).fuse();
387    futures::pin_mut!(timer);
388
389    loop {
390        // allocate buffer for next read
391        result.resize(written_total + buffer_window_size, 0);
392        let buf: &mut [u8] = &mut result[written_total..written_total + buffer_window_size];
393
394        let status = futures::select! {
395            status = stream.read(buf).fuse() => status,
396            _ = timer => {
397                result.resize(written_total, 0); // truncate as needed
398                return Err(RequestError::DirTimeout);
399            }
400        };
401        let written_in_this_loop = match status {
402            Ok(n) => n,
403            Err(other) => {
404                result.resize(written_total, 0); // truncate as needed
405                return Err(other.into());
406            }
407        };
408
409        written_total += written_in_this_loop;
410
411        // exit conditions below
412
413        if written_in_this_loop == 0 {
414            /*
415            in case we read less than `buffer_window_size` in last `read`
416            we need to shrink result because otherwise we'll return those
417            un-read 0s
418            */
419            if written_total < result.len() {
420                result.resize(written_total, 0);
421            }
422            return Ok(());
423        }
424
425        // TODO: It would be good to detect compression bombs, but
426        // that would require access to the internal stream, which
427        // would in turn require some tricky programming.  For now, we
428        // use the maximum length here to prevent an attacker from
429        // filling our RAM.
430        if written_total > maxlen {
431            result.resize(maxlen, 0);
432            return Err(RequestError::ResponseTooLong(written_total));
433        }
434    }
435}
436
437/// Retire a directory circuit because of an error we've encountered on it.
438fn retire_circ<R>(circ_mgr: &Arc<CircMgr<R>>, source_info: &SourceInfo, error: &str)
439where
440    R: Runtime,
441{
442    let id = source_info.unique_circ_id();
443    info!(
444        "{}: Retiring circuit because of directory failure: {}",
445        &id, &error
446    );
447    circ_mgr.retire_circ(id);
448}
449
450/// As AsyncBufReadExt::read_until, but stops after reading `max` bytes.
451///
452/// Note that this function might not actually read any byte of value
453/// `byte`, since EOF might occur, or we might fill the buffer.
454///
455/// A return value of 0 indicates an end-of-file.
456async fn read_until_limited<S>(
457    stream: &mut S,
458    byte: u8,
459    max: usize,
460    buf: &mut Vec<u8>,
461) -> std::io::Result<usize>
462where
463    S: AsyncBufRead + Unpin,
464{
465    let mut n_added = 0;
466    loop {
467        let data = stream.fill_buf().await?;
468        if data.is_empty() {
469            // End-of-file has been reached.
470            return Ok(n_added);
471        }
472        debug_assert!(n_added < max);
473        let remaining_space = max - n_added;
474        let (available, found_byte) = match memchr(byte, data) {
475            Some(idx) => (idx + 1, true),
476            None => (data.len(), false),
477        };
478        debug_assert!(available >= 1);
479        let n_to_copy = std::cmp::min(remaining_space, available);
480        buf.extend(&data[..n_to_copy]);
481        stream.consume_unpin(n_to_copy);
482        n_added += n_to_copy;
483        if found_byte || n_added == max {
484            return Ok(n_added);
485        }
486    }
487}
488
489/// Helper: Return a boxed decoder object that wraps the stream  $s.
490macro_rules! decoder {
491    ($dec:ident, $s:expr) => {{
492        let mut decoder = $dec::new($s);
493        decoder.multiple_members(true);
494        Ok(Box::new(decoder))
495    }};
496}
497
498/// Wrap `stream` in an appropriate type to undo the content encoding
499/// as described in `encoding`.
500fn get_decoder<'a, S: AsyncBufRead + Unpin + Send + 'a>(
501    stream: S,
502    encoding: Option<&str>,
503    anonymized: AnonymizedRequest,
504) -> RequestResult<Box<dyn AsyncRead + Unpin + Send + 'a>> {
505    use AnonymizedRequest::Direct;
506    match (encoding, anonymized) {
507        (None | Some("identity"), _) => Ok(Box::new(stream)),
508        (Some("deflate"), _) => decoder!(ZlibDecoder, stream),
509        // We only admit to supporting these on a direct connection; otherwise,
510        // a hostile directory could send them back even though we hadn't
511        // requested them.
512        #[cfg(feature = "xz")]
513        (Some("x-tor-lzma"), Direct) => decoder!(XzDecoder, stream),
514        #[cfg(feature = "zstd")]
515        (Some("x-zstd"), Direct) => decoder!(ZstdDecoder, stream),
516        (Some(other), _) => Err(RequestError::ContentEncoding(other.into())),
517    }
518}
519
520#[cfg(test)]
521mod test {
522    // @@ begin test lint list maintained by maint/add_warning @@
523    #![allow(clippy::bool_assert_comparison)]
524    #![allow(clippy::clone_on_copy)]
525    #![allow(clippy::dbg_macro)]
526    #![allow(clippy::mixed_attributes_style)]
527    #![allow(clippy::print_stderr)]
528    #![allow(clippy::print_stdout)]
529    #![allow(clippy::single_char_pattern)]
530    #![allow(clippy::unwrap_used)]
531    #![allow(clippy::unchecked_duration_subtraction)]
532    #![allow(clippy::useless_vec)]
533    #![allow(clippy::needless_pass_by_value)]
534    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
535    use super::*;
536    use tor_rtmock::io::stream_pair;
537
538    #[allow(deprecated)] // TODO #1885
539    use tor_rtmock::time::MockSleepProvider;
540
541    use futures_await_test::async_test;
542
543    #[async_test]
544    async fn test_read_until_limited() -> RequestResult<()> {
545        let mut out = Vec::new();
546        let bytes = b"This line eventually ends\nthen comes another\n";
547
548        // Case 1: find a whole line.
549        let mut s = &bytes[..];
550        let res = read_until_limited(&mut s, b'\n', 100, &mut out).await;
551        assert_eq!(res?, 26);
552        assert_eq!(&out[..], b"This line eventually ends\n");
553
554        // Case 2: reach the limit.
555        let mut s = &bytes[..];
556        out.clear();
557        let res = read_until_limited(&mut s, b'\n', 10, &mut out).await;
558        assert_eq!(res?, 10);
559        assert_eq!(&out[..], b"This line ");
560
561        // Case 3: reach EOF.
562        let mut s = &bytes[..];
563        out.clear();
564        let res = read_until_limited(&mut s, b'Z', 100, &mut out).await;
565        assert_eq!(res?, 45);
566        assert_eq!(&out[..], &bytes[..]);
567
568        Ok(())
569    }
570
571    // Basic decompression wrapper.
572    async fn decomp_basic(
573        encoding: Option<&str>,
574        data: &[u8],
575        maxlen: usize,
576    ) -> (RequestResult<()>, Vec<u8>) {
577        // We don't need to do anything fancy here, since we aren't simulating
578        // a timeout.
579        #[allow(deprecated)] // TODO #1885
580        let mock_time = MockSleepProvider::new(std::time::SystemTime::now());
581
582        let mut output = Vec::new();
583        let mut stream = match get_decoder(data, encoding, AnonymizedRequest::Direct) {
584            Ok(s) => s,
585            Err(e) => return (Err(e), output),
586        };
587
588        let r = read_and_decompress(&mock_time, &mut stream, maxlen, &mut output).await;
589
590        (r, output)
591    }
592
593    #[async_test]
594    async fn decompress_identity() -> RequestResult<()> {
595        let mut text = Vec::new();
596        for _ in 0..1000 {
597            text.extend(b"This is a string with a nontrivial length that we'll use to make sure that the loop is executed more than once.");
598        }
599
600        let limit = 10 << 20;
601        let (s, r) = decomp_basic(None, &text[..], limit).await;
602        s?;
603        assert_eq!(r, text);
604
605        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
606        s?;
607        assert_eq!(r, text);
608
609        // Try truncated result
610        let limit = 100;
611        let (s, r) = decomp_basic(Some("identity"), &text[..], limit).await;
612        assert!(s.is_err());
613        assert_eq!(r, &text[..100]);
614
615        Ok(())
616    }
617
618    #[async_test]
619    async fn decomp_zlib() -> RequestResult<()> {
620        let compressed =
621            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap();
622
623        let limit = 10 << 20;
624        let (s, r) = decomp_basic(Some("deflate"), &compressed, limit).await;
625        s?;
626        assert_eq!(r, b"One fish Two fish Red fish Blue fish");
627
628        Ok(())
629    }
630
631    #[cfg(feature = "zstd")]
632    #[async_test]
633    async fn decomp_zstd() -> RequestResult<()> {
634        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
635        let limit = 10 << 20;
636        let (s, r) = decomp_basic(Some("x-zstd"), &compressed, limit).await;
637        s?;
638        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
639
640        Ok(())
641    }
642
643    #[cfg(feature = "xz")]
644    #[async_test]
645    async fn decomp_xz2() -> RequestResult<()> {
646        // Not so good at tiny files...
647        let compressed = hex::decode("fd377a585a000004e6d6b446020021011c00000010cf58cce00024001d5d00279b88a202ca8612cfb3c19c87c34248a570451e4851d3323d34ab8000000000000901af64854c91f600013925d6ec06651fb6f37d010000000004595a").unwrap();
648        let limit = 10 << 20;
649        let (s, r) = decomp_basic(Some("x-tor-lzma"), &compressed, limit).await;
650        s?;
651        assert_eq!(r, b"One fish Two fish Red fish Blue fish\n");
652
653        Ok(())
654    }
655
656    #[async_test]
657    async fn decomp_unknown() {
658        let compressed = hex::decode("28b52ffd24250d0100c84f6e6520666973682054776f526564426c756520666973680a0200600c0e2509478352cb").unwrap();
659        let limit = 10 << 20;
660        let (s, _r) = decomp_basic(Some("x-proprietary-rle"), &compressed, limit).await;
661
662        assert!(matches!(s, Err(RequestError::ContentEncoding(_))));
663    }
664
665    #[async_test]
666    async fn decomp_bad_data() {
667        let compressed = b"This is not good zlib data";
668        let limit = 10 << 20;
669        let (s, _r) = decomp_basic(Some("deflate"), compressed, limit).await;
670
671        // This should possibly be a different type in the future.
672        assert!(matches!(s, Err(RequestError::IoError(_))));
673    }
674
675    #[async_test]
676    async fn headers_ok() -> RequestResult<()> {
677        let text = b"HTTP/1.0 200 OK\r\nDate: ignored\r\nContent-Encoding: Waffles\r\n\r\n";
678
679        let mut s = &text[..];
680        let h = read_headers(&mut s).await?;
681
682        assert_eq!(h.status, Some(200));
683        assert_eq!(h.encoding.as_deref(), Some("Waffles"));
684
685        // now try truncated
686        let mut s = &text[..15];
687        let h = read_headers(&mut s).await;
688        assert!(matches!(h, Err(RequestError::TruncatedHeaders)));
689
690        // now try with no encoding.
691        let text = b"HTTP/1.0 404 Not found\r\n\r\n";
692        let mut s = &text[..];
693        let h = read_headers(&mut s).await?;
694
695        assert_eq!(h.status, Some(404));
696        assert!(h.encoding.is_none());
697
698        Ok(())
699    }
700
701    #[async_test]
702    async fn headers_bogus() -> Result<()> {
703        let text = b"HTTP/999.0 WHAT EVEN\r\n\r\n";
704        let mut s = &text[..];
705        let h = read_headers(&mut s).await;
706
707        assert!(h.is_err());
708        assert!(matches!(h, Err(RequestError::HttparseError(_))));
709        Ok(())
710    }
711
712    /// Run a trivial download example with a response provided as a binary
713    /// string.
714    ///
715    /// Return the directory response (if any) and the request as encoded (if
716    /// any.)
717    fn run_download_test<Req: request::Requestable>(
718        req: Req,
719        response: &[u8],
720    ) -> (Result<DirResponse>, RequestResult<Vec<u8>>) {
721        let (mut s1, s2) = stream_pair();
722        let (mut s2_r, mut s2_w) = s2.split();
723
724        tor_rtcompat::test_with_one_runtime!(|rt| async move {
725            let rt2 = rt.clone();
726            let (v1, v2, v3): (
727                Result<DirResponse>,
728                RequestResult<Vec<u8>>,
729                RequestResult<()>,
730            ) = futures::join!(
731                async {
732                    // Run the download function.
733                    let r = send_request(&rt, &req, &mut s1, None).await;
734                    s1.close().await.map_err(|error| {
735                        Error::RequestFailed(RequestFailedError {
736                            source: None,
737                            error: error.into(),
738                        })
739                    })?;
740                    r
741                },
742                async {
743                    // Take the request from the client, and return it in "v2"
744                    let mut v = Vec::new();
745                    s2_r.read_to_end(&mut v).await?;
746                    Ok(v)
747                },
748                async {
749                    // Send back a response.
750                    s2_w.write_all(response).await?;
751                    // We wait a moment to give the other side time to notice it
752                    // has data.
753                    //
754                    // (Tentative diagnosis: The `async-compress` crate seems to
755                    // be behave differently depending on whether the "close"
756                    // comes right after the incomplete data or whether it comes
757                    // after a delay.  If there's a delay, it notices the
758                    // truncated data and tells us about it. But when there's
759                    // _no_delay, it treats the data as an error and doesn't
760                    // tell our code.)
761
762                    // TODO: sleeping in tests is not great.
763                    rt2.sleep(Duration::from_millis(50)).await;
764                    s2_w.close().await?;
765                    Ok(())
766                }
767            );
768
769            assert!(v3.is_ok());
770
771            (v1, v2)
772        })
773    }
774
775    #[test]
776    fn test_send_request() -> RequestResult<()> {
777        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
778
779        let (response, request) = run_download_test(
780            req,
781            b"HTTP/1.0 200 OK\r\n\r\nThis is where the descs would go.",
782        );
783
784        let request = request?;
785        assert!(request[..].starts_with(
786            b"GET /tor/micro/d/CQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQk.z HTTP/1.0\r\n"
787        ));
788
789        let response = response.unwrap();
790        assert_eq!(response.status_code(), 200);
791        assert!(!response.is_partial());
792        assert!(response.error().is_none());
793        assert!(response.source().is_none());
794        let out_ref = response.output_unchecked();
795        assert_eq!(out_ref, b"This is where the descs would go.");
796        let out = response.into_output_unchecked();
797        assert_eq!(&out, b"This is where the descs would go.");
798
799        Ok(())
800    }
801
802    #[test]
803    fn test_download_truncated() {
804        // Request only one md, so "partial ok" will not be set.
805        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
806        let mut response_text: Vec<u8> =
807            (*b"HTTP/1.0 200 OK\r\nContent-Encoding: deflate\r\n\r\n").into();
808        // "One fish two fish" as above twice, but truncated the second time
809        response_text.extend(
810            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5970c88").unwrap(),
811        );
812        response_text.extend(
813            hex::decode("789cf3cf4b5548cb2cce500829cf8730825253200ca79c52881c00e5").unwrap(),
814        );
815        let (response, request) = run_download_test(req, &response_text);
816        assert!(request.is_ok());
817        assert!(response.is_err()); // The whole download should fail, since partial_ok wasn't set.
818
819        // request two microdescs, so "partial_ok" will be set.
820        let req: request::MicrodescRequest = vec![[9; 32]; 2].into_iter().collect();
821
822        let (response, request) = run_download_test(req, &response_text);
823        assert!(request.is_ok());
824
825        let response = response.unwrap();
826        assert_eq!(response.status_code(), 200);
827        assert!(response.error().is_some());
828        assert!(response.is_partial());
829        assert!(response.output_unchecked().len() < 37 * 2);
830        assert!(response.output_unchecked().starts_with(b"One fish"));
831    }
832
833    #[test]
834    fn test_404() {
835        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
836        let response_text = b"HTTP/1.0 418 I'm a teapot\r\n\r\n";
837        let (response, _request) = run_download_test(req, response_text);
838
839        assert_eq!(response.unwrap().status_code(), 418);
840    }
841
842    #[test]
843    fn test_headers_truncated() {
844        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
845        let response_text = b"HTTP/1.0 404 truncation happens here\r\n";
846        let (response, _request) = run_download_test(req, response_text);
847
848        assert!(matches!(
849            response,
850            Err(Error::RequestFailed(RequestFailedError {
851                error: RequestError::TruncatedHeaders,
852                ..
853            }))
854        ));
855
856        // Try a completely empty response.
857        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
858        let response_text = b"";
859        let (response, _request) = run_download_test(req, response_text);
860
861        assert!(matches!(
862            response,
863            Err(Error::RequestFailed(RequestFailedError {
864                error: RequestError::TruncatedHeaders,
865                ..
866            }))
867        ));
868    }
869
870    #[test]
871    fn test_headers_too_long() {
872        let req: request::MicrodescRequest = vec![[9; 32]].into_iter().collect();
873        let mut response_text: Vec<u8> = (*b"HTTP/1.0 418 I'm a teapot\r\nX-Too-Many-As: ").into();
874        response_text.resize(16384, b'A');
875        let (response, _request) = run_download_test(req, &response_text);
876
877        assert!(response.as_ref().unwrap_err().should_retire_circ());
878        assert!(matches!(
879            response,
880            Err(Error::RequestFailed(RequestFailedError {
881                error: RequestError::HttparseError(_),
882                ..
883            }))
884        ));
885    }
886
887    // TODO: test with bad utf-8
888}