1#![forbid(unsafe_code)] use super::io::{stream_pair, LocalStream};
11use super::MockNetRuntime;
12use crate::util::mpsc_channel;
13use core::fmt;
14use tor_rtcompat::tls::TlsConnector;
15use tor_rtcompat::{
16 CertifiedConn, NetStreamListener, NetStreamProvider, Runtime, StreamOps, TlsProvider,
17};
18use tor_rtcompat::{UdpProvider, UdpSocket};
19
20use async_trait::async_trait;
21use futures::channel::mpsc;
22use futures::io::{AsyncRead, AsyncWrite};
23use futures::lock::Mutex as AsyncMutex;
24use futures::sink::SinkExt;
25use futures::stream::{Stream, StreamExt};
26use futures::FutureExt;
27use std::collections::HashMap;
28use std::fmt::Formatter;
29use std::io::{self, Error as IoError, ErrorKind, Result as IoResult};
30use std::net::{IpAddr, SocketAddr};
31use std::pin::Pin;
32use std::sync::atomic::{AtomicU16, Ordering};
33use std::sync::{Arc, Mutex};
34use std::task::{Context, Poll};
35use thiserror::Error;
36use void::Void;
37
38type ConnSender = mpsc::Sender<(LocalStream, SocketAddr)>;
41type ConnReceiver = mpsc::Receiver<(LocalStream, SocketAddr)>;
43
44#[derive(Default)]
51pub struct MockNetwork {
52 listening: Mutex<HashMap<SocketAddr, AddrBehavior>>,
54}
55
56#[derive(Clone)]
58struct ListenerEntry {
59 send: ConnSender,
62
63 tls_cert: Option<Vec<u8>>,
66}
67
68#[derive(Clone)]
70enum AddrBehavior {
71 Listener(ListenerEntry),
73 Timeout,
75}
76
77#[derive(Clone)]
115pub struct MockNetProvider {
116 inner: Arc<MockNetProviderInner>,
121}
122
123impl fmt::Debug for MockNetProvider {
124 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
125 f.debug_struct("MockNetProvider").finish_non_exhaustive()
126 }
127}
128
129struct MockNetProviderInner {
134 addrs: Vec<IpAddr>,
136 net: Arc<MockNetwork>,
138 next_port: AtomicU16,
143}
144
145pub struct MockNetListener {
149 addr: SocketAddr,
151 receiver: AsyncMutex<ConnReceiver>,
155}
156
157pub struct ProviderBuilder {
161 addrs: Vec<IpAddr>,
163 net: Arc<MockNetwork>,
165}
166
167impl Default for MockNetProvider {
168 fn default() -> Self {
169 Arc::new(MockNetwork::default()).builder().provider()
170 }
171}
172
173impl MockNetwork {
174 pub fn new() -> Arc<Self> {
176 Default::default()
177 }
178
179 pub fn builder(self: &Arc<Self>) -> ProviderBuilder {
192 ProviderBuilder {
193 addrs: vec![],
194 net: Arc::clone(self),
195 }
196 }
197
198 pub fn add_blackhole(&self, address: SocketAddr) -> IoResult<()> {
200 let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
201 if listener_map.contains_key(&address) {
202 return Err(err(ErrorKind::AddrInUse));
203 }
204 listener_map.insert(address, AddrBehavior::Timeout);
205 Ok(())
206 }
207
208 async fn send_connection(
217 &self,
218 source_addr: SocketAddr,
219 target_addr: SocketAddr,
220 peer_stream: LocalStream,
221 ) -> IoResult<Option<Vec<u8>>> {
222 let entry = {
223 let listener_map = self.listening.lock().expect("Poisoned lock for listener");
224 listener_map.get(&target_addr).cloned()
225 };
226 match entry {
227 Some(AddrBehavior::Listener(mut entry)) => {
228 if entry.send.send((peer_stream, source_addr)).await.is_ok() {
229 return Ok(entry.tls_cert);
230 }
231 Err(err(ErrorKind::ConnectionRefused))
232 }
233 Some(AddrBehavior::Timeout) => futures::future::pending().await,
234 None => Err(err(ErrorKind::ConnectionRefused)),
235 }
236 }
237
238 fn add_listener(&self, addr: SocketAddr, tls_cert: Option<Vec<u8>>) -> IoResult<ConnReceiver> {
246 let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
247 if listener_map.contains_key(&addr) {
248 return Err(err(ErrorKind::AddrInUse));
250 }
251
252 let (send, recv) = mpsc_channel(16);
253
254 let entry = ListenerEntry { send, tls_cert };
255
256 listener_map.insert(addr, AddrBehavior::Listener(entry));
257
258 Ok(recv)
259 }
260}
261
262impl ProviderBuilder {
263 pub fn add_address(&mut self, addr: IpAddr) -> &mut Self {
265 self.addrs.push(addr);
266 self
267 }
268 pub fn runtime<R: Runtime>(&self, runtime: R) -> super::MockNetRuntime<R> {
271 MockNetRuntime::new(runtime, self.provider())
272 }
273 pub fn provider(&self) -> MockNetProvider {
275 let inner = MockNetProviderInner {
276 addrs: self.addrs.clone(),
277 net: Arc::clone(&self.net),
278 next_port: AtomicU16::new(1),
279 };
280 MockNetProvider {
281 inner: Arc::new(inner),
282 }
283 }
284}
285
286impl NetStreamListener for MockNetListener {
287 type Stream = LocalStream;
288
289 type Incoming = Self;
290
291 fn local_addr(&self) -> IoResult<SocketAddr> {
292 Ok(self.addr)
293 }
294
295 fn incoming(self) -> Self {
296 self
297 }
298}
299
300impl Stream for MockNetListener {
301 type Item = IoResult<(LocalStream, SocketAddr)>;
302 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
303 let mut recv = futures::ready!(self.receiver.lock().poll_unpin(cx));
304 match recv.poll_next_unpin(cx) {
305 Poll::Pending => Poll::Pending,
306 Poll::Ready(None) => Poll::Ready(None),
307 Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
308 }
309 }
310}
311
312#[derive(Debug)]
314#[non_exhaustive]
315pub struct MockUdpSocket {
316 void: Void,
321}
322
323#[async_trait]
324impl UdpProvider for MockNetProvider {
325 type UdpSocket = MockUdpSocket;
326
327 async fn bind(&self, addr: &SocketAddr) -> IoResult<MockUdpSocket> {
328 let _ = addr; Err(io::ErrorKind::Unsupported.into())
330 }
331}
332
333#[allow(clippy::diverging_sub_expression)] #[async_trait]
335impl UdpSocket for MockUdpSocket {
336 async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
337 void::unreachable((self.void, buf).0)
341 }
342 async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
343 void::unreachable((self.void, buf, target).0)
344 }
345 fn local_addr(&self) -> IoResult<SocketAddr> {
346 void::unreachable(self.void)
347 }
348}
349
350impl MockNetProvider {
351 fn get_addr_in_family(&self, other: &IpAddr) -> Option<IpAddr> {
354 self.inner
355 .addrs
356 .iter()
357 .find(|a| a.is_ipv4() == other.is_ipv4())
358 .copied()
359 }
360
361 fn arbitrary_port(&self) -> u16 {
369 let next = self.inner.next_port.fetch_add(1, Ordering::Relaxed);
370 assert!(next != 0);
371 next
372 }
373
374 fn get_origin_addr_for(&self, addr: &SocketAddr) -> IoResult<SocketAddr> {
381 let my_addr = self
382 .get_addr_in_family(&addr.ip())
383 .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?;
384 Ok(SocketAddr::new(my_addr, self.arbitrary_port()))
385 }
386
387 fn get_listener_addr(&self, spec: &SocketAddr) -> IoResult<SocketAddr> {
397 let ipaddr = {
398 let ip = spec.ip();
399 if ip.is_unspecified() {
400 self.get_addr_in_family(&ip)
401 .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?
402 } else if self.inner.addrs.iter().any(|a| a == &ip) {
403 ip
404 } else {
405 return Err(err(ErrorKind::AddrNotAvailable));
406 }
407 };
408 let port = {
409 if spec.port() == 0 {
410 self.arbitrary_port()
411 } else {
412 spec.port()
413 }
414 };
415
416 Ok(SocketAddr::new(ipaddr, port))
417 }
418
419 pub fn listen_tls(&self, addr: &SocketAddr, tls_cert: Vec<u8>) -> IoResult<MockNetListener> {
425 let addr = self.get_listener_addr(addr)?;
426
427 let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, Some(tls_cert))?);
428
429 Ok(MockNetListener { addr, receiver })
430 }
431}
432
433#[async_trait]
434impl NetStreamProvider for MockNetProvider {
435 type Stream = LocalStream;
436 type Listener = MockNetListener;
437
438 async fn connect(&self, addr: &SocketAddr) -> IoResult<LocalStream> {
439 let my_addr = self.get_origin_addr_for(addr)?;
440 let (mut mine, theirs) = stream_pair();
441
442 let cert = self
443 .inner
444 .net
445 .send_connection(my_addr, *addr, theirs)
446 .await?;
447
448 mine.tls_cert = cert;
449
450 Ok(mine)
451 }
452
453 async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
454 let addr = self.get_listener_addr(addr)?;
455
456 let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, None)?);
457
458 Ok(MockNetListener { addr, receiver })
459 }
460}
461
462#[async_trait]
463impl TlsProvider<LocalStream> for MockNetProvider {
464 type Connector = MockTlsConnector;
465 type TlsStream = MockTlsStream;
466
467 fn tls_connector(&self) -> MockTlsConnector {
468 MockTlsConnector {}
469 }
470
471 fn supports_keying_material_export(&self) -> bool {
472 false
473 }
474}
475
476#[derive(Clone)]
481#[non_exhaustive]
482pub struct MockTlsConnector;
483
484pub struct MockTlsStream {
493 peer_cert: Option<Vec<u8>>,
495 stream: LocalStream,
497}
498
499#[async_trait]
500impl TlsConnector<LocalStream> for MockTlsConnector {
501 type Conn = MockTlsStream;
502
503 async fn negotiate_unvalidated(
504 &self,
505 mut stream: LocalStream,
506 _sni_hostname: &str,
507 ) -> IoResult<MockTlsStream> {
508 let peer_cert = stream.tls_cert.take();
509
510 if peer_cert.is_none() {
511 return Err(std::io::Error::new(
512 std::io::ErrorKind::Other,
513 "attempted to wrap non-TLS stream!",
514 ));
515 }
516
517 Ok(MockTlsStream { peer_cert, stream })
518 }
519}
520
521impl CertifiedConn for MockTlsStream {
522 fn peer_certificate(&self) -> IoResult<Option<Vec<u8>>> {
523 Ok(self.peer_cert.clone())
524 }
525 fn export_keying_material(
526 &self,
527 _len: usize,
528 _label: &[u8],
529 _context: Option<&[u8]>,
530 ) -> IoResult<Vec<u8>> {
531 Ok(Vec::new())
532 }
533}
534
535impl AsyncRead for MockTlsStream {
536 fn poll_read(
537 mut self: Pin<&mut Self>,
538 cx: &mut Context<'_>,
539 buf: &mut [u8],
540 ) -> Poll<IoResult<usize>> {
541 Pin::new(&mut self.stream).poll_read(cx, buf)
542 }
543}
544impl AsyncWrite for MockTlsStream {
545 fn poll_write(
546 mut self: Pin<&mut Self>,
547 cx: &mut Context<'_>,
548 buf: &[u8],
549 ) -> Poll<IoResult<usize>> {
550 Pin::new(&mut self.stream).poll_write(cx, buf)
551 }
552 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
553 Pin::new(&mut self.stream).poll_flush(cx)
554 }
555 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
556 Pin::new(&mut self.stream).poll_close(cx)
557 }
558}
559
560impl StreamOps for MockTlsStream {
561 fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
562 Err(std::io::Error::new(
563 std::io::ErrorKind::Unsupported,
564 "not supported on non-StreamOps stream!",
565 ))
566 }
567
568 fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
569 Box::new(tor_rtcompat::NoOpStreamOpsHandle::default())
570 }
571}
572
573#[derive(Clone, Error, Debug)]
575#[non_exhaustive]
576pub enum MockNetError {
577 #[error("Invalid operation on mock network")]
579 BadOp,
580}
581
582fn err(k: ErrorKind) -> IoError {
584 IoError::new(k, MockNetError::BadOp)
585}
586
587#[cfg(all(test, not(miri)))] mod test {
589 #![allow(clippy::bool_assert_comparison)]
591 #![allow(clippy::clone_on_copy)]
592 #![allow(clippy::dbg_macro)]
593 #![allow(clippy::mixed_attributes_style)]
594 #![allow(clippy::print_stderr)]
595 #![allow(clippy::print_stdout)]
596 #![allow(clippy::single_char_pattern)]
597 #![allow(clippy::unwrap_used)]
598 #![allow(clippy::unchecked_duration_subtraction)]
599 #![allow(clippy::useless_vec)]
600 #![allow(clippy::needless_pass_by_value)]
601 use super::*;
603 use futures::io::{AsyncReadExt, AsyncWriteExt};
604 use tor_rtcompat::test_with_all_runtimes;
605
606 fn client_pair() -> (MockNetProvider, MockNetProvider) {
607 let net = MockNetwork::new();
608 let client1 = net
609 .builder()
610 .add_address("192.0.2.55".parse().unwrap())
611 .provider();
612 let client2 = net
613 .builder()
614 .add_address("198.51.100.7".parse().unwrap())
615 .provider();
616
617 (client1, client2)
618 }
619
620 #[test]
621 fn end_to_end() {
622 test_with_all_runtimes!(|_rt| async {
623 let (client1, client2) = client_pair();
624 let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
625 let address = lis.local_addr()?;
626
627 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
628 async {
629 let mut conn = client1.connect(&address).await?;
630 conn.write_all(b"This is totally a network.").await?;
631 conn.close().await?;
632
633 let a2 = "192.0.2.200:99".parse().unwrap();
635 let cant_connect = client1.connect(&a2).await;
636 assert!(cant_connect.is_err());
637 Ok(())
638 },
639 async {
640 let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
641 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
642 let mut inp = Vec::new();
643 conn.read_to_end(&mut inp).await?;
644 assert_eq!(&inp[..], &b"This is totally a network."[..]);
645 Ok(())
646 }
647 );
648 r1?;
649 r2?;
650 IoResult::Ok(())
651 });
652 }
653
654 #[test]
655 fn pick_listener_addr() -> IoResult<()> {
656 let net = MockNetwork::new();
657 let ip4 = "192.0.2.55".parse().unwrap();
658 let ip6 = "2001:db8::7".parse().unwrap();
659 let client = net.builder().add_address(ip4).add_address(ip6).provider();
660
661 let a1 = client.get_listener_addr(&"0.0.0.0:99".parse().unwrap())?;
663 assert_eq!(a1.ip(), ip4);
664 assert_eq!(a1.port(), 99);
665 let a2 = client.get_listener_addr(&"192.0.2.55:100".parse().unwrap())?;
666 assert_eq!(a2.ip(), ip4);
667 assert_eq!(a2.port(), 100);
668 let a3 = client.get_listener_addr(&"192.0.2.55:0".parse().unwrap())?;
669 assert_eq!(a3.ip(), ip4);
670 assert!(a3.port() != 0);
671 let a4 = client.get_listener_addr(&"0.0.0.0:0".parse().unwrap())?;
672 assert_eq!(a4.ip(), ip4);
673 assert!(a4.port() != 0);
674 assert!(a4.port() != a3.port());
675 let a5 = client.get_listener_addr(&"[::]:99".parse().unwrap())?;
676 assert_eq!(a5.ip(), ip6);
677 assert_eq!(a5.port(), 99);
678 let a6 = client.get_listener_addr(&"[2001:db8::7]:100".parse().unwrap())?;
679 assert_eq!(a6.ip(), ip6);
680 assert_eq!(a6.port(), 100);
681
682 let e1 = client.get_listener_addr(&"192.0.2.56:0".parse().unwrap());
684 let e2 = client.get_listener_addr(&"[2001:db8::8]:0".parse().unwrap());
685 assert!(e1.is_err());
686 assert!(e2.is_err());
687
688 IoResult::Ok(())
689 }
690
691 #[test]
692 fn listener_stream() {
693 test_with_all_runtimes!(|_rt| async {
694 let (client1, client2) = client_pair();
695
696 let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
697 let address = lis.local_addr()?;
698 let mut incoming = lis.incoming();
699
700 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
701 async {
702 for _ in 0..3_u8 {
703 let mut c = client1.connect(&address).await?;
704 c.close().await?;
705 }
706 Ok(())
707 },
708 async {
709 for _ in 0..3_u8 {
710 let (mut c, a) = incoming.next().await.unwrap()?;
711 let mut v = Vec::new();
712 let _ = c.read_to_end(&mut v).await?;
713 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
714 }
715 Ok(())
716 }
717 );
718 r1?;
719 r2?;
720 IoResult::Ok(())
721 });
722 }
723
724 #[test]
725 fn tls_basics() {
726 let (client1, client2) = client_pair();
727 let cert = b"I am certified for something I assure you.";
728
729 test_with_all_runtimes!(|_rt| async {
730 let lis = client2
731 .listen_tls(&"0.0.0.0:0".parse().unwrap(), cert[..].into())
732 .unwrap();
733 let address = lis.local_addr().unwrap();
734
735 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
736 async {
737 let connector = client1.tls_connector();
738 let conn = client1.connect(&address).await?;
739 let mut conn = connector
740 .negotiate_unvalidated(conn, "zombo.example.com")
741 .await?;
742 assert_eq!(&conn.peer_certificate()?.unwrap()[..], &cert[..]);
743 conn.write_all(b"This is totally encrypted.").await?;
744 let mut v = Vec::new();
745 conn.read_to_end(&mut v).await?;
746 conn.close().await?;
747 assert_eq!(v[..], b"Yup, your secrets is safe"[..]);
748 Ok(())
749 },
750 async {
751 let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
752 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
753 let mut inp = [0_u8; 26];
754 conn.read_exact(&mut inp[..]).await?;
755 assert_eq!(&inp[..], &b"This is totally encrypted."[..]);
756 conn.write_all(b"Yup, your secrets is safe").await?;
757 Ok(())
758 }
759 );
760 r1?;
761 r2?;
762 IoResult::Ok(())
763 });
764 }
765}