1
//! Define a [`CompoundRuntime`] part that can be built from several component
2
//! pieces.
3

            
4
use std::{net, sync::Arc, time::Duration};
5

            
6
use crate::traits::*;
7
use crate::{CoarseInstant, CoarseTimeProvider};
8
use async_trait::async_trait;
9
use educe::Educe;
10
use futures::{future::FutureObj, task::Spawn};
11
use std::future::Future;
12
use std::io::Result as IoResult;
13
use std::time::{Instant, SystemTime};
14
use tor_general_addr::unix;
15

            
16
/// A runtime made of several parts, each of which implements one trait-group.
17
///
18
/// The `TaskR` component should implement [`Spawn`], [`Blocking`] and maybe [`ToplevelBlockOn`];
19
/// the `SleepR` component should implement [`SleepProvider`];
20
/// the `CoarseTimeR` component should implement [`CoarseTimeProvider`];
21
/// the `TcpR` component should implement [`NetStreamProvider`] for [`net::SocketAddr`];
22
/// the `UnixR` component should implement [`NetStreamProvider`] for [`unix::SocketAddr`];
23
/// and
24
/// the `TlsR` component should implement [`TlsProvider`].
25
///
26
/// You can use this structure to create new runtimes in two ways: either by
27
/// overriding a single part of an existing runtime, or by building an entirely
28
/// new runtime from pieces.
29
16543
#[derive(Educe)]
30
#[educe(Clone)] // #[derive(Clone)] wrongly infers Clone bounds on the generic parameters
31
pub struct CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
32
    /// The actual collection of Runtime objects.
33
    ///
34
    /// We wrap this in an Arc rather than requiring that each item implement
35
    /// Clone, though we could change our minds later on.
36
    inner: Arc<Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
37
}
38

            
39
/// A collection of objects implementing that traits that make up a [`Runtime`]
40
struct Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
41
    /// A `Spawn` and `BlockOn` implementation.
42
    spawn: TaskR,
43
    /// A `SleepProvider` implementation.
44
    sleep: SleepR,
45
    /// A `CoarseTimeProvider`` implementation.
46
    coarse_time: CoarseTimeR,
47
    /// A `NetStreamProvider<net::SocketAddr>` implementation
48
    tcp: TcpR,
49
    /// A `NetStreamProvider<unix::SocketAddr>` implementation.
50
    unix: UnixR,
51
    /// A `TlsProvider<TcpR::TcpStream>` implementation.
52
    tls: TlsR,
53
    /// A `UdpProvider` implementation
54
    udp: UdpR,
55
}
56

            
57
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
58
    CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
59
{
60
    /// Construct a new CompoundRuntime from its components.
61
18120
    pub fn new(
62
18120
        spawn: TaskR,
63
18120
        sleep: SleepR,
64
18120
        coarse_time: CoarseTimeR,
65
18120
        tcp: TcpR,
66
18120
        unix: UnixR,
67
18120
        tls: TlsR,
68
18120
        udp: UdpR,
69
18120
    ) -> Self {
70
18120
        #[allow(clippy::arc_with_non_send_sync)]
71
18120
        CompoundRuntime {
72
18120
            inner: Arc::new(Inner {
73
18120
                spawn,
74
18120
                sleep,
75
18120
                coarse_time,
76
18120
                tcp,
77
18120
                unix,
78
18120
                tls,
79
18120
                udp,
80
18120
            }),
81
18120
        }
82
18120
    }
83
}
84

            
85
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
86
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
87
where
88
    TaskR: Spawn,
89
{
90
    #[inline]
91
    #[track_caller]
92
824
    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
93
824
        self.inner.spawn.spawn_obj(future)
94
824
    }
95
}
96

            
97
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Blocking
98
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
99
where
100
    TaskR: Blocking,
101
    SleepR: Clone + Send + Sync + 'static,
102
    CoarseTimeR: Clone + Send + Sync + 'static,
103
    TcpR: Clone + Send + Sync + 'static,
104
    UnixR: Clone + Send + Sync + 'static,
105
    TlsR: Clone + Send + Sync + 'static,
106
    UdpR: Clone + Send + Sync + 'static,
107
{
108
    type ThreadHandle<T: Send + 'static> = TaskR::ThreadHandle<T>;
109

            
110
    #[inline]
111
    #[track_caller]
112
    fn spawn_blocking<F, T>(&self, f: F) -> TaskR::ThreadHandle<T>
113
    where
114
        F: FnOnce() -> T + Send + 'static,
115
        T: Send + 'static,
116
    {
117
        self.inner.spawn.spawn_blocking(f)
118
    }
119

            
120
    #[inline]
121
    #[track_caller]
122
    fn reenter_block_on<F>(&self, future: F) -> F::Output
123
    where
124
        F: Future,
125
        F::Output: Send + 'static,
126
    {
127
        self.inner.spawn.reenter_block_on(future)
128
    }
129

            
130
    #[track_caller]
131
    fn blocking_io<F, T>(&self, f: F) -> impl futures::Future<Output = T>
132
    where
133
        F: FnOnce() -> T + Send + 'static,
134
        T: Send + 'static,
135
    {
136
        self.inner.spawn.blocking_io(f)
137
    }
138
}
139

            
140
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> ToplevelBlockOn
141
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
142
where
143
    TaskR: ToplevelBlockOn,
144
    SleepR: Clone + Send + Sync + 'static,
145
    CoarseTimeR: Clone + Send + Sync + 'static,
146
    TcpR: Clone + Send + Sync + 'static,
147
    UnixR: Clone + Send + Sync + 'static,
148
    TlsR: Clone + Send + Sync + 'static,
149
    UdpR: Clone + Send + Sync + 'static,
150
{
151
    #[inline]
152
    #[track_caller]
153
720
    fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
154
720
        self.inner.spawn.block_on(future)
155
720
    }
156
}
157

            
158
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
159
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
160
where
161
    SleepR: SleepProvider,
162
    TaskR: Clone + Send + Sync + 'static,
163
    CoarseTimeR: Clone + Send + Sync + 'static,
164
    TcpR: Clone + Send + Sync + 'static,
165
    UnixR: Clone + Send + Sync + 'static,
166
    TlsR: Clone + Send + Sync + 'static,
167
    UdpR: Clone + Send + Sync + 'static,
168
{
169
    type SleepFuture = SleepR::SleepFuture;
170

            
171
    #[inline]
172
7782
    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
173
7782
        self.inner.sleep.sleep(duration)
174
7782
    }
175

            
176
    #[inline]
177
    fn now(&self) -> Instant {
178
        self.inner.sleep.now()
179
    }
180

            
181
    #[inline]
182
36
    fn wallclock(&self) -> SystemTime {
183
36
        self.inner.sleep.wallclock()
184
36
    }
185
}
186

            
187
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
188
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
189
where
190
    CoarseTimeR: CoarseTimeProvider,
191
    SleepR: Clone + Send + Sync + 'static,
192
    TaskR: Clone + Send + Sync + 'static,
193
    CoarseTimeR: Clone + Send + Sync + 'static,
194
    TcpR: Clone + Send + Sync + 'static,
195
    UnixR: Clone + Send + Sync + 'static,
196
    TlsR: Clone + Send + Sync + 'static,
197
    UdpR: Clone + Send + Sync + 'static,
198
{
199
    #[inline]
200
13703
    fn now_coarse(&self) -> CoarseInstant {
201
13703
        self.inner.coarse_time.now_coarse()
202
13703
    }
203
}
204

            
205
#[async_trait]
206
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
207
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
208
where
209
    TcpR: NetStreamProvider<net::SocketAddr>,
210
    TaskR: Send + Sync + 'static,
211
    SleepR: Send + Sync + 'static,
212
    CoarseTimeR: Send + Sync + 'static,
213
    TcpR: Send + Sync + 'static,
214
    UnixR: Clone + Send + Sync + 'static,
215
    TlsR: Send + Sync + 'static,
216
    UdpR: Send + Sync + 'static,
217
{
218
    type Stream = TcpR::Stream;
219

            
220
    type Listener = TcpR::Listener;
221

            
222
    #[inline]
223
52
    async fn connect(&self, addr: &net::SocketAddr) -> IoResult<Self::Stream> {
224
52
        self.inner.tcp.connect(addr).await
225
104
    }
226

            
227
    #[inline]
228
12
    async fn listen(&self, addr: &net::SocketAddr) -> IoResult<Self::Listener> {
229
12
        self.inner.tcp.listen(addr).await
230
24
    }
231
}
232

            
233
#[async_trait]
234
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
235
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
236
where
237
    UnixR: NetStreamProvider<unix::SocketAddr>,
238
    TaskR: Send + Sync + 'static,
239
    SleepR: Send + Sync + 'static,
240
    CoarseTimeR: Send + Sync + 'static,
241
    TcpR: Send + Sync + 'static,
242
    UnixR: Clone + Send + Sync + 'static,
243
    TlsR: Send + Sync + 'static,
244
    UdpR: Send + Sync + 'static,
245
{
246
    type Stream = UnixR::Stream;
247

            
248
    type Listener = UnixR::Listener;
249

            
250
    #[inline]
251
    async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
252
        self.inner.unix.connect(addr).await
253
    }
254

            
255
    #[inline]
256
    async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
257
        self.inner.unix.listen(addr).await
258
    }
259
}
260

            
261
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
262
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
263
where
264
    TcpR: NetStreamProvider,
265
    TlsR: TlsProvider<S>,
266
    UnixR: Clone + Send + Sync + 'static,
267
    SleepR: Clone + Send + Sync + 'static,
268
    CoarseTimeR: Clone + Send + Sync + 'static,
269
    TaskR: Clone + Send + Sync + 'static,
270
    UdpR: Clone + Send + Sync + 'static,
271
    S: StreamOps,
272
{
273
    type Connector = TlsR::Connector;
274
    type TlsStream = TlsR::TlsStream;
275

            
276
    #[inline]
277
22
    fn tls_connector(&self) -> Self::Connector {
278
22
        self.inner.tls.tls_connector()
279
22
    }
280

            
281
    #[inline]
282
    fn supports_keying_material_export(&self) -> bool {
283
        self.inner.tls.supports_keying_material_export()
284
    }
285
}
286

            
287
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
288
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
289
{
290
2
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291
2
        f.debug_struct("CompoundRuntime").finish_non_exhaustive()
292
2
    }
293
}
294

            
295
#[async_trait]
296
impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
297
    for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
298
where
299
    UdpR: UdpProvider,
300
    TaskR: Send + Sync + 'static,
301
    SleepR: Send + Sync + 'static,
302
    CoarseTimeR: Send + Sync + 'static,
303
    TcpR: Send + Sync + 'static,
304
    UnixR: Clone + Send + Sync + 'static,
305
    TlsR: Send + Sync + 'static,
306
    UdpR: Send + Sync + 'static,
307
{
308
    type UdpSocket = UdpR::UdpSocket;
309

            
310
    #[inline]
311
12
    async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
312
12
        self.inner.udp.bind(addr).await
313
24
    }
314
}
315

            
316
/// Module to seal RuntimeSubstExt
317
mod sealed {
318
    /// Helper for sealing RuntimeSubstExt
319
    #[allow(unreachable_pub)]
320
    pub trait Sealed {}
321
}
322
/// Extension trait on Runtime:
323
/// Construct new Runtimes that replace part of an original runtime.
324
///
325
/// (If you need to do more complicated versions of this, you should likely construct
326
/// CompoundRuntime directly.)
327
pub trait RuntimeSubstExt: sealed::Sealed + Sized {
328
    /// Return a new runtime wrapping this runtime, but replacing its TCP NetStreamProvider.
329
    fn with_tcp_provider<T>(
330
        &self,
331
        new_tcp: T,
332
    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
333
    /// Return a new runtime wrapping this runtime, but replacing its SleepProvider.
334
    fn with_sleep_provider<T>(
335
        &self,
336
        new_sleep: T,
337
    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
338
    /// Return a new runtime wrapping this runtime, but replacing its CoarseTimeProvider.
339
    fn with_coarse_time_provider<T>(
340
        &self,
341
        new_coarse_time: T,
342
    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
343
}
344
impl<R: Runtime> sealed::Sealed for R {}
345
impl<R: Runtime + Sized> RuntimeSubstExt for R {
346
    fn with_tcp_provider<T>(
347
        &self,
348
        new_tcp: T,
349
    ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
350
        CompoundRuntime::new(
351
            self.clone(),
352
            self.clone(),
353
            self.clone(),
354
            new_tcp,
355
            self.clone(),
356
            self.clone(),
357
            self.clone(),
358
        )
359
    }
360

            
361
12
    fn with_sleep_provider<T>(
362
12
        &self,
363
12
        new_sleep: T,
364
12
    ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
365
12
        CompoundRuntime::new(
366
12
            self.clone(),
367
12
            new_sleep,
368
12
            self.clone(),
369
12
            self.clone(),
370
12
            self.clone(),
371
12
            self.clone(),
372
12
            self.clone(),
373
12
        )
374
12
    }
375

            
376
12
    fn with_coarse_time_provider<T>(
377
12
        &self,
378
12
        new_coarse_time: T,
379
12
    ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
380
12
        CompoundRuntime::new(
381
12
            self.clone(),
382
12
            self.clone(),
383
12
            new_coarse_time,
384
12
            self.clone(),
385
12
            self.clone(),
386
12
            self.clone(),
387
12
            self.clone(),
388
12
        )
389
12
    }
390
}