arti_bench/
main.rs

1//! A simple benchmarking utility for Arti.
2//!
3//! This works by establishing a simple TCP server, and having Arti connect back to it via
4//! a `chutney` network of Tor nodes, benchmarking the upload and download bandwidth while doing so.
5
6// @@ begin lint list maintained by maint/add_warning @@
7#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
8#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
9#![warn(missing_docs)]
10#![warn(noop_method_call)]
11#![warn(unreachable_pub)]
12#![warn(clippy::all)]
13#![deny(clippy::await_holding_lock)]
14#![deny(clippy::cargo_common_metadata)]
15#![deny(clippy::cast_lossless)]
16#![deny(clippy::checked_conversions)]
17#![warn(clippy::cognitive_complexity)]
18#![deny(clippy::debug_assert_with_mut_call)]
19#![deny(clippy::exhaustive_enums)]
20#![deny(clippy::exhaustive_structs)]
21#![deny(clippy::expl_impl_clone_on_copy)]
22#![deny(clippy::fallible_impl_from)]
23#![deny(clippy::implicit_clone)]
24#![deny(clippy::large_stack_arrays)]
25#![warn(clippy::manual_ok_or)]
26#![deny(clippy::missing_docs_in_private_items)]
27#![warn(clippy::needless_borrow)]
28#![warn(clippy::needless_pass_by_value)]
29#![warn(clippy::option_option)]
30#![deny(clippy::print_stderr)]
31#![deny(clippy::print_stdout)]
32#![warn(clippy::rc_buffer)]
33#![deny(clippy::ref_option_ref)]
34#![warn(clippy::semicolon_if_nothing_returned)]
35#![warn(clippy::trait_duplication_in_bounds)]
36#![deny(clippy::unchecked_duration_subtraction)]
37#![deny(clippy::unnecessary_wraps)]
38#![warn(clippy::unseparated_literal_suffix)]
39#![deny(clippy::unwrap_used)]
40#![deny(clippy::mod_module_files)]
41#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
42#![allow(clippy::uninlined_format_args)]
43#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
44#![allow(clippy::result_large_err)] // temporary workaround for arti#587
45#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
46#![allow(clippy::needless_lifetimes)] // See arti#1765
47//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
48// This file uses `unwrap()` a fair deal, but this is fine in test/bench code
49// because it's OK if tests and benchmarks simply crash if things go wrong.
50#![allow(clippy::unwrap_used)]
51
52use anyhow::{anyhow, Result};
53use arti::cfg::ArtiCombinedConfig;
54use arti_client::{IsolationToken, TorAddr, TorClient, TorClientConfig};
55use clap::{value_parser, Arg, ArgAction};
56use futures::StreamExt;
57use rand::distr::StandardUniform;
58use rand::Rng;
59use serde::{Deserialize, Serialize};
60use std::collections::HashMap;
61use std::ffi::OsString;
62use std::fmt;
63use std::fmt::Formatter;
64use std::future::Future;
65use std::io::{Read, Write};
66use std::net::{IpAddr, SocketAddr, TcpListener, TcpStream};
67use std::ops::Deref;
68use std::str::FromStr;
69use std::sync::Arc;
70use std::thread::JoinHandle;
71use std::time::SystemTime;
72use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
73use tokio_socks::tcp::Socks5Stream;
74use tor_config::{ConfigurationSource, ConfigurationSources};
75use tor_rtcompat::ToplevelRuntime;
76use tracing::info;
77
78/// Generate a random payload of bytes of the given size
79fn random_payload(size: usize) -> Vec<u8> {
80    rand::rng()
81        .sample_iter(StandardUniform)
82        .take(size)
83        .collect()
84}
85
86/// Timing information from the benchmarking server.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct ServerTiming {
89    /// When the connection was accepted.
90    accepted_ts: SystemTime,
91    /// When the payload was successfully written to the client.
92    copied_ts: SystemTime,
93    /// When the server received the first byte from the client.
94    first_byte_ts: SystemTime,
95    /// When the server finished reading the client's payload.
96    read_done_ts: SystemTime,
97}
98
99/// Timing information from the benchmarking client.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ClientTiming {
102    /// When the client's connection succeeded.
103    started_ts: SystemTime,
104    /// When the client received the first byte from the server.
105    first_byte_ts: SystemTime,
106    /// When the client finished reading the server's payload.
107    read_done_ts: SystemTime,
108    /// When the payload was successfully written to the server.
109    copied_ts: SystemTime,
110    /// The server's copy of the timing information.
111    server: ServerTiming,
112    /// The size of the payload downloaded from the server.
113    download_size: usize,
114    /// The size of the payload uploaded to the server.
115    upload_size: usize,
116}
117
118/// A summary of benchmarking results, generated from `ClientTiming`.
119#[derive(Debug, Copy, Clone, Serialize)]
120pub struct TimingSummary {
121    /// The time to first byte (TTFB) for the download benchmark.
122    download_ttfb_sec: f64,
123    /// The average download speed, in megabits per second.
124    download_rate_megabit: f64,
125    /// The time to first byte (TTFB) for the upload benchmark.
126    upload_ttfb_sec: f64,
127    /// The average upload speed, in megabits per second.
128    upload_rate_megabit: f64,
129}
130
131impl fmt::Display for TimingSummary {
132    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
133        write!(
134            f,
135            "{:.2} Mbit/s up (ttfb {:.2}ms), {:.2} Mbit/s down (ttfb {:.2}ms)",
136            self.upload_rate_megabit,
137            self.upload_ttfb_sec * 1000.0,
138            self.download_rate_megabit,
139            self.download_ttfb_sec * 1000.0
140        )
141    }
142}
143
144impl TimingSummary {
145    /// Generate a `TimingSummary` from the `ClientTiming` returned by a benchmark run.
146    pub fn generate(ct: &ClientTiming) -> Result<Self> {
147        let download_ttfb = ct.first_byte_ts.duration_since(ct.server.accepted_ts)?;
148        let download_time = ct.read_done_ts.duration_since(ct.first_byte_ts)?;
149        let download_rate_bps = ct.download_size as f64 / download_time.as_secs_f64();
150
151        let upload_ttfb = ct.server.first_byte_ts.duration_since(ct.read_done_ts)?;
152        let upload_time = ct
153            .server
154            .read_done_ts
155            .duration_since(ct.server.first_byte_ts)?;
156        let upload_rate_bps = ct.upload_size as f64 / upload_time.as_secs_f64();
157
158        Ok(Self {
159            download_ttfb_sec: download_ttfb.as_secs_f64(),
160            download_rate_megabit: download_rate_bps / 125_000.0,
161            upload_ttfb_sec: upload_ttfb.as_secs_f64(),
162            upload_rate_megabit: upload_rate_bps / 125_000.0,
163        })
164    }
165}
166
167/// How much should we be willing to read at a time?
168const RECV_BUF_LEN: usize = 8192;
169
170/// Run the timing routine
171#[allow(clippy::cognitive_complexity)]
172fn run_timing(mut stream: TcpStream, send: &Arc<[u8]>, receive: &Arc<[u8]>) -> Result<()> {
173    let peer_addr = stream.peer_addr()?;
174    let mut received = vec![0_u8; RECV_BUF_LEN];
175    let expected_len = receive.len();
176    let mut expected = receive.deref();
177    let mut mismatch = false;
178    let mut total_read = 0;
179
180    info!("Accepted connection from {}", peer_addr);
181    let accepted_ts = SystemTime::now();
182    let mut data: &[u8] = send.deref();
183    let copied = std::io::copy(&mut data, &mut stream)?;
184    stream.flush()?;
185    let copied_ts = SystemTime::now();
186    assert_eq!(copied, send.len() as u64);
187    info!("Copied {} bytes payload to {}.", copied, peer_addr);
188    let read = stream.read(&mut received)?;
189    if read == 0 {
190        panic!("unexpected EOF");
191    }
192    let first_byte_ts = SystemTime::now();
193    if received[0..read] != expected[0..read] {
194        mismatch = true;
195    }
196    expected = &expected[read..];
197    total_read += read;
198    while total_read < expected_len {
199        let read = stream.read(&mut received)?;
200        if read == 0 {
201            panic!("unexpected eof");
202        }
203        if received[0..read] != expected[0..read] {
204            mismatch = true;
205        }
206        expected = &expected[read..];
207        total_read += read;
208    }
209    let read_done_ts = SystemTime::now();
210    info!("Received {} bytes payload from {}.", total_read, peer_addr);
211    // Check we actually got what we thought we would get.
212    if mismatch {
213        panic!("Received data doesn't match expected; potential corruption?");
214    }
215    let st = ServerTiming {
216        accepted_ts,
217        copied_ts,
218        first_byte_ts,
219        read_done_ts,
220    };
221    serde_json::to_writer(&mut stream, &st)?;
222    info!("Wrote timing payload to {}.", peer_addr);
223    Ok(())
224}
225
226/// Runs the benchmarking TCP server, using the provided TCP listener and set of payloads.
227fn serve_payload(
228    listener: &TcpListener,
229    send: &Arc<[u8]>,
230    receive: &Arc<[u8]>,
231) -> Vec<JoinHandle<Result<()>>> {
232    info!("Listening for clients...");
233
234    listener
235        .incoming()
236        .map(|stream| {
237            let send = Arc::clone(send);
238            let receive = Arc::clone(receive);
239            std::thread::spawn(move || run_timing(stream?, &send, &receive))
240        })
241        .collect()
242}
243
244/// Runs the benchmarking client on the provided socket.
245async fn client<S: AsyncRead + AsyncWrite + Unpin>(
246    mut socket: S,
247    send: Arc<[u8]>,
248    receive: Arc<[u8]>,
249) -> Result<ClientTiming> {
250    // Do this potentially costly allocation before we do all the timing stuff.
251    let mut received = vec![0_u8; receive.len()];
252    let started_ts = SystemTime::now();
253
254    let read = socket.read(&mut received).await?;
255    if read == 0 {
256        return Err(anyhow!("unexpected EOF"));
257    }
258    let first_byte_ts = SystemTime::now();
259    socket.read_exact(&mut received[read..]).await?;
260    let read_done_ts = SystemTime::now();
261    info!("Received {} bytes payload.", received.len());
262    let mut send_data = &send as &[u8];
263
264    tokio::io::copy(&mut send_data, &mut socket).await?;
265    socket.flush().await?;
266    info!("Sent {} bytes payload.", send.len());
267    let copied_ts = SystemTime::now();
268
269    // Check we actually got what we thought we would get.
270    if received != receive.deref() {
271        panic!("Received data doesn't match expected; potential corruption?");
272    }
273    let mut json_buf = Vec::new();
274    socket.read_to_end(&mut json_buf).await?;
275    let server: ServerTiming = serde_json::from_slice(&json_buf)?;
276    Ok(ClientTiming {
277        started_ts,
278        first_byte_ts,
279        read_done_ts,
280        copied_ts,
281        server,
282        download_size: receive.len(),
283        upload_size: send.len(),
284    })
285}
286
287#[allow(clippy::cognitive_complexity)]
288fn main() -> Result<()> {
289    tracing_subscriber::fmt::init();
290
291    let matches = clap::Command::new("arti-bench")
292        .version(env!("CARGO_PKG_VERSION"))
293        .author("The Tor Project Developers")
294        .about("A simple benchmarking utility for Arti.")
295        .arg(
296            Arg::new("arti-config")
297                .short('c')
298                .long("arti-config")
299                .action(ArgAction::Set)
300                .required(true)
301                .value_name("CONFIG")
302                .value_parser(value_parser!(OsString))
303                .help(
304                    "Path to the Arti configuration to use (usually, a Chutney-generated config).",
305                ),
306        )
307        .arg(
308            Arg::new("num-samples")
309                .short('s')
310                .long("num-samples")
311                .action(ArgAction::Set)
312                .value_name("COUNT")
313                .value_parser(value_parser!(usize))
314                .default_value("3")
315                .help("How many samples to take per benchmark run.")
316        )
317        .arg(
318            Arg::new("num-streams")
319                .short('p')
320                .long("streams")
321                .aliases(["num-parallel"])
322                .action(ArgAction::Set)
323                .value_name("COUNT")
324                .value_parser(value_parser!(usize))
325                .default_value("3")
326                .help("How many simultaneous streams per circuit.")
327        )
328        .arg(
329            Arg::new("num-circuits")
330                .short('C')
331                .long("num-circuits")
332                .action(ArgAction::Set)
333                .value_name("COUNT")
334                .value_parser(value_parser!(usize))
335                .default_value("1")
336                .help("How many simultaneous circuits per run.")
337        )
338        .arg(
339            Arg::new("output")
340                .short('o')
341                .action(ArgAction::Set)
342                .value_name("/path/to/output.json")
343                .help("A path to write benchmark results to, in JSON format.")
344        )
345        .arg(
346            Arg::new("download-bytes")
347                .short('d')
348                .long("download-bytes")
349                .action(ArgAction::Set)
350                .value_name("SIZE")
351                .value_parser(value_parser!(usize))
352                .default_value("10485760")
353                .help("How much fake payload data to generate for the download benchmark."),
354        )
355        .arg(
356            Arg::new("upload-bytes")
357                .short('u')
358                .long("upload-bytes")
359                .action(ArgAction::Set)
360                .value_name("SIZE")
361                .value_parser(value_parser!(usize))
362                .default_value("10485760")
363                .help("How much fake payload data to generate for the upload benchmark."),
364        )
365        .arg(
366            Arg::new("socks-proxy")
367                .long("socks5")
368                .action(ArgAction::Set)
369                .value_name("addr:port")
370                .help("SOCKS5 proxy address for a node to benchmark through as well (usually a Chutney node). Optional."),
371        )
372        .get_matches();
373    info!("Parsing Arti configuration...");
374    let mut config_sources = ConfigurationSources::new_empty();
375    matches
376        .get_many::<OsString>("arti-config")
377        .unwrap_or_default()
378        .for_each(|f| {
379            config_sources.push_source(
380                ConfigurationSource::from_path(f),
381                tor_config::sources::MustRead::MustRead,
382            );
383        });
384
385    // TODO really we ought to get this from the arti configuration, or something.
386    // But this is OK for now since we are a benchmarking tool.
387    let mistrust = fs_mistrust::Mistrust::new_dangerously_trust_everyone();
388    config_sources.set_mistrust(mistrust);
389
390    let cfg = config_sources.load()?;
391    let (_config, tcc) = tor_config::resolve::<ArtiCombinedConfig>(cfg)?;
392    info!("Binding local TCP listener...");
393    let listener = TcpListener::bind("0.0.0.0:0")?;
394    let local_addr = listener.local_addr()?;
395    let connect_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), local_addr.port());
396    info!("Bound to {}.", local_addr);
397    let upload_bytes = *matches.get_one::<usize>("upload-bytes").unwrap();
398    let download_bytes = *matches.get_one::<usize>("download-bytes").unwrap();
399    let samples = *matches.get_one::<usize>("num-samples").unwrap();
400    let streams_per_circ = *matches.get_one::<usize>("num-streams").unwrap();
401    let circs_per_sample = *matches.get_one::<usize>("num-circuits").unwrap();
402    info!("Generating test payloads, please wait...");
403    let upload_payload = random_payload(upload_bytes).into();
404    let download_payload = random_payload(download_bytes).into();
405    info!(
406        "Generated payloads ({} upload, {} download)",
407        upload_bytes, download_bytes
408    );
409    let up = Arc::clone(&upload_payload);
410    let dp = Arc::clone(&download_payload);
411    let _handle = std::thread::spawn(move || -> Result<()> {
412        serve_payload(&listener, &dp, &up)
413            .into_iter()
414            .try_for_each(|handle| handle.join().expect("failed to join thread"))
415    });
416
417    let mut benchmark = Benchmark {
418        connect_addr,
419        samples,
420        streams_per_circ,
421        circs_per_sample,
422        upload_payload,
423        download_payload,
424        runtime: tor_rtcompat::tokio::TokioNativeTlsRuntime::create()?,
425        results: Default::default(),
426    };
427
428    benchmark.without_arti()?;
429    if let Some(addr) = matches.get_one::<String>("socks-proxy") {
430        benchmark.with_proxy(addr)?;
431    }
432    benchmark.with_arti(tcc)?;
433
434    info!("Benchmarking complete.");
435
436    for (ty, results) in benchmark.results.iter() {
437        info!(
438            "Information for benchmark type {:?} ({} samples taken):",
439            ty, benchmark.samples
440        );
441        info!("  upload rate: {} Mbit/s", results.upload_rate_megabit);
442        info!("download rate: {} Mbit/s", results.upload_rate_megabit);
443        info!("    TTFB (up): {} msec", results.upload_ttfb_msec);
444        info!("  TTFB (down): {} msec", results.download_ttfb_msec);
445    }
446
447    if let Some(output) = matches.get_one::<String>("output") {
448        info!("Writing benchmark results to {}...", output);
449        let file = std::fs::File::create(output)?;
450        serde_json::to_writer(
451            &file,
452            &BenchmarkSummary {
453                crate_version: env!("CARGO_PKG_VERSION").to_string(),
454                results: benchmark.results,
455            },
456        )?;
457    }
458
459    Ok(())
460}
461
462/// A helper struct for running benchmarks
463#[allow(clippy::missing_docs_in_private_items)]
464struct Benchmark<R>
465where
466    R: ToplevelRuntime,
467{
468    runtime: R,
469    connect_addr: SocketAddr,
470    samples: usize,
471    streams_per_circ: usize,
472    circs_per_sample: usize,
473    upload_payload: Arc<[u8]>,
474    download_payload: Arc<[u8]>,
475    /// All benchmark results conducted, indexed by benchmark type.
476    results: HashMap<BenchmarkType, BenchmarkResults>,
477}
478
479/// The type of benchmark conducted.
480#[derive(Clone, Copy, Serialize, Deserialize, Hash, Debug, PartialEq, Eq)]
481enum BenchmarkType {
482    /// Use the benchmark server on its own, without using any proxy.
483    ///
484    /// This is useful to get an idea of how well the benchmarking utility performs on its own.
485    RawLoopback,
486    /// Benchmark via a SOCKS5 proxy (usually that of a chutney node).
487    Socks,
488    /// Benchmark via Arti.
489    Arti,
490}
491
492#[derive(Clone, Serialize, Debug)]
493/// Some information about a set of benchmark samples collected during multiple runs.
494struct Statistic {
495    /// The mean value of all samples.
496    mean: f64,
497    /// The low-median value of all samples.
498    /// # Important note
499    ///
500    /// This is only the median if an odd number of samples were collected; otherwise,
501    /// it is the `(number of samples / 2)`th sample after the samples are sorted.
502    median: f64,
503    /// The minimum sample observed.
504    min: f64,
505    /// The maximum sample observed.
506    max: f64,
507    /// The standard deviation of the set of samples.
508    stddev: f64,
509}
510
511impl fmt::Display for Statistic {
512    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
513        let Statistic {
514            mean,
515            median,
516            min,
517            max,
518            stddev,
519        } = self;
520        write!(
521            f,
522            "min/mean/median/max/stddev = {:>7.2}/{:>7.2}/{:>7.2}/{:>7.2}/{:>7.2}",
523            min, mean, median, max, stddev
524        )
525    }
526}
527
528impl Statistic {
529    /// Generate a summary of the provided `samples`.
530    ///
531    /// # Panics
532    ///
533    /// Panics if `samples` is empty.
534    fn from_samples(mut samples: Vec<f64>) -> Self {
535        let n_samples = samples.len();
536        float_ord::sort(&mut samples);
537        let mean = samples.iter().sum::<f64>() / n_samples as f64;
538        // \Sigma (x_i - \mu)^2
539        let samples_minus_mean_sum = samples.iter().map(|xi| (xi - mean).powf(2.0)).sum::<f64>();
540        let stddev = (samples_minus_mean_sum / n_samples as f64).sqrt();
541        Statistic {
542            mean,
543            median: samples[n_samples / 2],
544            min: samples[0],
545            max: samples[n_samples - 1],
546            stddev,
547        }
548    }
549}
550
551/// A set of benchmark results for a given `BenchmarkType`, including information about averages.
552#[derive(Clone, Serialize, Debug)]
553struct BenchmarkResults {
554    /// The type of benchmark conducted.
555    ty: BenchmarkType,
556    /// The number of times the benchmark was run.
557    samples: usize,
558    /// The number of concurrent streams per circuit used during the run.
559    streams_per_circ: usize,
560    /// The number of circuits used during the run.
561    circuits: usize,
562    /// The time to first byte (TTFB) for the download benchmark, in milliseconds.
563    download_ttfb_msec: Statistic,
564    /// The average download speed, in megabits per second.
565    download_rate_megabit: Statistic,
566    /// The time to first byte (TTFB) for the upload benchmark, in milliseconds.
567    upload_ttfb_msec: Statistic,
568    /// The average upload speed, in megabits per second.
569    upload_rate_megabit: Statistic,
570
571    /// The raw benchmark results.
572    results_raw: Vec<TimingSummary>,
573}
574
575impl BenchmarkResults {
576    /// Generate summarized benchmark results from raw run data.
577    fn generate(
578        ty: BenchmarkType,
579        streams_per_circ: usize,
580        circuits: usize,
581        raw: Vec<TimingSummary>,
582    ) -> Self {
583        let download_ttfb_msecs = raw
584            .iter()
585            .map(|s| s.download_ttfb_sec * 1000.0)
586            .collect::<Vec<_>>();
587        let download_rate_megabits = raw
588            .iter()
589            .map(|s| s.download_rate_megabit)
590            .collect::<Vec<_>>();
591        let upload_ttfb_msecs = raw
592            .iter()
593            .map(|s| s.upload_ttfb_sec * 1000.0)
594            .collect::<Vec<_>>();
595        let upload_rate_megabits = raw
596            .iter()
597            .map(|s| s.upload_rate_megabit)
598            .collect::<Vec<_>>();
599        let samples = raw.len();
600        BenchmarkResults {
601            ty,
602            samples,
603            streams_per_circ,
604            circuits,
605            download_ttfb_msec: Statistic::from_samples(download_ttfb_msecs),
606            download_rate_megabit: Statistic::from_samples(download_rate_megabits),
607            upload_ttfb_msec: Statistic::from_samples(upload_ttfb_msecs),
608            upload_rate_megabit: Statistic::from_samples(upload_rate_megabits),
609            results_raw: raw,
610        }
611    }
612}
613
614/// A summary of all benchmarks conducted throughout the invocation of `arti-bench`.
615///
616/// Designed to be stored as an artifact and compared against other later runs.
617#[derive(Clone, Serialize, Debug)]
618struct BenchmarkSummary {
619    /// The version of `arti-bench` used to generate the benchmark results.
620    crate_version: String,
621    /// All benchmark results conducted, indexed by benchmark type.
622    results: HashMap<BenchmarkType, BenchmarkResults>,
623}
624
625impl<R: ToplevelRuntime> Benchmark<R> {
626    /// Run a type of benchmark (`ty`), performing `self.samples` benchmark
627    /// runs, using `self.circs_per_sample` concurrent circuits, and
628    /// `self.streams_per_circ` concurrent streams on each circuit.
629    ///
630    /// Uses `stream_generator`, function that returns futures that themselves
631    /// generate streams, in order to obtain the required number of streams to
632    /// run the test over.  The function takes an index of the current run.
633    fn run<F, G, S, E>(&mut self, ty: BenchmarkType, mut stream_generator: F) -> Result<()>
634    where
635        F: FnMut(usize) -> G,
636        G: Future<Output = Result<S, E>>,
637        S: AsyncRead + AsyncWrite + Unpin,
638        E: std::error::Error + Send + Sync + 'static,
639    {
640        let mut results = vec![];
641        for n in 0..self.samples {
642            let total_streams = self.streams_per_circ * self.circs_per_sample;
643            let futures = (0..total_streams)
644                .map(|_| {
645                    let up = Arc::clone(&self.upload_payload);
646                    let dp = Arc::clone(&self.download_payload);
647                    let stream = stream_generator(n);
648                    Box::pin(async move { client(stream.await?, up, dp).await })
649                })
650                .collect::<futures::stream::FuturesUnordered<_>>()
651                .collect::<Vec<_>>();
652            info!(
653                "Benchmarking {:?} with {} connections, run {}/{}...",
654                ty,
655                self.streams_per_circ,
656                n + 1,
657                self.samples
658            );
659            let stats = self
660                .runtime
661                .block_on(futures)
662                .into_iter()
663                .map(|x| x.and_then(|x| TimingSummary::generate(&x)))
664                .collect::<Result<Vec<_>>>()?;
665            results.extend(stats);
666        }
667        let results =
668            BenchmarkResults::generate(ty, self.streams_per_circ, self.circs_per_sample, results);
669        self.results.insert(ty, results);
670        Ok(())
671    }
672
673    /// Benchmark without Arti on loopback.
674    fn without_arti(&mut self) -> Result<()> {
675        let ca = self.connect_addr;
676        self.run(BenchmarkType::RawLoopback, |_| {
677            tokio::net::TcpStream::connect(ca)
678        })
679    }
680
681    /// Benchmark through a SOCKS5 proxy at address `addr`.
682    fn with_proxy(&mut self, addr: &str) -> Result<()> {
683        let ca = self.connect_addr;
684        let mut iso = StreamIsolationTracker::new(self.streams_per_circ);
685
686        self.run(BenchmarkType::Socks, |run| {
687            // Tor uses the username,password tuple of socks authentication do decide how to isolate streams.
688            let iso_string = format!("{:?}", iso.next_in(run));
689            async move {
690                Socks5Stream::connect_with_password(addr, ca, &iso_string, &iso_string).await
691            }
692        })
693    }
694
695    /// Benchmark through Arti, using the provided `TorClientConfig`.
696    fn with_arti(&mut self, tcc: TorClientConfig) -> Result<()> {
697        info!("Starting Arti...");
698        let tor_client = self.runtime.block_on(
699            TorClient::with_runtime(self.runtime.clone())
700                .config(tcc)
701                .create_bootstrapped(),
702        )?;
703
704        let addr = TorAddr::dangerously_from(self.connect_addr)?;
705
706        let mut iso = StreamIsolationTracker::new(self.streams_per_circ);
707
708        self.run(BenchmarkType::Arti, |run| {
709            let mut prefs = arti_client::StreamPrefs::new();
710            prefs.set_isolation(iso.next_in(run));
711
712            tor_client.connect(addr.clone())
713        })
714    }
715}
716
717/// Helper type: track a StreamIsolation token over a set of runs.
718///
719/// We want to return a new token every `streams_per_circ` calls for each run,
720/// but always give a new token when a new run begins.
721#[derive(Debug, Clone)]
722struct StreamIsolationTracker {
723    /// The number of streams to assign to each circuit.
724    streams_per_circ: usize,
725    /// The current run index.
726    cur_run: usize,
727    /// The stream index within the run that we expect on the _next_ call to `next_in`.
728    next_stream: usize,
729    /// The isolation token we're currently handing out.
730    cur_token: IsolationToken,
731}
732
733impl StreamIsolationTracker {
734    /// Construct a new StreamIsolationTracker.
735    fn new(streams_per_circ: usize) -> Self {
736        Self {
737            streams_per_circ,
738            cur_run: 0,
739            next_stream: 0,
740            cur_token: IsolationToken::new(),
741        }
742    }
743    /// Return the isolation token to use for the next stream in the given
744    /// `run`.  Requires that runs are not interleaved.
745    fn next_in(&mut self, run: usize) -> IsolationToken {
746        if run != self.cur_run {
747            self.cur_run = run;
748            self.next_stream = 0;
749            self.cur_token = IsolationToken::new();
750        } else if self.next_stream % self.streams_per_circ == 0 {
751            self.cur_token = IsolationToken::new();
752        }
753        self.next_stream += 1;
754
755        self.cur_token
756    }
757}
758
759#[cfg(test)]
760mod test {
761    // @@ begin test lint list maintained by maint/add_warning @@
762    #![allow(clippy::bool_assert_comparison)]
763    #![allow(clippy::clone_on_copy)]
764    #![allow(clippy::dbg_macro)]
765    #![allow(clippy::mixed_attributes_style)]
766    #![allow(clippy::print_stderr)]
767    #![allow(clippy::print_stdout)]
768    #![allow(clippy::single_char_pattern)]
769    #![allow(clippy::unwrap_used)]
770    #![allow(clippy::unchecked_duration_subtraction)]
771    #![allow(clippy::useless_vec)]
772    #![allow(clippy::needless_pass_by_value)]
773    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
774    use super::StreamIsolationTracker;
775
776    #[test]
777    fn test_iso_tracker() {
778        let mut tr = StreamIsolationTracker::new(2);
779        let r1: Vec<_> = (0..9).map(|_| tr.next_in(0)).collect();
780        let r2: Vec<_> = (0..6).map(|_| tr.next_in(1)).collect();
781        assert_eq!(r1[0], r1[1]);
782        assert_ne!(r1[1], r1[2]);
783        assert_eq!(r1[2], r1[3]);
784        assert_eq!(r2[0], r2[1]);
785        assert_ne!(r2[1], r2[2]);
786        assert!(!r1.contains(&r2[0]));
787    }
788}