1#![forbid(unsafe_code)] use super::io::{stream_pair, LocalStream};
11use super::MockNetRuntime;
12use crate::util::mpsc_channel;
13use core::fmt;
14use tor_rtcompat::tls::TlsConnector;
15use tor_rtcompat::{
16 CertifiedConn, NetStreamListener, NetStreamProvider, Runtime, StreamOps, TlsProvider,
17};
18use tor_rtcompat::{UdpProvider, UdpSocket};
19
20use async_trait::async_trait;
21use futures::channel::mpsc;
22use futures::io::{AsyncRead, AsyncWrite};
23use futures::lock::Mutex as AsyncMutex;
24use futures::sink::SinkExt;
25use futures::stream::{Stream, StreamExt};
26use futures::FutureExt;
27use std::collections::HashMap;
28use std::fmt::Formatter;
29use std::io::{self, Error as IoError, ErrorKind, Result as IoResult};
30use std::net::{IpAddr, SocketAddr};
31use std::pin::Pin;
32use std::sync::atomic::{AtomicU16, Ordering};
33use std::sync::{Arc, Mutex};
34use std::task::{Context, Poll};
35use thiserror::Error;
36use void::Void;
37
38type ConnSender = mpsc::Sender<(LocalStream, SocketAddr)>;
41type ConnReceiver = mpsc::Receiver<(LocalStream, SocketAddr)>;
43
44#[derive(Default)]
51pub struct MockNetwork {
52 listening: Mutex<HashMap<SocketAddr, AddrBehavior>>,
54}
55
56#[derive(Clone)]
58struct ListenerEntry {
59 send: ConnSender,
62
63 tls_cert: Option<Vec<u8>>,
66}
67
68#[derive(Clone)]
70enum AddrBehavior {
71 Listener(ListenerEntry),
73 Timeout,
75}
76
77#[derive(Clone)]
115pub struct MockNetProvider {
116 inner: Arc<MockNetProviderInner>,
121}
122
123impl fmt::Debug for MockNetProvider {
124 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
125 f.debug_struct("MockNetProvider").finish_non_exhaustive()
126 }
127}
128
129struct MockNetProviderInner {
134 addrs: Vec<IpAddr>,
136 net: Arc<MockNetwork>,
138 next_port: AtomicU16,
143}
144
145pub struct MockNetListener {
149 addr: SocketAddr,
151 receiver: AsyncMutex<ConnReceiver>,
155}
156
157pub struct ProviderBuilder {
161 addrs: Vec<IpAddr>,
163 net: Arc<MockNetwork>,
165}
166
167impl Default for MockNetProvider {
168 fn default() -> Self {
169 Arc::new(MockNetwork::default()).builder().provider()
170 }
171}
172
173impl MockNetwork {
174 pub fn new() -> Arc<Self> {
176 Default::default()
177 }
178
179 pub fn builder(self: &Arc<Self>) -> ProviderBuilder {
192 ProviderBuilder {
193 addrs: vec![],
194 net: Arc::clone(self),
195 }
196 }
197
198 pub fn add_blackhole(&self, address: SocketAddr) -> IoResult<()> {
200 let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
201 if listener_map.contains_key(&address) {
202 return Err(err(ErrorKind::AddrInUse));
203 }
204 listener_map.insert(address, AddrBehavior::Timeout);
205 Ok(())
206 }
207
208 async fn send_connection(
217 &self,
218 source_addr: SocketAddr,
219 target_addr: SocketAddr,
220 peer_stream: LocalStream,
221 ) -> IoResult<Option<Vec<u8>>> {
222 let entry = {
223 let listener_map = self.listening.lock().expect("Poisoned lock for listener");
224 listener_map.get(&target_addr).cloned()
225 };
226 match entry {
227 Some(AddrBehavior::Listener(mut entry)) => {
228 if entry.send.send((peer_stream, source_addr)).await.is_ok() {
229 return Ok(entry.tls_cert);
230 }
231 Err(err(ErrorKind::ConnectionRefused))
232 }
233 Some(AddrBehavior::Timeout) => futures::future::pending().await,
234 None => Err(err(ErrorKind::ConnectionRefused)),
235 }
236 }
237
238 fn add_listener(&self, addr: SocketAddr, tls_cert: Option<Vec<u8>>) -> IoResult<ConnReceiver> {
246 let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
247 if listener_map.contains_key(&addr) {
248 return Err(err(ErrorKind::AddrInUse));
250 }
251
252 let (send, recv) = mpsc_channel(16);
253
254 let entry = ListenerEntry { send, tls_cert };
255
256 listener_map.insert(addr, AddrBehavior::Listener(entry));
257
258 Ok(recv)
259 }
260}
261
262impl ProviderBuilder {
263 pub fn add_address(&mut self, addr: IpAddr) -> &mut Self {
265 self.addrs.push(addr);
266 self
267 }
268 pub fn runtime<R: Runtime>(&self, runtime: R) -> super::MockNetRuntime<R> {
271 MockNetRuntime::new(runtime, self.provider())
272 }
273 pub fn provider(&self) -> MockNetProvider {
275 let inner = MockNetProviderInner {
276 addrs: self.addrs.clone(),
277 net: Arc::clone(&self.net),
278 next_port: AtomicU16::new(1),
279 };
280 MockNetProvider {
281 inner: Arc::new(inner),
282 }
283 }
284}
285
286impl NetStreamListener for MockNetListener {
287 type Stream = LocalStream;
288
289 type Incoming = Self;
290
291 fn local_addr(&self) -> IoResult<SocketAddr> {
292 Ok(self.addr)
293 }
294
295 fn incoming(self) -> Self {
296 self
297 }
298}
299
300impl Stream for MockNetListener {
301 type Item = IoResult<(LocalStream, SocketAddr)>;
302 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
303 let mut recv = futures::ready!(self.receiver.lock().poll_unpin(cx));
304 match recv.poll_next_unpin(cx) {
305 Poll::Pending => Poll::Pending,
306 Poll::Ready(None) => Poll::Ready(None),
307 Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
308 }
309 }
310}
311
312#[derive(Debug)]
314#[non_exhaustive]
315pub struct MockUdpSocket {
316 void: Void,
321}
322
323#[async_trait]
324impl UdpProvider for MockNetProvider {
325 type UdpSocket = MockUdpSocket;
326
327 async fn bind(&self, addr: &SocketAddr) -> IoResult<MockUdpSocket> {
328 let _ = addr; Err(io::ErrorKind::Unsupported.into())
330 }
331}
332
333#[allow(clippy::diverging_sub_expression)] #[async_trait]
335impl UdpSocket for MockUdpSocket {
336 async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
337 void::unreachable((self.void, buf).0)
341 }
342 async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
343 void::unreachable((self.void, buf, target).0)
344 }
345 fn local_addr(&self) -> IoResult<SocketAddr> {
346 void::unreachable(self.void)
347 }
348}
349
350impl MockNetProvider {
351 fn get_addr_in_family(&self, other: &IpAddr) -> Option<IpAddr> {
354 self.inner
355 .addrs
356 .iter()
357 .find(|a| a.is_ipv4() == other.is_ipv4())
358 .copied()
359 }
360
361 fn arbitrary_port(&self) -> u16 {
369 let next = self.inner.next_port.fetch_add(1, Ordering::Relaxed);
370 assert!(next != 0);
371 next
372 }
373
374 fn get_origin_addr_for(&self, addr: &SocketAddr) -> IoResult<SocketAddr> {
381 let my_addr = self
382 .get_addr_in_family(&addr.ip())
383 .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?;
384 Ok(SocketAddr::new(my_addr, self.arbitrary_port()))
385 }
386
387 fn get_listener_addr(&self, spec: &SocketAddr) -> IoResult<SocketAddr> {
397 let ipaddr = {
398 let ip = spec.ip();
399 if ip.is_unspecified() {
400 self.get_addr_in_family(&ip)
401 .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?
402 } else if self.inner.addrs.iter().any(|a| a == &ip) {
403 ip
404 } else {
405 return Err(err(ErrorKind::AddrNotAvailable));
406 }
407 };
408 let port = {
409 if spec.port() == 0 {
410 self.arbitrary_port()
411 } else {
412 spec.port()
413 }
414 };
415
416 Ok(SocketAddr::new(ipaddr, port))
417 }
418
419 pub fn listen_tls(&self, addr: &SocketAddr, tls_cert: Vec<u8>) -> IoResult<MockNetListener> {
425 let addr = self.get_listener_addr(addr)?;
426
427 let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, Some(tls_cert))?);
428
429 Ok(MockNetListener { addr, receiver })
430 }
431}
432
433#[async_trait]
434impl NetStreamProvider for MockNetProvider {
435 type Stream = LocalStream;
436 type Listener = MockNetListener;
437
438 async fn connect(&self, addr: &SocketAddr) -> IoResult<LocalStream> {
439 let my_addr = self.get_origin_addr_for(addr)?;
440 let (mut mine, theirs) = stream_pair();
441
442 let cert = self
443 .inner
444 .net
445 .send_connection(my_addr, *addr, theirs)
446 .await?;
447
448 mine.tls_cert = cert;
449
450 Ok(mine)
451 }
452
453 async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
454 let addr = self.get_listener_addr(addr)?;
455
456 let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, None)?);
457
458 Ok(MockNetListener { addr, receiver })
459 }
460}
461
462#[async_trait]
463impl TlsProvider<LocalStream> for MockNetProvider {
464 type Connector = MockTlsConnector;
465 type TlsStream = MockTlsStream;
466
467 fn tls_connector(&self) -> MockTlsConnector {
468 MockTlsConnector {}
469 }
470
471 fn supports_keying_material_export(&self) -> bool {
472 false
473 }
474}
475
476#[derive(Clone)]
481#[non_exhaustive]
482pub struct MockTlsConnector;
483
484pub struct MockTlsStream {
493 peer_cert: Option<Vec<u8>>,
495 stream: LocalStream,
497}
498
499#[async_trait]
500impl TlsConnector<LocalStream> for MockTlsConnector {
501 type Conn = MockTlsStream;
502
503 async fn negotiate_unvalidated(
504 &self,
505 mut stream: LocalStream,
506 _sni_hostname: &str,
507 ) -> IoResult<MockTlsStream> {
508 let peer_cert = stream.tls_cert.take();
509
510 if peer_cert.is_none() {
511 return Err(std::io::Error::other("attempted to wrap non-TLS stream!"));
512 }
513
514 Ok(MockTlsStream { peer_cert, stream })
515 }
516}
517
518impl CertifiedConn for MockTlsStream {
519 fn peer_certificate(&self) -> IoResult<Option<Vec<u8>>> {
520 Ok(self.peer_cert.clone())
521 }
522 fn export_keying_material(
523 &self,
524 _len: usize,
525 _label: &[u8],
526 _context: Option<&[u8]>,
527 ) -> IoResult<Vec<u8>> {
528 Ok(Vec::new())
529 }
530}
531
532impl AsyncRead for MockTlsStream {
533 fn poll_read(
534 mut self: Pin<&mut Self>,
535 cx: &mut Context<'_>,
536 buf: &mut [u8],
537 ) -> Poll<IoResult<usize>> {
538 Pin::new(&mut self.stream).poll_read(cx, buf)
539 }
540}
541impl AsyncWrite for MockTlsStream {
542 fn poll_write(
543 mut self: Pin<&mut Self>,
544 cx: &mut Context<'_>,
545 buf: &[u8],
546 ) -> Poll<IoResult<usize>> {
547 Pin::new(&mut self.stream).poll_write(cx, buf)
548 }
549 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
550 Pin::new(&mut self.stream).poll_flush(cx)
551 }
552 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
553 Pin::new(&mut self.stream).poll_close(cx)
554 }
555}
556
557impl StreamOps for MockTlsStream {
558 fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
559 Err(std::io::Error::new(
560 std::io::ErrorKind::Unsupported,
561 "not supported on non-StreamOps stream!",
562 ))
563 }
564
565 fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
566 Box::new(tor_rtcompat::NoOpStreamOpsHandle::default())
567 }
568}
569
570#[derive(Clone, Error, Debug)]
572#[non_exhaustive]
573pub enum MockNetError {
574 #[error("Invalid operation on mock network")]
576 BadOp,
577}
578
579fn err(k: ErrorKind) -> IoError {
581 IoError::new(k, MockNetError::BadOp)
582}
583
584#[cfg(all(test, not(miri)))] mod test {
586 #![allow(clippy::bool_assert_comparison)]
588 #![allow(clippy::clone_on_copy)]
589 #![allow(clippy::dbg_macro)]
590 #![allow(clippy::mixed_attributes_style)]
591 #![allow(clippy::print_stderr)]
592 #![allow(clippy::print_stdout)]
593 #![allow(clippy::single_char_pattern)]
594 #![allow(clippy::unwrap_used)]
595 #![allow(clippy::unchecked_duration_subtraction)]
596 #![allow(clippy::useless_vec)]
597 #![allow(clippy::needless_pass_by_value)]
598 use super::*;
600 use futures::io::{AsyncReadExt, AsyncWriteExt};
601 use tor_rtcompat::test_with_all_runtimes;
602
603 fn client_pair() -> (MockNetProvider, MockNetProvider) {
604 let net = MockNetwork::new();
605 let client1 = net
606 .builder()
607 .add_address("192.0.2.55".parse().unwrap())
608 .provider();
609 let client2 = net
610 .builder()
611 .add_address("198.51.100.7".parse().unwrap())
612 .provider();
613
614 (client1, client2)
615 }
616
617 #[test]
618 fn end_to_end() {
619 test_with_all_runtimes!(|_rt| async {
620 let (client1, client2) = client_pair();
621 let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
622 let address = lis.local_addr()?;
623
624 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
625 async {
626 let mut conn = client1.connect(&address).await?;
627 conn.write_all(b"This is totally a network.").await?;
628 conn.close().await?;
629
630 let a2 = "192.0.2.200:99".parse().unwrap();
632 let cant_connect = client1.connect(&a2).await;
633 assert!(cant_connect.is_err());
634 Ok(())
635 },
636 async {
637 let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
638 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
639 let mut inp = Vec::new();
640 conn.read_to_end(&mut inp).await?;
641 assert_eq!(&inp[..], &b"This is totally a network."[..]);
642 Ok(())
643 }
644 );
645 r1?;
646 r2?;
647 IoResult::Ok(())
648 });
649 }
650
651 #[test]
652 fn pick_listener_addr() -> IoResult<()> {
653 let net = MockNetwork::new();
654 let ip4 = "192.0.2.55".parse().unwrap();
655 let ip6 = "2001:db8::7".parse().unwrap();
656 let client = net.builder().add_address(ip4).add_address(ip6).provider();
657
658 let a1 = client.get_listener_addr(&"0.0.0.0:99".parse().unwrap())?;
660 assert_eq!(a1.ip(), ip4);
661 assert_eq!(a1.port(), 99);
662 let a2 = client.get_listener_addr(&"192.0.2.55:100".parse().unwrap())?;
663 assert_eq!(a2.ip(), ip4);
664 assert_eq!(a2.port(), 100);
665 let a3 = client.get_listener_addr(&"192.0.2.55:0".parse().unwrap())?;
666 assert_eq!(a3.ip(), ip4);
667 assert!(a3.port() != 0);
668 let a4 = client.get_listener_addr(&"0.0.0.0:0".parse().unwrap())?;
669 assert_eq!(a4.ip(), ip4);
670 assert!(a4.port() != 0);
671 assert!(a4.port() != a3.port());
672 let a5 = client.get_listener_addr(&"[::]:99".parse().unwrap())?;
673 assert_eq!(a5.ip(), ip6);
674 assert_eq!(a5.port(), 99);
675 let a6 = client.get_listener_addr(&"[2001:db8::7]:100".parse().unwrap())?;
676 assert_eq!(a6.ip(), ip6);
677 assert_eq!(a6.port(), 100);
678
679 let e1 = client.get_listener_addr(&"192.0.2.56:0".parse().unwrap());
681 let e2 = client.get_listener_addr(&"[2001:db8::8]:0".parse().unwrap());
682 assert!(e1.is_err());
683 assert!(e2.is_err());
684
685 IoResult::Ok(())
686 }
687
688 #[test]
689 fn listener_stream() {
690 test_with_all_runtimes!(|_rt| async {
691 let (client1, client2) = client_pair();
692
693 let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
694 let address = lis.local_addr()?;
695 let mut incoming = lis.incoming();
696
697 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
698 async {
699 for _ in 0..3_u8 {
700 let mut c = client1.connect(&address).await?;
701 c.close().await?;
702 }
703 Ok(())
704 },
705 async {
706 for _ in 0..3_u8 {
707 let (mut c, a) = incoming.next().await.unwrap()?;
708 let mut v = Vec::new();
709 let _ = c.read_to_end(&mut v).await?;
710 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
711 }
712 Ok(())
713 }
714 );
715 r1?;
716 r2?;
717 IoResult::Ok(())
718 });
719 }
720
721 #[test]
722 fn tls_basics() {
723 let (client1, client2) = client_pair();
724 let cert = b"I am certified for something I assure you.";
725
726 test_with_all_runtimes!(|_rt| async {
727 let lis = client2
728 .listen_tls(&"0.0.0.0:0".parse().unwrap(), cert[..].into())
729 .unwrap();
730 let address = lis.local_addr().unwrap();
731
732 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
733 async {
734 let connector = client1.tls_connector();
735 let conn = client1.connect(&address).await?;
736 let mut conn = connector
737 .negotiate_unvalidated(conn, "zombo.example.com")
738 .await?;
739 assert_eq!(&conn.peer_certificate()?.unwrap()[..], &cert[..]);
740 conn.write_all(b"This is totally encrypted.").await?;
741 let mut v = Vec::new();
742 conn.read_to_end(&mut v).await?;
743 conn.close().await?;
744 assert_eq!(v[..], b"Yup, your secrets is safe"[..]);
745 Ok(())
746 },
747 async {
748 let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
749 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
750 let mut inp = [0_u8; 26];
751 conn.read_exact(&mut inp[..]).await?;
752 assert_eq!(&inp[..], &b"This is totally encrypted."[..]);
753 conn.write_all(b"Yup, your secrets is safe").await?;
754 Ok(())
755 }
756 );
757 r1?;
758 r2?;
759 IoResult::Ok(())
760 });
761 }
762}