1use std::{net, sync::Arc, time::Duration};
5
6use crate::traits::*;
7use crate::{CoarseInstant, CoarseTimeProvider};
8use async_trait::async_trait;
9use educe::Educe;
10use futures::{future::FutureObj, task::Spawn};
11use std::future::Future;
12use std::io::Result as IoResult;
13use std::time::{Instant, SystemTime};
14use tor_general_addr::unix;
15
16#[derive(Educe)]
30#[educe(Clone)] pub struct CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
32 inner: Arc<Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>>,
37}
38
39struct Inner<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> {
41 spawn: TaskR,
43 sleep: SleepR,
45 coarse_time: CoarseTimeR,
47 tcp: TcpR,
49 unix: UnixR,
51 tls: TlsR,
53 udp: UdpR,
55}
56
57impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
58 CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
59{
60 pub fn new(
62 spawn: TaskR,
63 sleep: SleepR,
64 coarse_time: CoarseTimeR,
65 tcp: TcpR,
66 unix: UnixR,
67 tls: TlsR,
68 udp: UdpR,
69 ) -> Self {
70 #[allow(clippy::arc_with_non_send_sync)]
71 CompoundRuntime {
72 inner: Arc::new(Inner {
73 spawn,
74 sleep,
75 coarse_time,
76 tcp,
77 unix,
78 tls,
79 udp,
80 }),
81 }
82 }
83}
84
85impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Spawn
86 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
87where
88 TaskR: Spawn,
89{
90 #[inline]
91 #[track_caller]
92 fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), futures::task::SpawnError> {
93 self.inner.spawn.spawn_obj(future)
94 }
95}
96
97impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> Blocking
98 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
99where
100 TaskR: Blocking,
101 SleepR: Clone + Send + Sync + 'static,
102 CoarseTimeR: Clone + Send + Sync + 'static,
103 TcpR: Clone + Send + Sync + 'static,
104 UnixR: Clone + Send + Sync + 'static,
105 TlsR: Clone + Send + Sync + 'static,
106 UdpR: Clone + Send + Sync + 'static,
107{
108 type ThreadHandle<T: Send + 'static> = TaskR::ThreadHandle<T>;
109
110 #[inline]
111 #[track_caller]
112 fn spawn_blocking<F, T>(&self, f: F) -> TaskR::ThreadHandle<T>
113 where
114 F: FnOnce() -> T + Send + 'static,
115 T: Send + 'static,
116 {
117 self.inner.spawn.spawn_blocking(f)
118 }
119
120 #[inline]
121 #[track_caller]
122 fn reenter_block_on<F>(&self, future: F) -> F::Output
123 where
124 F: Future,
125 F::Output: Send + 'static,
126 {
127 self.inner.spawn.reenter_block_on(future)
128 }
129
130 #[track_caller]
131 fn blocking_io<F, T>(&self, f: F) -> impl futures::Future<Output = T>
132 where
133 F: FnOnce() -> T + Send + 'static,
134 T: Send + 'static,
135 {
136 self.inner.spawn.blocking_io(f)
137 }
138}
139
140impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> ToplevelBlockOn
141 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
142where
143 TaskR: ToplevelBlockOn,
144 SleepR: Clone + Send + Sync + 'static,
145 CoarseTimeR: Clone + Send + Sync + 'static,
146 TcpR: Clone + Send + Sync + 'static,
147 UnixR: Clone + Send + Sync + 'static,
148 TlsR: Clone + Send + Sync + 'static,
149 UdpR: Clone + Send + Sync + 'static,
150{
151 #[inline]
152 #[track_caller]
153 fn block_on<F: futures::Future>(&self, future: F) -> F::Output {
154 self.inner.spawn.block_on(future)
155 }
156}
157
158impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> SleepProvider
159 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
160where
161 SleepR: SleepProvider,
162 TaskR: Clone + Send + Sync + 'static,
163 CoarseTimeR: Clone + Send + Sync + 'static,
164 TcpR: Clone + Send + Sync + 'static,
165 UnixR: Clone + Send + Sync + 'static,
166 TlsR: Clone + Send + Sync + 'static,
167 UdpR: Clone + Send + Sync + 'static,
168{
169 type SleepFuture = SleepR::SleepFuture;
170
171 #[inline]
172 fn sleep(&self, duration: Duration) -> Self::SleepFuture {
173 self.inner.sleep.sleep(duration)
174 }
175
176 #[inline]
177 fn now(&self) -> Instant {
178 self.inner.sleep.now()
179 }
180
181 #[inline]
182 fn wallclock(&self) -> SystemTime {
183 self.inner.sleep.wallclock()
184 }
185}
186
187impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> CoarseTimeProvider
188 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
189where
190 CoarseTimeR: CoarseTimeProvider,
191 SleepR: Clone + Send + Sync + 'static,
192 TaskR: Clone + Send + Sync + 'static,
193 CoarseTimeR: Clone + Send + Sync + 'static,
194 TcpR: Clone + Send + Sync + 'static,
195 UnixR: Clone + Send + Sync + 'static,
196 TlsR: Clone + Send + Sync + 'static,
197 UdpR: Clone + Send + Sync + 'static,
198{
199 #[inline]
200 fn now_coarse(&self) -> CoarseInstant {
201 self.inner.coarse_time.now_coarse()
202 }
203}
204
205#[async_trait]
206impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<net::SocketAddr>
207 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
208where
209 TcpR: NetStreamProvider<net::SocketAddr>,
210 TaskR: Send + Sync + 'static,
211 SleepR: Send + Sync + 'static,
212 CoarseTimeR: Send + Sync + 'static,
213 TcpR: Send + Sync + 'static,
214 UnixR: Clone + Send + Sync + 'static,
215 TlsR: Send + Sync + 'static,
216 UdpR: Send + Sync + 'static,
217{
218 type Stream = TcpR::Stream;
219
220 type Listener = TcpR::Listener;
221
222 #[inline]
223 async fn connect(&self, addr: &net::SocketAddr) -> IoResult<Self::Stream> {
224 self.inner.tcp.connect(addr).await
225 }
226
227 #[inline]
228 async fn listen(&self, addr: &net::SocketAddr) -> IoResult<Self::Listener> {
229 self.inner.tcp.listen(addr).await
230 }
231}
232
233#[async_trait]
234impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> NetStreamProvider<unix::SocketAddr>
235 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
236where
237 UnixR: NetStreamProvider<unix::SocketAddr>,
238 TaskR: Send + Sync + 'static,
239 SleepR: Send + Sync + 'static,
240 CoarseTimeR: Send + Sync + 'static,
241 TcpR: Send + Sync + 'static,
242 UnixR: Clone + Send + Sync + 'static,
243 TlsR: Send + Sync + 'static,
244 UdpR: Send + Sync + 'static,
245{
246 type Stream = UnixR::Stream;
247
248 type Listener = UnixR::Listener;
249
250 #[inline]
251 async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
252 self.inner.unix.connect(addr).await
253 }
254
255 #[inline]
256 async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
257 self.inner.unix.listen(addr).await
258 }
259}
260
261impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR, S> TlsProvider<S>
262 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
263where
264 TcpR: NetStreamProvider,
265 TlsR: TlsProvider<S>,
266 UnixR: Clone + Send + Sync + 'static,
267 SleepR: Clone + Send + Sync + 'static,
268 CoarseTimeR: Clone + Send + Sync + 'static,
269 TaskR: Clone + Send + Sync + 'static,
270 UdpR: Clone + Send + Sync + 'static,
271 S: StreamOps,
272{
273 type Connector = TlsR::Connector;
274 type TlsStream = TlsR::TlsStream;
275
276 #[inline]
277 fn tls_connector(&self) -> Self::Connector {
278 self.inner.tls.tls_connector()
279 }
280
281 #[inline]
282 fn supports_keying_material_export(&self) -> bool {
283 self.inner.tls.supports_keying_material_export()
284 }
285}
286
287impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> std::fmt::Debug
288 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
289{
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 f.debug_struct("CompoundRuntime").finish_non_exhaustive()
292 }
293}
294
295#[async_trait]
296impl<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR> UdpProvider
297 for CompoundRuntime<TaskR, SleepR, CoarseTimeR, TcpR, UnixR, TlsR, UdpR>
298where
299 UdpR: UdpProvider,
300 TaskR: Send + Sync + 'static,
301 SleepR: Send + Sync + 'static,
302 CoarseTimeR: Send + Sync + 'static,
303 TcpR: Send + Sync + 'static,
304 UnixR: Clone + Send + Sync + 'static,
305 TlsR: Send + Sync + 'static,
306 UdpR: Send + Sync + 'static,
307{
308 type UdpSocket = UdpR::UdpSocket;
309
310 #[inline]
311 async fn bind(&self, addr: &net::SocketAddr) -> IoResult<Self::UdpSocket> {
312 self.inner.udp.bind(addr).await
313 }
314}
315
316mod sealed {
318 #[allow(unreachable_pub)]
320 pub trait Sealed {}
321}
322pub trait RuntimeSubstExt: sealed::Sealed + Sized {
328 fn with_tcp_provider<T>(
330 &self,
331 new_tcp: T,
332 ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self>;
333 fn with_sleep_provider<T>(
335 &self,
336 new_sleep: T,
337 ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self>;
338 fn with_coarse_time_provider<T>(
340 &self,
341 new_coarse_time: T,
342 ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self>;
343}
344impl<R: Runtime> sealed::Sealed for R {}
345impl<R: Runtime + Sized> RuntimeSubstExt for R {
346 fn with_tcp_provider<T>(
347 &self,
348 new_tcp: T,
349 ) -> CompoundRuntime<Self, Self, Self, T, Self, Self, Self> {
350 CompoundRuntime::new(
351 self.clone(),
352 self.clone(),
353 self.clone(),
354 new_tcp,
355 self.clone(),
356 self.clone(),
357 self.clone(),
358 )
359 }
360
361 fn with_sleep_provider<T>(
362 &self,
363 new_sleep: T,
364 ) -> CompoundRuntime<Self, T, Self, Self, Self, Self, Self> {
365 CompoundRuntime::new(
366 self.clone(),
367 new_sleep,
368 self.clone(),
369 self.clone(),
370 self.clone(),
371 self.clone(),
372 self.clone(),
373 )
374 }
375
376 fn with_coarse_time_provider<T>(
377 &self,
378 new_coarse_time: T,
379 ) -> CompoundRuntime<Self, Self, T, Self, Self, Self, Self> {
380 CompoundRuntime::new(
381 self.clone(),
382 self.clone(),
383 new_coarse_time,
384 self.clone(),
385 self.clone(),
386 self.clone(),
387 self.clone(),
388 )
389 }
390}