#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))]
#![doc = include_str!("../README.md")]
#![allow(renamed_and_removed_lints)] #![allow(unknown_lints)] #![warn(missing_docs)]
#![warn(noop_method_call)]
#![warn(unreachable_pub)]
#![warn(clippy::all)]
#![deny(clippy::await_holding_lock)]
#![deny(clippy::cargo_common_metadata)]
#![deny(clippy::cast_lossless)]
#![deny(clippy::checked_conversions)]
#![warn(clippy::cognitive_complexity)]
#![deny(clippy::debug_assert_with_mut_call)]
#![deny(clippy::exhaustive_enums)]
#![deny(clippy::exhaustive_structs)]
#![deny(clippy::expl_impl_clone_on_copy)]
#![deny(clippy::fallible_impl_from)]
#![deny(clippy::implicit_clone)]
#![deny(clippy::large_stack_arrays)]
#![warn(clippy::manual_ok_or)]
#![deny(clippy::missing_docs_in_private_items)]
#![warn(clippy::needless_borrow)]
#![warn(clippy::needless_pass_by_value)]
#![warn(clippy::option_option)]
#![deny(clippy::print_stderr)]
#![deny(clippy::print_stdout)]
#![warn(clippy::rc_buffer)]
#![deny(clippy::ref_option_ref)]
#![warn(clippy::semicolon_if_nothing_returned)]
#![warn(clippy::trait_duplication_in_bounds)]
#![deny(clippy::unchecked_duration_subtraction)]
#![deny(clippy::unnecessary_wraps)]
#![warn(clippy::unseparated_literal_suffix)]
#![deny(clippy::unwrap_used)]
#![allow(clippy::let_unit_value)] #![allow(clippy::uninlined_format_args)]
#![allow(clippy::significant_drop_in_scrutinee)] #![allow(clippy::result_large_err)] #![allow(clippy::needless_raw_string_hashes)] use std::future::Future;
use std::io::Error;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use arti_client::{DataStream, IntoTorAddr, TorClient};
use educe::Educe;
use hyper::client::connect::{Connected, Connection};
use hyper::http::uri::Scheme;
use hyper::http::Uri;
use hyper::service::Service;
use pin_project::pin_project;
use thiserror::Error;
use tls_api::TlsConnector as TlsConn; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tor_rtcompat::Runtime;
#[derive(Error, Clone, Debug)]
#[non_exhaustive]
pub enum ConnectionError {
#[error("unsupported URI scheme in {uri:?}")]
UnsupportedUriScheme {
uri: Uri,
},
#[error("Missing hostname in {uri:?}")]
MissingHostname {
uri: Uri,
},
#[error("Tor connection failed")]
Arti(#[from] arti_client::Error),
#[error("TLS connection failed")]
TLS(#[source] Arc<anyhow::Error>),
}
impl tor_error::HasKind for ConnectionError {
#[rustfmt::skip]
fn kind(&self) -> tor_error::ErrorKind {
use ConnectionError as CE;
use tor_error::ErrorKind as EK;
match self {
CE::UnsupportedUriScheme{..} => EK::NotImplemented,
CE::MissingHostname{..} => EK::BadApiUsage,
CE::Arti(e) => e.kind(),
CE::TLS(_) => EK::RemoteProtocolViolation,
}
}
}
#[derive(Educe)]
#[educe(Clone)] pub struct ArtiHttpConnector<R: Runtime, TC: TlsConn> {
client: TorClient<R>,
tls_conn: Arc<TC>,
}
impl<R: Runtime, TC: TlsConn> ArtiHttpConnector<R, TC> {
pub fn new(client: TorClient<R>, tls_conn: TC) -> Self {
let tls_conn = tls_conn.into();
Self { client, tls_conn }
}
}
#[pin_project]
pub struct ArtiHttpConnection<TC: TlsConn> {
#[pin]
inner: MaybeHttpsStream<TC>,
}
#[pin_project(project = MaybeHttpsStreamProj)]
enum MaybeHttpsStream<TC: TlsConn> {
Http(Pin<Box<DataStream>>), Https(#[pin] TC::TlsStream),
}
impl<TC: TlsConn> Connection for ArtiHttpConnection<TC> {
fn connected(&self) -> Connected {
Connected::new()
}
}
impl<TC: TlsConn> AsyncRead for ArtiHttpConnection<TC> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.project().inner.project() {
MaybeHttpsStreamProj::Http(ds) => ds.as_mut().poll_read(cx, buf),
MaybeHttpsStreamProj::Https(t) => t.poll_read(cx, buf),
}
}
}
impl<TC: TlsConn> AsyncWrite for ArtiHttpConnection<TC> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
match self.project().inner.project() {
MaybeHttpsStreamProj::Http(ds) => ds.as_mut().poll_write(cx, buf),
MaybeHttpsStreamProj::Https(t) => t.poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.project().inner.project() {
MaybeHttpsStreamProj::Http(ds) => ds.as_mut().poll_flush(cx),
MaybeHttpsStreamProj::Https(t) => t.poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.project().inner.project() {
MaybeHttpsStreamProj::Http(ds) => ds.as_mut().poll_shutdown(cx),
MaybeHttpsStreamProj::Https(t) => t.poll_shutdown(cx),
}
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum UseTls {
Bare,
Tls,
}
fn uri_to_host_port_tls(uri: Uri) -> Result<(String, u16, UseTls), ConnectionError> {
let use_tls = {
let scheme = uri.scheme();
if scheme == Some(&Scheme::HTTP) {
UseTls::Bare
} else if scheme == Some(&Scheme::HTTPS) {
UseTls::Tls
} else {
return Err(ConnectionError::UnsupportedUriScheme { uri });
}
};
let host = match uri.host() {
Some(h) => h,
_ => return Err(ConnectionError::MissingHostname { uri }),
};
let port = uri.port().map(|x| x.as_u16()).unwrap_or(match use_tls {
UseTls::Tls => 443,
UseTls::Bare => 80,
});
Ok((host.to_owned(), port, use_tls))
}
impl<R: Runtime, TC: TlsConn> Service<Uri> for ArtiHttpConnector<R, TC> {
type Response = ArtiHttpConnection<TC>;
type Error = ConnectionError;
#[allow(clippy::type_complexity)]
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Uri) -> Self::Future {
let client = self.client.clone();
let tls_conn = self.tls_conn.clone();
Box::pin(async move {
let (host, port, use_tls) = uri_to_host_port_tls(req)?;
let addr = (&host as &str, port)
.into_tor_addr()
.map_err(arti_client::Error::from)?;
let ds = client.connect(addr).await?;
let inner = match use_tls {
UseTls::Tls => {
let conn = tls_conn
.connect_impl_tls_stream(&host, ds)
.await
.map_err(|e| ConnectionError::TLS(e.into()))?;
MaybeHttpsStream::Https(conn)
}
UseTls::Bare => MaybeHttpsStream::Http(Box::new(ds).into()),
};
Ok(ArtiHttpConnection { inner })
})
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_duration_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
fn make_uri(url: &str) -> Uri {
url.parse::<Uri>().expect("Unable to parse uri")
}
#[test]
fn check_supported_uri_schemes() {
let unsupported = [
"wss://torproject.org",
"file://torproject.org",
"ftp://torproject.org",
"vnc://torproject.org",
"/no/scheme",
];
for url in unsupported {
assert!(uri_to_host_port_tls(make_uri(url)).is_err());
}
let supported = [
("https://torproject.org", UseTls::Tls),
("http://torproject.org", UseTls::Bare),
];
for (url, tls) in supported {
let (_ret_host, _ret_port, ret_tls) =
uri_to_host_port_tls(make_uri(url)).expect("function should return Result");
assert_eq!(ret_tls, tls);
}
}
#[test]
fn get_correct_port_and_tls_from_uri() {
let urls = [
("https://torproject.org:999", 999, UseTls::Tls),
("https://torproject.org:80", 80, UseTls::Tls),
("https://torproject.org", 443, UseTls::Tls),
("http://torproject.org:999", 999, UseTls::Bare),
("http://torproject.org:443", 443, UseTls::Bare),
("http://torproject.org", 80, UseTls::Bare),
];
for (url, port, tls) in urls {
let (_ret_host, ret_port, ret_tls) =
uri_to_host_port_tls(make_uri(url)).expect("function should return Result");
assert_eq!(ret_port, port);
assert_eq!(ret_tls, tls);
}
}
#[test]
fn get_correct_host_from_uri() {
let urls = [
("https://torproject.org", "torproject.org"),
("http://torproject.org", "torproject.org"),
];
for (url, host) in urls {
let (ret_host, _ret_port, _ret_tls) =
uri_to_host_port_tls(make_uri(url)).expect("function should return Result");
assert_eq!(ret_host, host);
}
}
}