arti_bench/
main.rs

1//! A simple benchmarking utility for Arti.
2//!
3//! This works by establishing a simple TCP server, and having Arti connect back to it via
4//! a `chutney` network of Tor nodes, benchmarking the upload and download bandwidth while doing so.
5
6// @@ begin lint list maintained by maint/add_warning @@
7#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
8#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
9#![warn(missing_docs)]
10#![warn(noop_method_call)]
11#![warn(unreachable_pub)]
12#![warn(clippy::all)]
13#![deny(clippy::await_holding_lock)]
14#![deny(clippy::cargo_common_metadata)]
15#![deny(clippy::cast_lossless)]
16#![deny(clippy::checked_conversions)]
17#![warn(clippy::cognitive_complexity)]
18#![deny(clippy::debug_assert_with_mut_call)]
19#![deny(clippy::exhaustive_enums)]
20#![deny(clippy::exhaustive_structs)]
21#![deny(clippy::expl_impl_clone_on_copy)]
22#![deny(clippy::fallible_impl_from)]
23#![deny(clippy::implicit_clone)]
24#![deny(clippy::large_stack_arrays)]
25#![warn(clippy::manual_ok_or)]
26#![deny(clippy::missing_docs_in_private_items)]
27#![warn(clippy::needless_borrow)]
28#![warn(clippy::needless_pass_by_value)]
29#![warn(clippy::option_option)]
30#![deny(clippy::print_stderr)]
31#![deny(clippy::print_stdout)]
32#![warn(clippy::rc_buffer)]
33#![deny(clippy::ref_option_ref)]
34#![warn(clippy::semicolon_if_nothing_returned)]
35#![warn(clippy::trait_duplication_in_bounds)]
36#![deny(clippy::unchecked_duration_subtraction)]
37#![deny(clippy::unnecessary_wraps)]
38#![warn(clippy::unseparated_literal_suffix)]
39#![deny(clippy::unwrap_used)]
40#![deny(clippy::mod_module_files)]
41#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
42#![allow(clippy::uninlined_format_args)]
43#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
44#![allow(clippy::result_large_err)] // temporary workaround for arti#587
45#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
46#![allow(clippy::needless_lifetimes)] // See arti#1765
47#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
48//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
49// This file uses `unwrap()` a fair deal, but this is fine in test/bench code
50// because it's OK if tests and benchmarks simply crash if things go wrong.
51#![allow(clippy::unwrap_used)]
52
53use anyhow::{anyhow, Result};
54use arti::cfg::ArtiCombinedConfig;
55use arti_client::{IsolationToken, TorAddr, TorClient, TorClientConfig};
56use clap::{value_parser, Arg, ArgAction};
57use futures::StreamExt;
58use rand::distr::StandardUniform;
59use rand::Rng;
60use serde::{Deserialize, Serialize};
61use std::collections::HashMap;
62use std::ffi::OsString;
63use std::fmt;
64use std::fmt::Formatter;
65use std::future::Future;
66use std::io::{Read, Write};
67use std::net::{IpAddr, SocketAddr, TcpListener, TcpStream};
68use std::ops::Deref;
69use std::str::FromStr;
70use std::sync::Arc;
71use std::thread::JoinHandle;
72use std::time::SystemTime;
73use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
74use tokio_socks::tcp::Socks5Stream;
75use tor_config::{ConfigurationSource, ConfigurationSources};
76use tor_rtcompat::ToplevelRuntime;
77use tracing::info;
78
79/// Generate a random payload of bytes of the given size
80fn random_payload(size: usize) -> Vec<u8> {
81    rand::rng()
82        .sample_iter(StandardUniform)
83        .take(size)
84        .collect()
85}
86
87/// Timing information from the benchmarking server.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ServerTiming {
90    /// When the connection was accepted.
91    accepted_ts: SystemTime,
92    /// When the payload was successfully written to the client.
93    copied_ts: SystemTime,
94    /// When the server received the first byte from the client.
95    first_byte_ts: SystemTime,
96    /// When the server finished reading the client's payload.
97    read_done_ts: SystemTime,
98}
99
100/// Timing information from the benchmarking client.
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ClientTiming {
103    /// When the client's connection succeeded.
104    started_ts: SystemTime,
105    /// When the client received the first byte from the server.
106    first_byte_ts: SystemTime,
107    /// When the client finished reading the server's payload.
108    read_done_ts: SystemTime,
109    /// When the payload was successfully written to the server.
110    copied_ts: SystemTime,
111    /// The server's copy of the timing information.
112    server: ServerTiming,
113    /// The size of the payload downloaded from the server.
114    download_size: usize,
115    /// The size of the payload uploaded to the server.
116    upload_size: usize,
117}
118
119/// A summary of benchmarking results, generated from `ClientTiming`.
120#[derive(Debug, Copy, Clone, Serialize)]
121pub struct TimingSummary {
122    /// The time to first byte (TTFB) for the download benchmark.
123    download_ttfb_sec: f64,
124    /// The average download speed, in megabits per second.
125    download_rate_megabit: f64,
126    /// The time to first byte (TTFB) for the upload benchmark.
127    upload_ttfb_sec: f64,
128    /// The average upload speed, in megabits per second.
129    upload_rate_megabit: f64,
130}
131
132impl fmt::Display for TimingSummary {
133    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
134        write!(
135            f,
136            "{:.2} Mbit/s up (ttfb {:.2}ms), {:.2} Mbit/s down (ttfb {:.2}ms)",
137            self.upload_rate_megabit,
138            self.upload_ttfb_sec * 1000.0,
139            self.download_rate_megabit,
140            self.download_ttfb_sec * 1000.0
141        )
142    }
143}
144
145impl TimingSummary {
146    /// Generate a `TimingSummary` from the `ClientTiming` returned by a benchmark run.
147    pub fn generate(ct: &ClientTiming) -> Result<Self> {
148        let download_ttfb = ct.first_byte_ts.duration_since(ct.server.accepted_ts)?;
149        let download_time = ct.read_done_ts.duration_since(ct.first_byte_ts)?;
150        let download_rate_bps = ct.download_size as f64 / download_time.as_secs_f64();
151
152        let upload_ttfb = ct.server.first_byte_ts.duration_since(ct.read_done_ts)?;
153        let upload_time = ct
154            .server
155            .read_done_ts
156            .duration_since(ct.server.first_byte_ts)?;
157        let upload_rate_bps = ct.upload_size as f64 / upload_time.as_secs_f64();
158
159        Ok(Self {
160            download_ttfb_sec: download_ttfb.as_secs_f64(),
161            download_rate_megabit: download_rate_bps / 125_000.0,
162            upload_ttfb_sec: upload_ttfb.as_secs_f64(),
163            upload_rate_megabit: upload_rate_bps / 125_000.0,
164        })
165    }
166}
167
168/// How much should we be willing to read at a time?
169const RECV_BUF_LEN: usize = 8192;
170
171/// Run the timing routine
172#[allow(clippy::cognitive_complexity)]
173fn run_timing(mut stream: TcpStream, send: &Arc<[u8]>, receive: &Arc<[u8]>) -> Result<()> {
174    let peer_addr = stream.peer_addr()?;
175    let mut received = vec![0_u8; RECV_BUF_LEN];
176    let expected_len = receive.len();
177    let mut expected = receive.deref();
178    let mut mismatch = false;
179    let mut total_read = 0;
180
181    info!("Accepted connection from {}", peer_addr);
182    let accepted_ts = SystemTime::now();
183    let mut data: &[u8] = send.deref();
184    let copied = std::io::copy(&mut data, &mut stream)?;
185    stream.flush()?;
186    let copied_ts = SystemTime::now();
187    assert_eq!(copied, send.len() as u64);
188    info!("Copied {} bytes payload to {}.", copied, peer_addr);
189    let read = stream.read(&mut received)?;
190    if read == 0 {
191        panic!("unexpected EOF");
192    }
193    let first_byte_ts = SystemTime::now();
194    if received[0..read] != expected[0..read] {
195        mismatch = true;
196    }
197    expected = &expected[read..];
198    total_read += read;
199    while total_read < expected_len {
200        let read = stream.read(&mut received)?;
201        if read == 0 {
202            panic!("unexpected eof");
203        }
204        if received[0..read] != expected[0..read] {
205            mismatch = true;
206        }
207        expected = &expected[read..];
208        total_read += read;
209    }
210    let read_done_ts = SystemTime::now();
211    info!("Received {} bytes payload from {}.", total_read, peer_addr);
212    // Check we actually got what we thought we would get.
213    if mismatch {
214        panic!("Received data doesn't match expected; potential corruption?");
215    }
216    let st = ServerTiming {
217        accepted_ts,
218        copied_ts,
219        first_byte_ts,
220        read_done_ts,
221    };
222    serde_json::to_writer(&mut stream, &st)?;
223    info!("Wrote timing payload to {}.", peer_addr);
224    Ok(())
225}
226
227/// Runs the benchmarking TCP server, using the provided TCP listener and set of payloads.
228fn serve_payload(
229    listener: &TcpListener,
230    send: &Arc<[u8]>,
231    receive: &Arc<[u8]>,
232) -> Vec<JoinHandle<Result<()>>> {
233    info!("Listening for clients...");
234
235    listener
236        .incoming()
237        .map(|stream| {
238            let send = Arc::clone(send);
239            let receive = Arc::clone(receive);
240            std::thread::spawn(move || run_timing(stream?, &send, &receive))
241        })
242        .collect()
243}
244
245/// Runs the benchmarking client on the provided socket.
246async fn client<S: AsyncRead + AsyncWrite + Unpin>(
247    mut socket: S,
248    send: Arc<[u8]>,
249    receive: Arc<[u8]>,
250) -> Result<ClientTiming> {
251    // Do this potentially costly allocation before we do all the timing stuff.
252    let mut received = vec![0_u8; receive.len()];
253    let started_ts = SystemTime::now();
254
255    let read = socket.read(&mut received).await?;
256    if read == 0 {
257        return Err(anyhow!("unexpected EOF"));
258    }
259    let first_byte_ts = SystemTime::now();
260    socket.read_exact(&mut received[read..]).await?;
261    let read_done_ts = SystemTime::now();
262    info!("Received {} bytes payload.", received.len());
263    let mut send_data = &send as &[u8];
264
265    tokio::io::copy(&mut send_data, &mut socket).await?;
266    socket.flush().await?;
267    info!("Sent {} bytes payload.", send.len());
268    let copied_ts = SystemTime::now();
269
270    // Check we actually got what we thought we would get.
271    if received != receive.deref() {
272        panic!("Received data doesn't match expected; potential corruption?");
273    }
274    let mut json_buf = Vec::new();
275    socket.read_to_end(&mut json_buf).await?;
276    let server: ServerTiming = serde_json::from_slice(&json_buf)?;
277    Ok(ClientTiming {
278        started_ts,
279        first_byte_ts,
280        read_done_ts,
281        copied_ts,
282        server,
283        download_size: receive.len(),
284        upload_size: send.len(),
285    })
286}
287
288#[allow(clippy::cognitive_complexity)]
289fn main() -> Result<()> {
290    tracing_subscriber::fmt::init();
291
292    let matches = clap::Command::new("arti-bench")
293        .version(env!("CARGO_PKG_VERSION"))
294        .author("The Tor Project Developers")
295        .about("A simple benchmarking utility for Arti.")
296        .arg(
297            Arg::new("arti-config")
298                .short('c')
299                .long("arti-config")
300                .action(ArgAction::Set)
301                .required(true)
302                .value_name("CONFIG")
303                .value_parser(value_parser!(OsString))
304                .help(
305                    "Path to the Arti configuration to use (usually, a Chutney-generated config).",
306                ),
307        )
308        .arg(
309            Arg::new("num-samples")
310                .short('s')
311                .long("num-samples")
312                .action(ArgAction::Set)
313                .value_name("COUNT")
314                .value_parser(value_parser!(usize))
315                .default_value("3")
316                .help("How many samples to take per benchmark run.")
317        )
318        .arg(
319            Arg::new("num-streams")
320                .short('p')
321                .long("streams")
322                .aliases(["num-parallel"])
323                .action(ArgAction::Set)
324                .value_name("COUNT")
325                .value_parser(value_parser!(usize))
326                .default_value("3")
327                .help("How many simultaneous streams per circuit.")
328        )
329        .arg(
330            Arg::new("num-circuits")
331                .short('C')
332                .long("num-circuits")
333                .action(ArgAction::Set)
334                .value_name("COUNT")
335                .value_parser(value_parser!(usize))
336                .default_value("1")
337                .help("How many simultaneous circuits per run.")
338        )
339        .arg(
340            Arg::new("output")
341                .short('o')
342                .action(ArgAction::Set)
343                .value_name("/path/to/output.json")
344                .help("A path to write benchmark results to, in JSON format.")
345        )
346        .arg(
347            Arg::new("download-bytes")
348                .short('d')
349                .long("download-bytes")
350                .action(ArgAction::Set)
351                .value_name("SIZE")
352                .value_parser(value_parser!(usize))
353                .default_value("10485760")
354                .help("How much fake payload data to generate for the download benchmark."),
355        )
356        .arg(
357            Arg::new("upload-bytes")
358                .short('u')
359                .long("upload-bytes")
360                .action(ArgAction::Set)
361                .value_name("SIZE")
362                .value_parser(value_parser!(usize))
363                .default_value("10485760")
364                .help("How much fake payload data to generate for the upload benchmark."),
365        )
366        .arg(
367            Arg::new("socks-proxy")
368                .long("socks5")
369                .action(ArgAction::Set)
370                .value_name("addr:port")
371                .help("SOCKS5 proxy address for a node to benchmark through as well (usually a Chutney node). Optional."),
372        )
373        .get_matches();
374    info!("Parsing Arti configuration...");
375    let mut config_sources = ConfigurationSources::new_empty();
376    matches
377        .get_many::<OsString>("arti-config")
378        .unwrap_or_default()
379        .for_each(|f| {
380            config_sources.push_source(
381                ConfigurationSource::from_path(f),
382                tor_config::sources::MustRead::MustRead,
383            );
384        });
385
386    // TODO really we ought to get this from the arti configuration, or something.
387    // But this is OK for now since we are a benchmarking tool.
388    let mistrust = fs_mistrust::Mistrust::new_dangerously_trust_everyone();
389    config_sources.set_mistrust(mistrust);
390
391    let cfg = config_sources.load()?;
392    let (_config, tcc) = tor_config::resolve::<ArtiCombinedConfig>(cfg)?;
393    info!("Binding local TCP listener...");
394    let listener = TcpListener::bind("0.0.0.0:0")?;
395    let local_addr = listener.local_addr()?;
396    let connect_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), local_addr.port());
397    info!("Bound to {}.", local_addr);
398    let upload_bytes = *matches.get_one::<usize>("upload-bytes").unwrap();
399    let download_bytes = *matches.get_one::<usize>("download-bytes").unwrap();
400    let samples = *matches.get_one::<usize>("num-samples").unwrap();
401    let streams_per_circ = *matches.get_one::<usize>("num-streams").unwrap();
402    let circs_per_sample = *matches.get_one::<usize>("num-circuits").unwrap();
403    info!("Generating test payloads, please wait...");
404    let upload_payload = random_payload(upload_bytes).into();
405    let download_payload = random_payload(download_bytes).into();
406    info!(
407        "Generated payloads ({} upload, {} download)",
408        upload_bytes, download_bytes
409    );
410    let up = Arc::clone(&upload_payload);
411    let dp = Arc::clone(&download_payload);
412    let _handle = std::thread::spawn(move || -> Result<()> {
413        serve_payload(&listener, &dp, &up)
414            .into_iter()
415            .try_for_each(|handle| handle.join().expect("failed to join thread"))
416    });
417
418    let mut benchmark = Benchmark {
419        connect_addr,
420        samples,
421        streams_per_circ,
422        circs_per_sample,
423        upload_payload,
424        download_payload,
425        runtime: tor_rtcompat::tokio::TokioNativeTlsRuntime::create()?,
426        results: Default::default(),
427    };
428
429    benchmark.without_arti()?;
430    if let Some(addr) = matches.get_one::<String>("socks-proxy") {
431        benchmark.with_proxy(addr)?;
432    }
433    benchmark.with_arti(tcc)?;
434
435    info!("Benchmarking complete.");
436
437    for (ty, results) in benchmark.results.iter() {
438        info!(
439            "Information for benchmark type {:?} ({} samples taken):",
440            ty, benchmark.samples
441        );
442        info!("  upload rate: {} Mbit/s", results.upload_rate_megabit);
443        info!("download rate: {} Mbit/s", results.upload_rate_megabit);
444        info!("    TTFB (up): {} msec", results.upload_ttfb_msec);
445        info!("  TTFB (down): {} msec", results.download_ttfb_msec);
446    }
447
448    if let Some(output) = matches.get_one::<String>("output") {
449        info!("Writing benchmark results to {}...", output);
450        let file = std::fs::File::create(output)?;
451        serde_json::to_writer(
452            &file,
453            &BenchmarkSummary {
454                crate_version: env!("CARGO_PKG_VERSION").to_string(),
455                results: benchmark.results,
456            },
457        )?;
458    }
459
460    Ok(())
461}
462
463/// A helper struct for running benchmarks
464#[allow(clippy::missing_docs_in_private_items)]
465struct Benchmark<R>
466where
467    R: ToplevelRuntime,
468{
469    runtime: R,
470    connect_addr: SocketAddr,
471    samples: usize,
472    streams_per_circ: usize,
473    circs_per_sample: usize,
474    upload_payload: Arc<[u8]>,
475    download_payload: Arc<[u8]>,
476    /// All benchmark results conducted, indexed by benchmark type.
477    results: HashMap<BenchmarkType, BenchmarkResults>,
478}
479
480/// The type of benchmark conducted.
481#[derive(Clone, Copy, Serialize, Deserialize, Hash, Debug, PartialEq, Eq)]
482enum BenchmarkType {
483    /// Use the benchmark server on its own, without using any proxy.
484    ///
485    /// This is useful to get an idea of how well the benchmarking utility performs on its own.
486    RawLoopback,
487    /// Benchmark via a SOCKS5 proxy (usually that of a chutney node).
488    Socks,
489    /// Benchmark via Arti.
490    Arti,
491}
492
493#[derive(Clone, Serialize, Debug)]
494/// Some information about a set of benchmark samples collected during multiple runs.
495struct Statistic {
496    /// The mean value of all samples.
497    mean: f64,
498    /// The low-median value of all samples.
499    /// # Important note
500    ///
501    /// This is only the median if an odd number of samples were collected; otherwise,
502    /// it is the `(number of samples / 2)`th sample after the samples are sorted.
503    median: f64,
504    /// The minimum sample observed.
505    min: f64,
506    /// The maximum sample observed.
507    max: f64,
508    /// The standard deviation of the set of samples.
509    stddev: f64,
510}
511
512impl fmt::Display for Statistic {
513    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
514        let Statistic {
515            mean,
516            median,
517            min,
518            max,
519            stddev,
520        } = self;
521        write!(
522            f,
523            "min/mean/median/max/stddev = {:>7.2}/{:>7.2}/{:>7.2}/{:>7.2}/{:>7.2}",
524            min, mean, median, max, stddev
525        )
526    }
527}
528
529impl Statistic {
530    /// Generate a summary of the provided `samples`.
531    ///
532    /// # Panics
533    ///
534    /// Panics if `samples` is empty.
535    fn from_samples(mut samples: Vec<f64>) -> Self {
536        let n_samples = samples.len();
537        float_ord::sort(&mut samples);
538        let mean = samples.iter().sum::<f64>() / n_samples as f64;
539        // \Sigma (x_i - \mu)^2
540        let samples_minus_mean_sum = samples.iter().map(|xi| (xi - mean).powf(2.0)).sum::<f64>();
541        let stddev = (samples_minus_mean_sum / n_samples as f64).sqrt();
542        Statistic {
543            mean,
544            median: samples[n_samples / 2],
545            min: samples[0],
546            max: samples[n_samples - 1],
547            stddev,
548        }
549    }
550}
551
552/// A set of benchmark results for a given `BenchmarkType`, including information about averages.
553#[derive(Clone, Serialize, Debug)]
554struct BenchmarkResults {
555    /// The type of benchmark conducted.
556    ty: BenchmarkType,
557    /// The number of times the benchmark was run.
558    samples: usize,
559    /// The number of concurrent streams per circuit used during the run.
560    streams_per_circ: usize,
561    /// The number of circuits used during the run.
562    circuits: usize,
563    /// The time to first byte (TTFB) for the download benchmark, in milliseconds.
564    download_ttfb_msec: Statistic,
565    /// The average download speed, in megabits per second.
566    download_rate_megabit: Statistic,
567    /// The time to first byte (TTFB) for the upload benchmark, in milliseconds.
568    upload_ttfb_msec: Statistic,
569    /// The average upload speed, in megabits per second.
570    upload_rate_megabit: Statistic,
571
572    /// The raw benchmark results.
573    results_raw: Vec<TimingSummary>,
574}
575
576impl BenchmarkResults {
577    /// Generate summarized benchmark results from raw run data.
578    fn generate(
579        ty: BenchmarkType,
580        streams_per_circ: usize,
581        circuits: usize,
582        raw: Vec<TimingSummary>,
583    ) -> Self {
584        let download_ttfb_msecs = raw
585            .iter()
586            .map(|s| s.download_ttfb_sec * 1000.0)
587            .collect::<Vec<_>>();
588        let download_rate_megabits = raw
589            .iter()
590            .map(|s| s.download_rate_megabit)
591            .collect::<Vec<_>>();
592        let upload_ttfb_msecs = raw
593            .iter()
594            .map(|s| s.upload_ttfb_sec * 1000.0)
595            .collect::<Vec<_>>();
596        let upload_rate_megabits = raw
597            .iter()
598            .map(|s| s.upload_rate_megabit)
599            .collect::<Vec<_>>();
600        let samples = raw.len();
601        BenchmarkResults {
602            ty,
603            samples,
604            streams_per_circ,
605            circuits,
606            download_ttfb_msec: Statistic::from_samples(download_ttfb_msecs),
607            download_rate_megabit: Statistic::from_samples(download_rate_megabits),
608            upload_ttfb_msec: Statistic::from_samples(upload_ttfb_msecs),
609            upload_rate_megabit: Statistic::from_samples(upload_rate_megabits),
610            results_raw: raw,
611        }
612    }
613}
614
615/// A summary of all benchmarks conducted throughout the invocation of `arti-bench`.
616///
617/// Designed to be stored as an artifact and compared against other later runs.
618#[derive(Clone, Serialize, Debug)]
619struct BenchmarkSummary {
620    /// The version of `arti-bench` used to generate the benchmark results.
621    crate_version: String,
622    /// All benchmark results conducted, indexed by benchmark type.
623    results: HashMap<BenchmarkType, BenchmarkResults>,
624}
625
626impl<R: ToplevelRuntime> Benchmark<R> {
627    /// Run a type of benchmark (`ty`), performing `self.samples` benchmark
628    /// runs, using `self.circs_per_sample` concurrent circuits, and
629    /// `self.streams_per_circ` concurrent streams on each circuit.
630    ///
631    /// Uses `stream_generator`, function that returns futures that themselves
632    /// generate streams, in order to obtain the required number of streams to
633    /// run the test over.  The function takes an index of the current run.
634    fn run<F, G, S, E>(&mut self, ty: BenchmarkType, mut stream_generator: F) -> Result<()>
635    where
636        F: FnMut(usize) -> G,
637        G: Future<Output = Result<S, E>>,
638        S: AsyncRead + AsyncWrite + Unpin,
639        E: std::error::Error + Send + Sync + 'static,
640    {
641        let mut results = vec![];
642        for n in 0..self.samples {
643            let total_streams = self.streams_per_circ * self.circs_per_sample;
644            let futures = (0..total_streams)
645                .map(|_| {
646                    let up = Arc::clone(&self.upload_payload);
647                    let dp = Arc::clone(&self.download_payload);
648                    let stream = stream_generator(n);
649                    Box::pin(async move { client(stream.await?, up, dp).await })
650                })
651                .collect::<futures::stream::FuturesUnordered<_>>()
652                .collect::<Vec<_>>();
653            info!(
654                "Benchmarking {:?} with {} connections, run {}/{}...",
655                ty,
656                self.streams_per_circ,
657                n + 1,
658                self.samples
659            );
660            let stats = self
661                .runtime
662                .block_on(futures)
663                .into_iter()
664                .map(|x| x.and_then(|x| TimingSummary::generate(&x)))
665                .collect::<Result<Vec<_>>>()?;
666            results.extend(stats);
667        }
668        let results =
669            BenchmarkResults::generate(ty, self.streams_per_circ, self.circs_per_sample, results);
670        self.results.insert(ty, results);
671        Ok(())
672    }
673
674    /// Benchmark without Arti on loopback.
675    fn without_arti(&mut self) -> Result<()> {
676        let ca = self.connect_addr;
677        self.run(BenchmarkType::RawLoopback, |_| {
678            tokio::net::TcpStream::connect(ca)
679        })
680    }
681
682    /// Benchmark through a SOCKS5 proxy at address `addr`.
683    fn with_proxy(&mut self, addr: &str) -> Result<()> {
684        let ca = self.connect_addr;
685        let mut iso = StreamIsolationTracker::new(self.streams_per_circ);
686
687        self.run(BenchmarkType::Socks, |run| {
688            // Tor uses the username,password tuple of socks authentication do decide how to isolate streams.
689            let iso_string = format!("{:?}", iso.next_in(run));
690            async move {
691                Socks5Stream::connect_with_password(addr, ca, &iso_string, &iso_string).await
692            }
693        })
694    }
695
696    /// Benchmark through Arti, using the provided `TorClientConfig`.
697    fn with_arti(&mut self, tcc: TorClientConfig) -> Result<()> {
698        info!("Starting Arti...");
699        let tor_client = self.runtime.block_on(
700            TorClient::with_runtime(self.runtime.clone())
701                .config(tcc)
702                .create_bootstrapped(),
703        )?;
704
705        let addr = TorAddr::dangerously_from(self.connect_addr)?;
706
707        let mut iso = StreamIsolationTracker::new(self.streams_per_circ);
708
709        self.run(BenchmarkType::Arti, |run| {
710            let mut prefs = arti_client::StreamPrefs::new();
711            prefs.set_isolation(iso.next_in(run));
712
713            tor_client.connect(addr.clone())
714        })
715    }
716}
717
718/// Helper type: track a StreamIsolation token over a set of runs.
719///
720/// We want to return a new token every `streams_per_circ` calls for each run,
721/// but always give a new token when a new run begins.
722#[derive(Debug, Clone)]
723struct StreamIsolationTracker {
724    /// The number of streams to assign to each circuit.
725    streams_per_circ: usize,
726    /// The current run index.
727    cur_run: usize,
728    /// The stream index within the run that we expect on the _next_ call to `next_in`.
729    next_stream: usize,
730    /// The isolation token we're currently handing out.
731    cur_token: IsolationToken,
732}
733
734impl StreamIsolationTracker {
735    /// Construct a new StreamIsolationTracker.
736    fn new(streams_per_circ: usize) -> Self {
737        Self {
738            streams_per_circ,
739            cur_run: 0,
740            next_stream: 0,
741            cur_token: IsolationToken::new(),
742        }
743    }
744    /// Return the isolation token to use for the next stream in the given
745    /// `run`.  Requires that runs are not interleaved.
746    fn next_in(&mut self, run: usize) -> IsolationToken {
747        if run != self.cur_run {
748            self.cur_run = run;
749            self.next_stream = 0;
750            self.cur_token = IsolationToken::new();
751        } else if self.next_stream % self.streams_per_circ == 0 {
752            self.cur_token = IsolationToken::new();
753        }
754        self.next_stream += 1;
755
756        self.cur_token
757    }
758}
759
760#[cfg(test)]
761mod test {
762    // @@ begin test lint list maintained by maint/add_warning @@
763    #![allow(clippy::bool_assert_comparison)]
764    #![allow(clippy::clone_on_copy)]
765    #![allow(clippy::dbg_macro)]
766    #![allow(clippy::mixed_attributes_style)]
767    #![allow(clippy::print_stderr)]
768    #![allow(clippy::print_stdout)]
769    #![allow(clippy::single_char_pattern)]
770    #![allow(clippy::unwrap_used)]
771    #![allow(clippy::unchecked_duration_subtraction)]
772    #![allow(clippy::useless_vec)]
773    #![allow(clippy::needless_pass_by_value)]
774    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
775    use super::StreamIsolationTracker;
776
777    #[test]
778    fn test_iso_tracker() {
779        let mut tr = StreamIsolationTracker::new(2);
780        let r1: Vec<_> = (0..9).map(|_| tr.next_in(0)).collect();
781        let r2: Vec<_> = (0..6).map(|_| tr.next_in(1)).collect();
782        assert_eq!(r1[0], r1[1]);
783        assert_ne!(r1[1], r1[2]);
784        assert_eq!(r1[2], r1[3]);
785        assert_eq!(r2[0], r2[1]);
786        assert_ne!(r2[1], r2[2]);
787        assert!(!r1.contains(&r2[0]));
788    }
789}