use std::{net, sync::Arc, time::Duration};
use crate::traits::*;
use crate::{CoarseInstant, CoarseTimeProvider};
use async_trait::async_trait;
use educe::Educe;
use futures::{future::FutureObj, task::Spawn};
use std::io::Result as IoResult;
use std::time::{Instant, SystemTime};
use tor_general_addr::unix;
#[derive(Educe)]
#[educe(Clone)] pub struct CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
inner: Arc<Inner<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
}
struct Inner<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
spawn: SpawnR,
sleep: SleepR,
coarse_time: CoarseTimeR,
tcp: TcpR,
unix: UnixR,
tls: TlsR,
udp: UdpR,
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
{
pub fn new(
spawn: SpawnR,
sleep: SleepR,
coarse_time: CoarseTimeR,
tcp: TcpR,
unix: UnixR,
tls: TlsR,
udp: UdpR,
) -> Self {
#[allow(clippy::arc_with_non_send_sync)]
CompoundRuntime {
inner: Arc::new(Inner {
spawn,
sleep,
coarse_time,
tcp,
unix,
tls,
udp,
}),
}
}
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
where
SpawnR: Spawn,
{
#[inline]
#[track_caller]
fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
self.inner.spawn.spawn_obj(future)
}
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SpawnBlocking
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
where
SpawnR: SpawnBlocking,
SleepR: Clone + Send + Sync + 'static,
CoarseTimeR: Clone + Send + Sync + 'static,
TcpR: Clone + Send + Sync + 'static,
UnixR: Clone + Send + Sync + 'static,
TlsR: Clone + Send + Sync + 'static,
UdpR: Clone + Send + Sync + 'static,
{
type Handle<T: Send + 'static> = SpawnR::Handle<T>;
#[inline]
#[track_caller]
fn spawn_blocking<F, T>(&self, f: F) -> SpawnR::Handle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
self.inner.spawn.spawn_blocking(f)
}
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> BlockOn
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
where
SpawnR: BlockOn,
SleepR: Clone + Send + Sync + 'static,
CoarseTimeR: Clone + Send + Sync + 'static,
TcpR: Clone + Send + Sync + 'static,
UnixR: Clone + Send + Sync + 'static,
TlsR: Clone + Send + Sync + 'static,
UdpR: Clone + Send + Sync + 'static,
{
#[inline]
#[track_caller]
fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
self.inner.spawn.block_on(future)
}
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
where
SleepR: SleepProvider,
SpawnR: Clone + Send + Sync + 'static,
CoarseTimeR: Clone + Send + Sync + 'static,
TcpR: Clone + Send + Sync + 'static,
UnixR: Clone + Send + Sync + 'static,
TlsR: Clone + Send + Sync + 'static,
UdpR: Clone + Send + Sync + 'static,
{
type SleepFuture = SleepR::SleepFuture;
#[inline]
fn sleep(&self, duration: Duration) -> Self::SleepFuture {
self.inner.sleep.sleep(duration)
}
#[inline]
fn now(&self) -> Instant {
self.inner.sleep.now()
}
#[inline]
fn wallclock(&self) -> SystemTime {
self.inner.sleep.wallclock()
}
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
where
CoarseTimeR: CoarseTimeProvider,
SleepR: Clone + Send + Sync + 'static,
SpawnR: Clone + Send + Sync + 'static,
CoarseTimeR: Clone + Send + Sync + 'static,
TcpR: Clone + Send + Sync + 'static,
UnixR: Clone + Send + Sync + 'static,
TlsR: Clone + Send + Sync + 'static,
UdpR: Clone + Send + Sync + 'static,
{
#[inline]
fn now_coarse(&self) -> CoarseInstant {
self.inner.coarse_time.now_coarse()
}
}
#[async_trait]
impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
where
TcpR: NetStreamProvider<net::SocketAddr>,
SpawnR: Send + Sync + 'static,
SleepR: Send + Sync + 'static,
CoarseTimeR: Send + Sync + 'static,
TcpR: Send + Sync + 'static,
UnixR: Clone + Send + Sync + 'static,
TlsR: Send + Sync + 'static,
UdpR: Send + Sync + 'static,
{
type Stream = TcpR::Stream;
type Listener = TcpR::Listener;
#[inline]
async fn connect(&self, addr: &net::SocketAddr) -> IoResult<Self::Stream> {
self.inner.tcp.connect(addr).await
}
#[inline]
async fn listen(&self, addr: &net::SocketAddr) -> IoResult<Self::Listener> {
self.inner.tcp.listen(addr).await
}
}
#[async_trait]
impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
where
UnixR: NetStreamProvider<unix::SocketAddr>,
SpawnR: Send + Sync + 'static,
SleepR: Send + Sync + 'static,
CoarseTimeR: Send + Sync + 'static,
TcpR: Send + Sync + 'static,
UnixR: Clone + Send + Sync + 'static,
TlsR: Send + Sync + 'static,
UdpR: Send + Sync + 'static,
{
type Stream = UnixR::Stream;
type Listener = UnixR::Listener;
#[inline]
async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
self.inner.unix.connect(addr).await
}
#[inline]
async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
self.inner.unix.listen(addr).await
}
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
where
TcpR: NetStreamProvider,
TlsR: TlsProvider<S>,
UnixR: Clone + Send + Sync + 'static,
SleepR: Clone + Send + Sync + 'static,
CoarseTimeR: Clone + Send + Sync + 'static,
SpawnR: Clone + Send + Sync + 'static,
UdpR: Clone + Send + Sync + 'static,
S: StreamOps,
{
type Connector = TlsR::Connector;
type TlsStream = TlsR::TlsStream;
#[inline]
fn tls_connector(&self) -> Self::Connector {
self.inner.tls.tls_connector()
}
#[inline]
fn supports_keying_material_export(&self) -> bool {
self.inner.tls.supports_keying_material_export()
}
}
impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompoundRuntime").finish_non_exhaustive()
}
}
#[async_trait]
impl<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
for CompoundRuntime<SpawnR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
where
UdpR: UdpProvider,
SpawnR: Send + Sync + 'static,
SleepR: Send + Sync + 'static,
CoarseTimeR: Send + Sync + 'static,
TcpR: Send + Sync + 'static,
UnixR: Clone + Send + Sync + 'static,
TlsR: Send + Sync + 'static,
UdpR: Send + Sync + 'static,
{
type UdpSocket = UdpR::UdpSocket;
#[inline]
async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
self.inner.udp.bind(addr).await
}
}
mod sealed {
#[allow(unreachable_pub)]
pub trait Sealed {}
}
pub trait RuntimeSubstExt: sealed::Sealed + Sized {
fn with_tcp_provider<T>(
&self,
new_tcp: T,
) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
fn with_sleep_provider<T>(
&self,
new_sleep: T,
) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
fn with_coarse_time_provider<T>(
&self,
new_coarse_time: T,
) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
}
impl<R: Runtime> sealed::Sealed for R {}
impl<R: Runtime + Sized> RuntimeSubstExt for R {
fn with_tcp_provider<T>(
&self,
new_tcp: T,
) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
CompoundRuntime::new(
self.clone(),
self.clone(),
self.clone(),
new_tcp,
self.clone(),
self.clone(),
self.clone(),
)
}
fn with_sleep_provider<T>(
&self,
new_sleep: T,
) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
CompoundRuntime::new(
self.clone(),
new_sleep,
self.clone(),
self.clone(),
self.clone(),
self.clone(),
self.clone(),
)
}
fn with_coarse_time_provider<T>(
&self,
new_coarse_time: T,
) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
CompoundRuntime::new(
self.clone(),
self.clone(),
new_coarse_time,
self.clone(),
self.clone(),
self.clone(),
self.clone(),
)
}
}