1
//! Implements a simple mock network for testing purposes.
2

            
3
// Note: There are lots of opportunities here for making the network
4
// more and more realistic, but please remember that this module only
5
// exists for writing unit tests.  Let's resist the temptation to add
6
// things we don't need.
7

            
8
#![forbid(unsafe_code)] // if you remove this, enable (or write) miri tests (git grep miri)
9

            
10
use super::io::{stream_pair, LocalStream};
11
use super::MockNetRuntime;
12
use crate::util::mpsc_channel;
13
use core::fmt;
14
use tor_rtcompat::tls::TlsConnector;
15
use tor_rtcompat::{
16
    CertifiedConn, NetStreamListener, NetStreamProvider, Runtime, StreamOps, TlsProvider,
17
};
18
use tor_rtcompat::{UdpProvider, UdpSocket};
19

            
20
use async_trait::async_trait;
21
use futures::channel::mpsc;
22
use futures::io::{AsyncRead, AsyncWrite};
23
use futures::lock::Mutex as AsyncMutex;
24
use futures::sink::SinkExt;
25
use futures::stream::{Stream, StreamExt};
26
use futures::FutureExt;
27
use std::collections::HashMap;
28
use std::fmt::Formatter;
29
use std::io::{self, Error as IoError, ErrorKind, Result as IoResult};
30
use std::net::{IpAddr, SocketAddr};
31
use std::pin::Pin;
32
use std::sync::atomic::{AtomicU16, Ordering};
33
use std::sync::{Arc, Mutex};
34
use std::task::{Context, Poll};
35
use thiserror::Error;
36
use void::Void;
37

            
38
/// A channel sender that we use to send incoming connections to
39
/// listeners.
40
type ConnSender = mpsc::Sender<(LocalStream, SocketAddr)>;
41
/// A channel receiver that listeners use to receive incoming connections.
42
type ConnReceiver = mpsc::Receiver<(LocalStream, SocketAddr)>;
43

            
44
/// A simulated Internet, for testing.
45
///
46
/// We simulate TCP streams only, and skip all the details. Connection
47
/// are implemented using [`LocalStream`]. The MockNetwork object is
48
/// shared by a large set of MockNetworkProviders, each of which has
49
/// its own view of its address(es) on the network.
50
#[derive(Default)]
51
pub struct MockNetwork {
52
    /// A map from address to the entries about listeners there.
53
    listening: Mutex<HashMap<SocketAddr, AddrBehavior>>,
54
}
55

            
56
/// The `MockNetwork`'s view of a listener.
57
#[derive(Clone)]
58
struct ListenerEntry {
59
    /// A sender that need to be informed about connection attempts
60
    /// there.
61
    send: ConnSender,
62

            
63
    /// A notional TLS certificate for this listener.  If absent, the
64
    /// listener isn't a TLS listener.
65
    tls_cert: Option<Vec<u8>>,
66
}
67

            
68
/// A possible non-error behavior from an address
69
#[derive(Clone)]
70
enum AddrBehavior {
71
    /// There's a listener at this address, which would like to reply.
72
    Listener(ListenerEntry),
73
    /// All connections sent to this address will time out.
74
    Timeout,
75
}
76

            
77
/// A view of a single host's access to a MockNetwork.
78
///
79
/// Each simulated host has its own addresses that it's allowed to listen on,
80
/// and a reference to the network.
81
///
82
/// This type implements [`NetStreamProvider`] for [`SocketAddr`]
83
/// so that it can be used as a
84
/// drop-in replacement for testing code that uses the network.
85
///
86
/// # Limitations
87
///
88
/// There's no randomness here, so we can't simulate the weirdness of
89
/// real networks.
90
///
91
/// So far, there's no support for DNS or UDP.
92
///
93
/// We don't handle localhost specially, and we don't simulate providers
94
/// that can connect to some addresses but not all.
95
///
96
/// We don't do the right thing (block) if there is a listener that
97
/// never calls accept.
98
///
99
/// UDP is completely broken:
100
/// datagrams appear to be transmitted, but will never be received.
101
/// And local address assignment is not implemented
102
/// so [`.local_addr()`](UdpSocket::local_addr) can return `NONE`
103
// TODO MOCK UDP: Documentation does describe the brokennesses
104
///
105
/// We use a simple `u16` counter to decide what arbitrary port
106
/// numbers to use: Once that counter is exhausted, we will fail with
107
/// an assertion.  We don't do anything to prevent those arbitrary
108
/// ports from colliding with specified ports, other than declare that
109
/// you can't have two listeners on the same addr:port at the same
110
/// time.
111
///
112
/// We pretend to provide TLS, but there's no actual encryption or
113
/// authentication.
114
#[derive(Clone)]
115
pub struct MockNetProvider {
116
    /// Actual implementation of this host's view of the network.
117
    ///
118
    /// We have to use a separate type here and reference count it,
119
    /// since the `next_port` counter needs to be shared.
120
    inner: Arc<MockNetProviderInner>,
121
}
122

            
123
impl fmt::Debug for MockNetProvider {
124
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
125
        f.debug_struct("MockNetProvider").finish_non_exhaustive()
126
    }
127
}
128

            
129
/// Shared part of a MockNetworkProvider.
130
///
131
/// This is separate because providers need to implement Clone, but
132
/// `next_port` can't be cloned.
133
struct MockNetProviderInner {
134
    /// List of public addresses
135
    addrs: Vec<IpAddr>,
136
    /// Shared reference to the network.
137
    net: Arc<MockNetwork>,
138
    /// Next port number to hand out when we're asked to listen on
139
    /// port 0.
140
    ///
141
    /// See discussion of limitations on `listen()` implementation.
142
    next_port: AtomicU16,
143
}
144

            
145
/// A [`NetStreamListener`] implementation returned by a [`MockNetProvider`].
146
///
147
/// Represents listening on a public address for incoming TCP connections.
148
pub struct MockNetListener {
149
    /// The address that we're listening on.
150
    addr: SocketAddr,
151
    /// The incoming channel that tells us about new connections.
152
    // TODO: I'm not thrilled to have to use an AsyncMutex and a
153
    // std Mutex in the same module.
154
    receiver: AsyncMutex<ConnReceiver>,
155
}
156

            
157
/// A builder object used to configure a [`MockNetProvider`]
158
///
159
/// Returned by [`MockNetwork::builder()`].
160
pub struct ProviderBuilder {
161
    /// List of public addresses.
162
    addrs: Vec<IpAddr>,
163
    /// Shared reference to the network.
164
    net: Arc<MockNetwork>,
165
}
166

            
167
impl Default for MockNetProvider {
168
44494
    fn default() -> Self {
169
44494
        Arc::new(MockNetwork::default()).builder().provider()
170
44494
    }
171
}
172

            
173
impl MockNetwork {
174
    /// Make a new MockNetwork with no active listeners.
175
112
    pub fn new() -> Arc<Self> {
176
112
        Default::default()
177
112
    }
178

            
179
    /// Return a [`ProviderBuilder`] for creating a [`MockNetProvider`]
180
    ///
181
    /// # Examples
182
    ///
183
    /// ```
184
    /// # use tor_rtmock::net::*;
185
    /// # let mock_network = MockNetwork::new();
186
    /// let client_net = mock_network.builder()
187
    ///       .add_address("198.51.100.6".parse().unwrap())
188
    ///       .add_address("2001:db8::7".parse().unwrap())
189
    ///       .provider();
190
    /// ```
191
44716
    pub fn builder(self: &Arc<Self>) -> ProviderBuilder {
192
44716
        ProviderBuilder {
193
44716
            addrs: vec![],
194
44716
            net: Arc::clone(self),
195
44716
        }
196
44716
    }
197

            
198
    /// Add a "black hole" at the given address, where all traffic will time out.
199
46
    pub fn add_blackhole(&self, address: SocketAddr) -> IoResult<()> {
200
46
        let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
201
46
        if listener_map.contains_key(&address) {
202
            return Err(err(ErrorKind::AddrInUse));
203
46
        }
204
46
        listener_map.insert(address, AddrBehavior::Timeout);
205
46
        Ok(())
206
46
    }
207

            
208
    /// Tell the listener at `target_addr` (if any) about an incoming
209
    /// connection from `source_addr` at `peer_stream`.
210
    ///
211
    /// If the listener is a TLS listener, returns its certificate.
212
    /// **Note:** Callers should check whether the presence or absence of a certificate
213
    /// matches their expectations.
214
    ///
215
    /// Returns an error if there isn't any such listener.
216
968
    async fn send_connection(
217
968
        &self,
218
968
        source_addr: SocketAddr,
219
968
        target_addr: SocketAddr,
220
968
        peer_stream: LocalStream,
221
1012
    ) -> IoResult<Option<Vec<u8>>> {
222
968
        let entry = {
223
968
            let listener_map = self.listening.lock().expect("Poisoned lock for listener");
224
968
            listener_map.get(&target_addr).cloned()
225
        };
226
730
        match entry {
227
500
            Some(AddrBehavior::Listener(mut entry)) => {
228
500
                if entry.send.send((peer_stream, source_addr)).await.is_ok() {
229
500
                    return Ok(entry.tls_cert);
230
                }
231
                Err(err(ErrorKind::ConnectionRefused))
232
            }
233
230
            Some(AddrBehavior::Timeout) => futures::future::pending().await,
234
238
            None => Err(err(ErrorKind::ConnectionRefused)),
235
        }
236
738
    }
237

            
238
    /// Register a listener at `addr` and return the ConnReceiver
239
    /// that it should use for connections.
240
    ///
241
    /// If tls_cert is provided, then the listener is a TLS listener
242
    /// and any only TLS connection attempts should succeed.
243
    ///
244
    /// Returns an error if the address is already in use.
245
162
    fn add_listener(&self, addr: SocketAddr, tls_cert: Option<Vec<u8>>) -> IoResult<ConnReceiver> {
246
162
        let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
247
162
        if listener_map.contains_key(&addr) {
248
            // TODO: Maybe this should ignore dangling Weak references?
249
            return Err(err(ErrorKind::AddrInUse));
250
162
        }
251
162

            
252
162
        let (send, recv) = mpsc_channel(16);
253
162

            
254
162
        let entry = ListenerEntry { send, tls_cert };
255
162

            
256
162
        listener_map.insert(addr, AddrBehavior::Listener(entry));
257
162

            
258
162
        Ok(recv)
259
162
    }
260
}
261

            
262
impl ProviderBuilder {
263
    /// Add `addr` as a new address for the provider we're building.
264
270
    pub fn add_address(&mut self, addr: IpAddr) -> &mut Self {
265
270
        self.addrs.push(addr);
266
270
        self
267
270
    }
268
    /// Use this builder to return a new [`MockNetRuntime`] wrapping
269
    /// an existing `runtime`.
270
8
    pub fn runtime<R: Runtime>(&self, runtime: R) -> super::MockNetRuntime<R> {
271
8
        MockNetRuntime::new(runtime, self.provider())
272
8
    }
273
    /// Use this builder to return a new [`MockNetProvider`]
274
44716
    pub fn provider(&self) -> MockNetProvider {
275
44716
        let inner = MockNetProviderInner {
276
44716
            addrs: self.addrs.clone(),
277
44716
            net: Arc::clone(&self.net),
278
44716
            next_port: AtomicU16::new(1),
279
44716
        };
280
44716
        MockNetProvider {
281
44716
            inner: Arc::new(inner),
282
44716
        }
283
44716
    }
284
}
285

            
286
impl NetStreamListener for MockNetListener {
287
    type Stream = LocalStream;
288

            
289
    type Incoming = Self;
290

            
291
24
    fn local_addr(&self) -> IoResult<SocketAddr> {
292
24
        Ok(self.addr)
293
24
    }
294

            
295
70
    fn incoming(self) -> Self {
296
70
        self
297
70
    }
298
}
299

            
300
impl Stream for MockNetListener {
301
    type Item = IoResult<(LocalStream, SocketAddr)>;
302
86
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
303
86
        let mut recv = futures::ready!(self.receiver.lock().poll_unpin(cx));
304
86
        match recv.poll_next_unpin(cx) {
305
            Poll::Pending => Poll::Pending,
306
            Poll::Ready(None) => Poll::Ready(None),
307
86
            Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
308
        }
309
86
    }
310
}
311

            
312
/// A very poor imitation of a UDP socket
313
#[derive(Debug)]
314
#[non_exhaustive]
315
pub struct MockUdpSocket {
316
    /// This is uninhabited.
317
    ///
318
    /// To implement UDP support, implement `.bind()`, and abolish this field,
319
    /// replacing it with the actual implementation.
320
    void: Void,
321
}
322

            
323
#[async_trait]
324
impl UdpProvider for MockNetProvider {
325
    type UdpSocket = MockUdpSocket;
326

            
327
    async fn bind(&self, addr: &SocketAddr) -> IoResult<MockUdpSocket> {
328
        let _ = addr; // MockNetProvider UDP is not implemented
329
        Err(io::ErrorKind::Unsupported.into())
330
    }
331
}
332

            
333
#[allow(clippy::diverging_sub_expression)] // void::unimplemented + async_trait
334
#[async_trait]
335
impl UdpSocket for MockUdpSocket {
336
    async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
337
        // This tuple idiom avoids unused variable warnings.
338
        // An alternative would be to write _buf, but then when this is implemented,
339
        // and the void::unreachable call removed, we actually *want* those warnings.
340
        void::unreachable((self.void, buf).0)
341
    }
342
    async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
343
        void::unreachable((self.void, buf, target).0)
344
    }
345
    fn local_addr(&self) -> IoResult<SocketAddr> {
346
        void::unreachable(self.void)
347
    }
348
}
349

            
350
impl MockNetProvider {
351
    /// If we have a local addresses that is in the same family as `other`,
352
    /// return it.
353
998
    fn get_addr_in_family(&self, other: &IpAddr) -> Option<IpAddr> {
354
998
        self.inner
355
998
            .addrs
356
998
            .iter()
357
1059
            .find(|a| a.is_ipv4() == other.is_ipv4())
358
998
            .copied()
359
998
    }
360

            
361
    /// Return an arbitrary port number that we haven't returned from
362
    /// this function before.
363
    ///
364
    /// # Panics
365
    ///
366
    /// Panics if there are no remaining ports that this function hasn't
367
    /// returned before.
368
980
    fn arbitrary_port(&self) -> u16 {
369
980
        let next = self.inner.next_port.fetch_add(1, Ordering::Relaxed);
370
980
        assert!(next != 0);
371
980
        next
372
980
    }
373

            
374
    /// Helper for connecting: Picks the socketaddr to use
375
    /// when told to connect to `addr`.
376
    ///
377
    /// The IP is one of our own IPs with the same family as `addr`.
378
    /// The port is a port that we haven't used as an arbitrary port
379
    /// before.
380
968
    fn get_origin_addr_for(&self, addr: &SocketAddr) -> IoResult<SocketAddr> {
381
968
        let my_addr = self
382
968
            .get_addr_in_family(&addr.ip())
383
968
            .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?;
384
968
        Ok(SocketAddr::new(my_addr, self.arbitrary_port()))
385
968
    }
386

            
387
    /// Helper for binding a listener: Picks the socketaddr to use
388
    /// when told to bind to `addr`.
389
    ///
390
    /// If addr is `0.0.0.0` or `[::]`, then we pick one of our own
391
    /// addresses with the same family. Otherwise we fail unless `addr` is
392
    /// one of our own addresses.
393
    ///
394
    /// If port is 0, we pick a new arbitrary port we haven't used as
395
    /// an arbitrary port before.
396
178
    fn get_listener_addr(&self, spec: &SocketAddr) -> IoResult<SocketAddr> {
397
174
        let ipaddr = {
398
178
            let ip = spec.ip();
399
178
            if ip.is_unspecified() {
400
30
                self.get_addr_in_family(&ip)
401
30
                    .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?
402
208
            } else if self.inner.addrs.iter().any(|a| a == &ip) {
403
144
                ip
404
            } else {
405
4
                return Err(err(ErrorKind::AddrNotAvailable));
406
            }
407
        };
408
174
        let port = {
409
174
            if spec.port() == 0 {
410
12
                self.arbitrary_port()
411
            } else {
412
162
                spec.port()
413
            }
414
        };
415

            
416
174
        Ok(SocketAddr::new(ipaddr, port))
417
178
    }
418

            
419
    /// Create a mock TLS listener with provided certificate.
420
    ///
421
    /// Note that no encryption or authentication is actually
422
    /// performed!  Other parties are simply told that their connections
423
    /// succeeded and were authenticated against the given certificate.
424
54
    pub fn listen_tls(&self, addr: &SocketAddr, tls_cert: Vec<u8>) -> IoResult<MockNetListener> {
425
54
        let addr = self.get_listener_addr(addr)?;
426

            
427
54
        let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, Some(tls_cert))?);
428

            
429
54
        Ok(MockNetListener { addr, receiver })
430
54
    }
431
}
432

            
433
#[async_trait]
434
impl NetStreamProvider for MockNetProvider {
435
    type Stream = LocalStream;
436
    type Listener = MockNetListener;
437

            
438
968
    async fn connect(&self, addr: &SocketAddr) -> IoResult<LocalStream> {
439
968
        let my_addr = self.get_origin_addr_for(addr)?;
440
968
        let (mut mine, theirs) = stream_pair();
441

            
442
968
        let cert = self
443
968
            .inner
444
968
            .net
445
968
            .send_connection(my_addr, *addr, theirs)
446
968
            .await?;
447

            
448
500
        mine.tls_cert = cert;
449
500

            
450
500
        Ok(mine)
451
1706
    }
452

            
453
108
    async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
454
108
        let addr = self.get_listener_addr(addr)?;
455

            
456
108
        let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, None)?);
457

            
458
108
        Ok(MockNetListener { addr, receiver })
459
216
    }
460
}
461

            
462
#[async_trait]
463
impl TlsProvider<LocalStream> for MockNetProvider {
464
    type Connector = MockTlsConnector;
465
    type TlsStream = MockTlsStream;
466

            
467
790
    fn tls_connector(&self) -> MockTlsConnector {
468
790
        MockTlsConnector {}
469
790
    }
470

            
471
    fn supports_keying_material_export(&self) -> bool {
472
        false
473
    }
474
}
475

            
476
/// Mock TLS connector for use with MockNetProvider.
477
///
478
/// Note that no TLS is actually performed here: connections are simply
479
/// told that they succeeded with a given certificate.
480
#[derive(Clone)]
481
#[non_exhaustive]
482
pub struct MockTlsConnector;
483

            
484
/// Mock TLS connector for use with MockNetProvider.
485
///
486
/// Note that no TLS is actually performed here: connections are simply
487
/// told that they succeeded with a given certificate.
488
///
489
/// Note also that we only use this type for client-side connections
490
/// right now: Arti doesn't support being a real TLS Listener yet,
491
/// since we only handle Tor client operations.
492
pub struct MockTlsStream {
493
    /// The peer certificate that we are pretending our peer has.
494
    peer_cert: Option<Vec<u8>>,
495
    /// The underlying stream.
496
    stream: LocalStream,
497
}
498

            
499
#[async_trait]
500
impl TlsConnector<LocalStream> for MockTlsConnector {
501
    type Conn = MockTlsStream;
502

            
503
    async fn negotiate_unvalidated(
504
        &self,
505
        mut stream: LocalStream,
506
        _sni_hostname: &str,
507
54
    ) -> IoResult<MockTlsStream> {
508
54
        let peer_cert = stream.tls_cert.take();
509
54

            
510
54
        if peer_cert.is_none() {
511
            return Err(std::io::Error::new(
512
                std::io::ErrorKind::Other,
513
                "attempted to wrap non-TLS stream!",
514
            ));
515
54
        }
516
54

            
517
54
        Ok(MockTlsStream { peer_cert, stream })
518
108
    }
519
}
520

            
521
impl CertifiedConn for MockTlsStream {
522
54
    fn peer_certificate(&self) -> IoResult<Option<Vec<u8>>> {
523
54
        Ok(self.peer_cert.clone())
524
54
    }
525
    fn export_keying_material(
526
        &self,
527
        _len: usize,
528
        _label: &[u8],
529
        _context: Option<&[u8]>,
530
    ) -> IoResult<Vec<u8>> {
531
        Ok(Vec::new())
532
    }
533
}
534

            
535
impl AsyncRead for MockTlsStream {
536
760
    fn poll_read(
537
760
        mut self: Pin<&mut Self>,
538
760
        cx: &mut Context<'_>,
539
760
        buf: &mut [u8],
540
760
    ) -> Poll<IoResult<usize>> {
541
760
        Pin::new(&mut self.stream).poll_read(cx, buf)
542
760
    }
543
}
544
impl AsyncWrite for MockTlsStream {
545
192
    fn poll_write(
546
192
        mut self: Pin<&mut Self>,
547
192
        cx: &mut Context<'_>,
548
192
        buf: &[u8],
549
192
    ) -> Poll<IoResult<usize>> {
550
192
        Pin::new(&mut self.stream).poll_write(cx, buf)
551
192
    }
552
92
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
553
92
        Pin::new(&mut self.stream).poll_flush(cx)
554
92
    }
555
8
    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
556
8
        Pin::new(&mut self.stream).poll_close(cx)
557
8
    }
558
}
559

            
560
impl StreamOps for MockTlsStream {
561
    fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
562
        Err(std::io::Error::new(
563
            std::io::ErrorKind::Unsupported,
564
            "not supported on non-StreamOps stream!",
565
        ))
566
    }
567

            
568
46
    fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
569
46
        Box::new(tor_rtcompat::NoOpStreamOpsHandle::default())
570
46
    }
571
}
572

            
573
/// Inner error type returned when a `MockNetwork` operation fails.
574
#[derive(Clone, Error, Debug)]
575
#[non_exhaustive]
576
pub enum MockNetError {
577
    /// General-purpose error.  The real information is in `ErrorKind`.
578
    #[error("Invalid operation on mock network")]
579
    BadOp,
580
}
581

            
582
/// Wrap `k` in a new [`std::io::Error`].
583
242
fn err(k: ErrorKind) -> IoError {
584
242
    IoError::new(k, MockNetError::BadOp)
585
242
}
586

            
587
#[cfg(all(test, not(miri)))] // miri cannot simulate the networking
588
mod test {
589
    // @@ begin test lint list maintained by maint/add_warning @@
590
    #![allow(clippy::bool_assert_comparison)]
591
    #![allow(clippy::clone_on_copy)]
592
    #![allow(clippy::dbg_macro)]
593
    #![allow(clippy::mixed_attributes_style)]
594
    #![allow(clippy::print_stderr)]
595
    #![allow(clippy::print_stdout)]
596
    #![allow(clippy::single_char_pattern)]
597
    #![allow(clippy::unwrap_used)]
598
    #![allow(clippy::unchecked_duration_subtraction)]
599
    #![allow(clippy::useless_vec)]
600
    #![allow(clippy::needless_pass_by_value)]
601
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
602
    use super::*;
603
    use futures::io::{AsyncReadExt, AsyncWriteExt};
604
    use tor_rtcompat::test_with_all_runtimes;
605

            
606
    fn client_pair() -> (MockNetProvider, MockNetProvider) {
607
        let net = MockNetwork::new();
608
        let client1 = net
609
            .builder()
610
            .add_address("192.0.2.55".parse().unwrap())
611
            .provider();
612
        let client2 = net
613
            .builder()
614
            .add_address("198.51.100.7".parse().unwrap())
615
            .provider();
616

            
617
        (client1, client2)
618
    }
619

            
620
    #[test]
621
    fn end_to_end() {
622
        test_with_all_runtimes!(|_rt| async {
623
            let (client1, client2) = client_pair();
624
            let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
625
            let address = lis.local_addr()?;
626

            
627
            let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
628
                async {
629
                    let mut conn = client1.connect(&address).await?;
630
                    conn.write_all(b"This is totally a network.").await?;
631
                    conn.close().await?;
632

            
633
                    // Nobody listening here...
634
                    let a2 = "192.0.2.200:99".parse().unwrap();
635
                    let cant_connect = client1.connect(&a2).await;
636
                    assert!(cant_connect.is_err());
637
                    Ok(())
638
                },
639
                async {
640
                    let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
641
                    assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
642
                    let mut inp = Vec::new();
643
                    conn.read_to_end(&mut inp).await?;
644
                    assert_eq!(&inp[..], &b"This is totally a network."[..]);
645
                    Ok(())
646
                }
647
            );
648
            r1?;
649
            r2?;
650
            IoResult::Ok(())
651
        });
652
    }
653

            
654
    #[test]
655
    fn pick_listener_addr() -> IoResult<()> {
656
        let net = MockNetwork::new();
657
        let ip4 = "192.0.2.55".parse().unwrap();
658
        let ip6 = "2001:db8::7".parse().unwrap();
659
        let client = net.builder().add_address(ip4).add_address(ip6).provider();
660

            
661
        // Successful cases
662
        let a1 = client.get_listener_addr(&"0.0.0.0:99".parse().unwrap())?;
663
        assert_eq!(a1.ip(), ip4);
664
        assert_eq!(a1.port(), 99);
665
        let a2 = client.get_listener_addr(&"192.0.2.55:100".parse().unwrap())?;
666
        assert_eq!(a2.ip(), ip4);
667
        assert_eq!(a2.port(), 100);
668
        let a3 = client.get_listener_addr(&"192.0.2.55:0".parse().unwrap())?;
669
        assert_eq!(a3.ip(), ip4);
670
        assert!(a3.port() != 0);
671
        let a4 = client.get_listener_addr(&"0.0.0.0:0".parse().unwrap())?;
672
        assert_eq!(a4.ip(), ip4);
673
        assert!(a4.port() != 0);
674
        assert!(a4.port() != a3.port());
675
        let a5 = client.get_listener_addr(&"[::]:99".parse().unwrap())?;
676
        assert_eq!(a5.ip(), ip6);
677
        assert_eq!(a5.port(), 99);
678
        let a6 = client.get_listener_addr(&"[2001:db8::7]:100".parse().unwrap())?;
679
        assert_eq!(a6.ip(), ip6);
680
        assert_eq!(a6.port(), 100);
681

            
682
        // Failing cases
683
        let e1 = client.get_listener_addr(&"192.0.2.56:0".parse().unwrap());
684
        let e2 = client.get_listener_addr(&"[2001:db8::8]:0".parse().unwrap());
685
        assert!(e1.is_err());
686
        assert!(e2.is_err());
687

            
688
        IoResult::Ok(())
689
    }
690

            
691
    #[test]
692
    fn listener_stream() {
693
        test_with_all_runtimes!(|_rt| async {
694
            let (client1, client2) = client_pair();
695

            
696
            let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
697
            let address = lis.local_addr()?;
698
            let mut incoming = lis.incoming();
699

            
700
            let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
701
                async {
702
                    for _ in 0..3_u8 {
703
                        let mut c = client1.connect(&address).await?;
704
                        c.close().await?;
705
                    }
706
                    Ok(())
707
                },
708
                async {
709
                    for _ in 0..3_u8 {
710
                        let (mut c, a) = incoming.next().await.unwrap()?;
711
                        let mut v = Vec::new();
712
                        let _ = c.read_to_end(&mut v).await?;
713
                        assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
714
                    }
715
                    Ok(())
716
                }
717
            );
718
            r1?;
719
            r2?;
720
            IoResult::Ok(())
721
        });
722
    }
723

            
724
    #[test]
725
    fn tls_basics() {
726
        let (client1, client2) = client_pair();
727
        let cert = b"I am certified for something I assure you.";
728

            
729
        test_with_all_runtimes!(|_rt| async {
730
            let lis = client2
731
                .listen_tls(&"0.0.0.0:0".parse().unwrap(), cert[..].into())
732
                .unwrap();
733
            let address = lis.local_addr().unwrap();
734

            
735
            let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
736
                async {
737
                    let connector = client1.tls_connector();
738
                    let conn = client1.connect(&address).await?;
739
                    let mut conn = connector
740
                        .negotiate_unvalidated(conn, "zombo.example.com")
741
                        .await?;
742
                    assert_eq!(&conn.peer_certificate()?.unwrap()[..], &cert[..]);
743
                    conn.write_all(b"This is totally encrypted.").await?;
744
                    let mut v = Vec::new();
745
                    conn.read_to_end(&mut v).await?;
746
                    conn.close().await?;
747
                    assert_eq!(v[..], b"Yup, your secrets is safe"[..]);
748
                    Ok(())
749
                },
750
                async {
751
                    let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
752
                    assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
753
                    let mut inp = [0_u8; 26];
754
                    conn.read_exact(&mut inp[..]).await?;
755
                    assert_eq!(&inp[..], &b"This is totally encrypted."[..]);
756
                    conn.write_all(b"Yup, your secrets is safe").await?;
757
                    Ok(())
758
                }
759
            );
760
            r1?;
761
            r2?;
762
            IoResult::Ok(())
763
        });
764
    }
765
}