1
//! Implementation for using `native_tls`
2

            
3
use crate::traits::{CertifiedConn, StreamOps, TlsConnector, TlsProvider};
4

            
5
use async_trait::async_trait;
6
use futures::{AsyncRead, AsyncWrite};
7
use native_tls_crate as native_tls;
8
use std::io::{Error as IoError, Result as IoResult};
9

            
10
/// A [`TlsProvider`] that uses `native_tls`.
11
///
12
/// It supports wrapping any reasonable stream type that implements `AsyncRead` + `AsyncWrite`.
13
#[cfg_attr(
14
    docsrs,
15
    doc(cfg(all(feature = "native-tls", any(feature = "tokio", feature = "async-std"))))
16
)]
17
#[derive(Default, Clone)]
18
#[non_exhaustive]
19
pub struct NativeTlsProvider {}
20

            
21
impl<S> CertifiedConn for async_native_tls::TlsStream<S>
22
where
23
    S: AsyncRead + AsyncWrite + Unpin,
24
{
25
6
    fn peer_certificate(&self) -> IoResult<Option<Vec<u8>>> {
26
6
        let cert = self.peer_certificate();
27
6
        match cert {
28
6
            Ok(Some(c)) => {
29
6
                let der = c
30
6
                    .to_der()
31
6
                    .map_err(|e| IoError::new(std::io::ErrorKind::Other, e))?;
32
6
                Ok(Some(der))
33
            }
34
            Ok(None) => Ok(None),
35
            Err(e) => Err(IoError::new(std::io::ErrorKind::Other, e)),
36
        }
37
6
    }
38

            
39
    fn export_keying_material(
40
        &self,
41
        _len: usize,
42
        _label: &[u8],
43
        _context: Option<&[u8]>,
44
    ) -> IoResult<Vec<u8>> {
45
        Err(std::io::Error::new(
46
            std::io::ErrorKind::Unsupported,
47
            tor_error::bad_api_usage!("native-tls does not support exporting keying material"),
48
        ))
49
    }
50
}
51

            
52
impl<S: AsyncRead + AsyncWrite + StreamOps + Unpin> StreamOps for async_native_tls::TlsStream<S> {
53
    fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
54
        self.get_ref().set_tcp_notsent_lowat(notsent_lowat)
55
    }
56

            
57
    fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
58
        self.get_ref().new_handle()
59
    }
60
}
61

            
62
/// An implementation of [`TlsConnector`] built with `native_tls`.
63
pub struct NativeTlsConnector<S> {
64
    /// The inner connector object.
65
    connector: async_native_tls::TlsConnector,
66
    /// Phantom data to ensure proper variance.
67
    _phantom: std::marker::PhantomData<fn(S) -> S>,
68
}
69

            
70
#[async_trait]
71
impl<S> TlsConnector<S> for NativeTlsConnector<S>
72
where
73
    S: AsyncRead + AsyncWrite + StreamOps + Unpin + Send + 'static,
74
{
75
    type Conn = async_native_tls::TlsStream<S>;
76

            
77
6
    async fn negotiate_unvalidated(&self, stream: S, sni_hostname: &str) -> IoResult<Self::Conn> {
78
6
        let conn = self
79
6
            .connector
80
6
            .connect(sni_hostname, stream)
81
6
            .await
82
6
            .map_err(|e| IoError::new(std::io::ErrorKind::Other, e))?;
83
6
        Ok(conn)
84
12
    }
85
}
86

            
87
impl<S> TlsProvider<S> for NativeTlsProvider
88
where
89
    S: AsyncRead + AsyncWrite + StreamOps + Unpin + Send + 'static,
90
{
91
    type Connector = NativeTlsConnector<S>;
92

            
93
    type TlsStream = async_native_tls::TlsStream<S>;
94

            
95
16
    fn tls_connector(&self) -> Self::Connector {
96
16
        let mut builder = native_tls::TlsConnector::builder();
97
16
        // These function names are scary, but they just mean that we
98
16
        // aren't checking whether the signer of this cert
99
16
        // participates in the web PKI, and we aren't checking the
100
16
        // hostname in the cert.
101
16
        builder
102
16
            .danger_accept_invalid_certs(true)
103
16
            .danger_accept_invalid_hostnames(true);
104
16

            
105
16
        // We don't participate in the web PKI, so there is no reason for us to load the standard
106
16
        // list of CAs and CRLs. This can save us an megabyte or two.
107
16
        builder.disable_built_in_roots(true);
108
16

            
109
16
        let connector = builder.into();
110
16

            
111
16
        NativeTlsConnector {
112
16
            connector,
113
16
            _phantom: std::marker::PhantomData,
114
16
        }
115
16
    }
116

            
117
    fn supports_keying_material_export(&self) -> bool {
118
        false
119
    }
120
}