tor_rtcompat/
compound.rs

1//! Define a [`CompoundRuntime`] part that can be built from several component
2//! pieces.
3
4use std::{net, sync::Arc, time::Duration};
5
6use crate::traits::*;
7use crate::{CoarseInstant, CoarseTimeProvider};
8use async_trait::async_trait;
9use educe::Educe;
10use futures::{future::FutureObj, task::Spawn};
11use std::future::Future;
12use std::io::Result as IoResult;
13use std::time::{Instant, SystemTime};
14use tor_general_addr::unix;
15
16/// A runtime made of several parts, each of which implements one trait-group.
17///
18/// The `TaskR` component should implement [`Spawn`], [`Blocking`] and maybe [`ToplevelBlockOn`];
19/// the `SleepR` component should implement [`SleepProvider`];
20/// the `CoarseTimeR` component should implement [`CoarseTimeProvider`];
21/// the `TcpR` component should implement [`NetStreamProvider`] for [`net::SocketAddr`];
22/// the `UnixR` component should implement [`NetStreamProvider`] for [`unix::SocketAddr`];
23/// and
24/// the `TlsR` component should implement [`TlsProvider`].
25///
26/// You can use this structure to create new runtimes in two ways: either by
27/// overriding a single part of an existing runtime, or by building an entirely
28/// new runtime from pieces.
29#[derive(Educe)]
30#[educe(Clone)] // #[derive(Clone)] wrongly infers Clone bounds on the generic parameters
31pub struct CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
32    /// The actual collection of Runtime objects.
33    ///
34    /// We wrap this in an Arc rather than requiring that each item implement
35    /// Clone, though we could change our minds later on.
36    inner: Arc<Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
37}
38
39/// A collection of objects implementing that traits that make up a [`Runtime`]
40struct Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
41    /// A `Spawn` and `BlockOn` implementation.
42    spawn: TaskR,
43    /// A `SleepProvider` implementation.
44    sleep: SleepR,
45    /// A `CoarseTimeProvider`` implementation.
46    coarse_time: CoarseTimeR,
47    /// A `NetStreamProvider<net::SocketAddr>` implementation
48    tcp: TcpR,
49    /// A `NetStreamProvider<unix::SocketAddr>` implementation.
50    unix: UnixR,
51    /// A `TlsProvider<TcpR::TcpStream>` implementation.
52    tls: TlsR,
53    /// A `UdpProvider` implementation
54    udp: UdpR,
55}
56
57impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
58    CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
59{
60    /// Construct a new CompoundRuntime from its components.
61    pub fn new(
62        spawn: TaskR,
63        sleep: SleepR,
64        coarse_time: CoarseTimeR,
65        tcp: TcpR,
66        unix: UnixR,
67        tls: TlsR,
68        udp: UdpR,
69    ) -> Self {
70        #[allow(clippy::arc_with_non_send_sync)]
71        CompoundRuntime {
72            inner: Arc::new(Inner {
73                spawn,
74                sleep,
75                coarse_time,
76                tcp,
77                unix,
78                tls,
79                udp,
80            }),
81        }
82    }
83}
84
85impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
86    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
87where
88    TaskR: Spawn,
89{
90    #[inline]
91    #[track_caller]
92    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
93        self.inner.spawn.spawn_obj(future)
94    }
95}
96
97impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Blocking
98    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
99where
100    TaskR: Blocking,
101    SleepR: Clone + Send + Sync + 'static,
102    CoarseTimeR: Clone + Send + Sync + 'static,
103    TcpR: Clone + Send + Sync + 'static,
104    UnixR: Clone + Send + Sync + 'static,
105    TlsR: Clone + Send + Sync + 'static,
106    UdpR: Clone + Send + Sync + 'static,
107{
108    type ThreadHandle<T: Send + 'static> = TaskR::ThreadHandle<T>;
109
110    #[inline]
111    #[track_caller]
112    fn spawn_blocking<F, T>(&self, f: F) -> TaskR::ThreadHandle<T>
113    where
114        F: FnOnce() -> T + Send + 'static,
115        T: Send + 'static,
116    {
117        self.inner.spawn.spawn_blocking(f)
118    }
119
120    #[inline]
121    #[track_caller]
122    fn reenter_block_on<F>(&self, future: F) -> F::Output
123    where
124        F: Future,
125        F::Output: Send + 'static,
126    {
127        self.inner.spawn.reenter_block_on(future)
128    }
129
130    #[track_caller]
131    fn blocking_io<F, T>(&self, f: F) -> impl futures::Future<Output = T>
132    where
133        F: FnOnce() -> T + Send + 'static,
134        T: Send + 'static,
135    {
136        self.inner.spawn.blocking_io(f)
137    }
138}
139
140impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> ToplevelBlockOn
141    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
142where
143    TaskR: ToplevelBlockOn,
144    SleepR: Clone + Send + Sync + 'static,
145    CoarseTimeR: Clone + Send + Sync + 'static,
146    TcpR: Clone + Send + Sync + 'static,
147    UnixR: Clone + Send + Sync + 'static,
148    TlsR: Clone + Send + Sync + 'static,
149    UdpR: Clone + Send + Sync + 'static,
150{
151    #[inline]
152    #[track_caller]
153    fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
154        self.inner.spawn.block_on(future)
155    }
156}
157
158impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
159    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
160where
161    SleepR: SleepProvider,
162    TaskR: Clone + Send + Sync + 'static,
163    CoarseTimeR: Clone + Send + Sync + 'static,
164    TcpR: Clone + Send + Sync + 'static,
165    UnixR: Clone + Send + Sync + 'static,
166    TlsR: Clone + Send + Sync + 'static,
167    UdpR: Clone + Send + Sync + 'static,
168{
169    type SleepFuture = SleepR::SleepFuture;
170
171    #[inline]
172    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
173        self.inner.sleep.sleep(duration)
174    }
175
176    #[inline]
177    fn now(&self) -> Instant {
178        self.inner.sleep.now()
179    }
180
181    #[inline]
182    fn wallclock(&self) -> SystemTime {
183        self.inner.sleep.wallclock()
184    }
185}
186
187impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
188    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
189where
190    CoarseTimeR: CoarseTimeProvider,
191    SleepR: Clone + Send + Sync + 'static,
192    TaskR: Clone + Send + Sync + 'static,
193    CoarseTimeR: Clone + Send + Sync + 'static,
194    TcpR: Clone + Send + Sync + 'static,
195    UnixR: Clone + Send + Sync + 'static,
196    TlsR: Clone + Send + Sync + 'static,
197    UdpR: Clone + Send + Sync + 'static,
198{
199    #[inline]
200    fn now_coarse(&self) -> CoarseInstant {
201        self.inner.coarse_time.now_coarse()
202    }
203}
204
205#[async_trait]
206impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
207    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
208where
209    TcpR: NetStreamProvider<net::SocketAddr>,
210    TaskR: Send + Sync + 'static,
211    SleepR: Send + Sync + 'static,
212    CoarseTimeR: Send + Sync + 'static,
213    TcpR: Send + Sync + 'static,
214    UnixR: Clone + Send + Sync + 'static,
215    TlsR: Send + Sync + 'static,
216    UdpR: Send + Sync + 'static,
217{
218    type Stream = TcpR::Stream;
219
220    type Listener = TcpR::Listener;
221
222    #[inline]
223    async fn connect(&self, addr: &net::SocketAddr) -> IoResult<Self::Stream> {
224        self.inner.tcp.connect(addr).await
225    }
226
227    #[inline]
228    async fn listen(&self, addr: &net::SocketAddr) -> IoResult<Self::Listener> {
229        self.inner.tcp.listen(addr).await
230    }
231}
232
233#[async_trait]
234impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
235    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
236where
237    UnixR: NetStreamProvider<unix::SocketAddr>,
238    TaskR: Send + Sync + 'static,
239    SleepR: Send + Sync + 'static,
240    CoarseTimeR: Send + Sync + 'static,
241    TcpR: Send + Sync + 'static,
242    UnixR: Clone + Send + Sync + 'static,
243    TlsR: Send + Sync + 'static,
244    UdpR: Send + Sync + 'static,
245{
246    type Stream = UnixR::Stream;
247
248    type Listener = UnixR::Listener;
249
250    #[inline]
251    async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
252        self.inner.unix.connect(addr).await
253    }
254
255    #[inline]
256    async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
257        self.inner.unix.listen(addr).await
258    }
259}
260
261impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
262    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
263where
264    TcpR: NetStreamProvider,
265    TlsR: TlsProvider<S>,
266    UnixR: Clone + Send + Sync + 'static,
267    SleepR: Clone + Send + Sync + 'static,
268    CoarseTimeR: Clone + Send + Sync + 'static,
269    TaskR: Clone + Send + Sync + 'static,
270    UdpR: Clone + Send + Sync + 'static,
271    S: StreamOps,
272{
273    type Connector = TlsR::Connector;
274    type TlsStream = TlsR::TlsStream;
275
276    #[inline]
277    fn tls_connector(&self) -> Self::Connector {
278        self.inner.tls.tls_connector()
279    }
280
281    #[inline]
282    fn supports_keying_material_export(&self) -> bool {
283        self.inner.tls.supports_keying_material_export()
284    }
285}
286
287impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
288    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
289{
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        f.debug_struct("CompoundRuntime").finish_non_exhaustive()
292    }
293}
294
295#[async_trait]
296impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
297    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
298where
299    UdpR: UdpProvider,
300    TaskR: Send + Sync + 'static,
301    SleepR: Send + Sync + 'static,
302    CoarseTimeR: Send + Sync + 'static,
303    TcpR: Send + Sync + 'static,
304    UnixR: Clone + Send + Sync + 'static,
305    TlsR: Send + Sync + 'static,
306    UdpR: Send + Sync + 'static,
307{
308    type UdpSocket = UdpR::UdpSocket;
309
310    #[inline]
311    async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
312        self.inner.udp.bind(addr).await
313    }
314}
315
316/// Module to seal RuntimeSubstExt
317mod sealed {
318    /// Helper for sealing RuntimeSubstExt
319    #[allow(unreachable_pub)]
320    pub trait Sealed {}
321}
322/// Extension trait on Runtime:
323/// Construct new Runtimes that replace part of an original runtime.
324///
325/// (If you need to do more complicated versions of this, you should likely construct
326/// CompoundRuntime directly.)
327pub trait RuntimeSubstExt: sealed::Sealed + Sized {
328    /// Return a new runtime wrapping this runtime, but replacing its TCP NetStreamProvider.
329    fn with_tcp_provider<T>(
330        &self,
331        new_tcp: T,
332    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
333    /// Return a new runtime wrapping this runtime, but replacing its SleepProvider.
334    fn with_sleep_provider<T>(
335        &self,
336        new_sleep: T,
337    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
338    /// Return a new runtime wrapping this runtime, but replacing its CoarseTimeProvider.
339    fn with_coarse_time_provider<T>(
340        &self,
341        new_coarse_time: T,
342    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
343}
344impl<R: Runtime> sealed::Sealed for R {}
345impl<R: Runtime + Sized> RuntimeSubstExt for R {
346    fn with_tcp_provider<T>(
347        &self,
348        new_tcp: T,
349    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
350        CompoundRuntime::new(
351            self.clone(),
352            self.clone(),
353            self.clone(),
354            new_tcp,
355            self.clone(),
356            self.clone(),
357            self.clone(),
358        )
359    }
360
361    fn with_sleep_provider<T>(
362        &self,
363        new_sleep: T,
364    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
365        CompoundRuntime::new(
366            self.clone(),
367            new_sleep,
368            self.clone(),
369            self.clone(),
370            self.clone(),
371            self.clone(),
372            self.clone(),
373        )
374    }
375
376    fn with_coarse_time_provider<T>(
377        &self,
378        new_coarse_time: T,
379    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
380        CompoundRuntime::new(
381            self.clone(),
382            self.clone(),
383            new_coarse_time,
384            self.clone(),
385            self.clone(),
386            self.clone(),
387            self.clone(),
388        )
389    }
390}