1
//! Implementation for using `native_tls`
2

            
3
use crate::traits::{CertifiedConn, StreamOps, TlsConnector, TlsProvider};
4

            
5
use async_trait::async_trait;
6
use futures::{AsyncRead, AsyncWrite};
7
use native_tls_crate as native_tls;
8
use std::io::{Error as IoError, Result as IoResult};
9

            
10
/// A [`TlsProvider`] that uses `native_tls`.
11
///
12
/// It supports wrapping any reasonable stream type that implements `AsyncRead` + `AsyncWrite`.
13
#[cfg_attr(
14
    docsrs,
15
    doc(cfg(all(
16
        feature = "native-tls",
17
        any(feature = "tokio", feature = "async-std", feature = "smol")
18
    )))
19
)]
20
#[derive(Default, Clone)]
21
#[non_exhaustive]
22
pub struct NativeTlsProvider {}
23

            
24
impl<S> CertifiedConn for async_native_tls::TlsStream<S>
25
where
26
    S: AsyncRead + AsyncWrite + Unpin,
27
{
28
8
    fn peer_certificate(&self) -> IoResult<Option<Vec<u8>>> {
29
8
        let cert = self.peer_certificate();
30
8
        match cert {
31
8
            Ok(Some(c)) => {
32
8
                let der = c.to_der().map_err(IoError::other)?;
33
8
                Ok(Some(der))
34
            }
35
            Ok(None) => Ok(None),
36
            Err(e) => Err(IoError::other(e)),
37
        }
38
8
    }
39

            
40
    fn export_keying_material(
41
        &self,
42
        _len: usize,
43
        _label: &[u8],
44
        _context: Option<&[u8]>,
45
    ) -> IoResult<Vec<u8>> {
46
        Err(std::io::Error::new(
47
            std::io::ErrorKind::Unsupported,
48
            tor_error::bad_api_usage!("native-tls does not support exporting keying material"),
49
        ))
50
    }
51
}
52

            
53
impl<S: AsyncRead + AsyncWrite + StreamOps + Unpin> StreamOps for async_native_tls::TlsStream<S> {
54
    fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
55
        self.get_ref().set_tcp_notsent_lowat(notsent_lowat)
56
    }
57

            
58
    fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
59
        self.get_ref().new_handle()
60
    }
61
}
62

            
63
/// An implementation of [`TlsConnector`] built with `native_tls`.
64
pub struct NativeTlsConnector<S> {
65
    /// The inner connector object.
66
    connector: async_native_tls::TlsConnector,
67
    /// Phantom data to ensure proper variance.
68
    _phantom: std::marker::PhantomData<fn(S) -> S>,
69
}
70

            
71
#[async_trait]
72
impl<S> TlsConnector<S> for NativeTlsConnector<S>
73
where
74
    S: AsyncRead + AsyncWrite + StreamOps + Unpin + Send + 'static,
75
{
76
    type Conn = async_native_tls::TlsStream<S>;
77

            
78
8
    async fn negotiate_unvalidated(&self, stream: S, sni_hostname: &str) -> IoResult<Self::Conn> {
79
        let conn = self
80
            .connector
81
            .connect(sni_hostname, stream)
82
            .await
83
            .map_err(IoError::other)?;
84
        Ok(conn)
85
8
    }
86
}
87

            
88
impl<S> TlsProvider<S> for NativeTlsProvider
89
where
90
    S: AsyncRead + AsyncWrite + StreamOps + Unpin + Send + 'static,
91
{
92
    type Connector = NativeTlsConnector<S>;
93

            
94
    type TlsStream = async_native_tls::TlsStream<S>;
95

            
96
24
    fn tls_connector(&self) -> Self::Connector {
97
24
        let mut builder = native_tls::TlsConnector::builder();
98
        // These function names are scary, but they just mean that we
99
        // aren't checking whether the signer of this cert
100
        // participates in the web PKI, and we aren't checking the
101
        // hostname in the cert.
102
24
        builder
103
24
            .danger_accept_invalid_certs(true)
104
24
            .danger_accept_invalid_hostnames(true);
105

            
106
        // We don't participate in the web PKI, so there is no reason for us to load the standard
107
        // list of CAs and CRLs. This can save us an megabyte or two.
108
24
        builder.disable_built_in_roots(true);
109

            
110
24
        let connector = builder.into();
111

            
112
24
        NativeTlsConnector {
113
24
            connector,
114
24
            _phantom: std::marker::PhantomData,
115
24
        }
116
24
    }
117

            
118
    fn supports_keying_material_export(&self) -> bool {
119
        false
120
    }
121
}