1
//! Declare a type for streams that do hostname lookups
2

            
3
use crate::memquota::StreamAccount;
4
use crate::stream::StreamReceiver;
5
use crate::{Error, Result};
6

            
7
use futures::StreamExt;
8
use tor_cell::relaycell::msg::Resolved;
9
use tor_cell::relaycell::RelayCmd;
10
use tor_cell::restricted_msg;
11

            
12
use super::AnyCmdChecker;
13

            
14
/// A ResolveStream represents a pending DNS request made with a RESOLVE
15
/// cell.
16
pub struct ResolveStream {
17
    /// The underlying RawCellStream.
18
    s: StreamReceiver,
19

            
20
    /// The memory quota account that should be used for this "stream"'s data
21
    ///
22
    /// Exists to keep the account alive
23
    _memquota: StreamAccount,
24
}
25

            
26
restricted_msg! {
27
    /// An allowable reply for a RESOLVE message.
28
    enum ResolveResponseMsg : RelayMsg {
29
        End,
30
        Resolved,
31
    }
32
}
33

            
34
impl ResolveStream {
35
    /// Wrap a RawCellStream into a ResolveStream.
36
    ///
37
    /// Call only after sending a RESOLVE cell.
38
    pub(crate) fn new(s: StreamReceiver, memquota: StreamAccount) -> Self {
39
        ResolveStream {
40
            s,
41
            _memquota: memquota,
42
        }
43
    }
44

            
45
    /// Read a message from this stream telling us the answer to our
46
    /// name lookup request.
47
    pub async fn read_msg(&mut self) -> Result<Resolved> {
48
        use ResolveResponseMsg::*;
49
        let cell = match self.s.next().await {
50
            Some(cell) => cell?,
51
            None => return Err(Error::NotConnected),
52
        };
53
        let msg = match cell.decode::<ResolveResponseMsg>() {
54
            Ok(cell) => cell.into_msg(),
55
            Err(e) => {
56
                self.s.protocol_error();
57
                return Err(Error::from_bytes_err(e, "response on a resolve stream"));
58
            }
59
        };
60
        match msg {
61
            End(e) => Err(Error::EndReceived(e.reason())),
62
            Resolved(r) => Ok(r),
63
        }
64
    }
65
}
66

            
67
/// A `CmdChecker` that enforces correctness for incoming commands on an
68
/// outbound resolve stream.
69
#[derive(Debug, Default)]
70
pub(crate) struct ResolveCmdChecker {}
71

            
72
impl super::CmdChecker for ResolveCmdChecker {
73
    fn check_msg(
74
        &mut self,
75
        msg: &tor_cell::relaycell::UnparsedRelayMsg,
76
    ) -> Result<super::StreamStatus> {
77
        use super::StreamStatus::Closed;
78
        match msg.cmd() {
79
            RelayCmd::RESOLVED => Ok(Closed),
80
            RelayCmd::END => Ok(Closed),
81
            _ => Err(Error::StreamProto(format!(
82
                "Unexpected {} on resolve stream",
83
                msg.cmd()
84
            ))),
85
        }
86
    }
87

            
88
    fn consume_checked_msg(&mut self, msg: tor_cell::relaycell::UnparsedRelayMsg) -> Result<()> {
89
        let _ = msg
90
            .decode::<ResolveResponseMsg>()
91
            .map_err(|err| Error::from_bytes_err(err, "message on resolve stream."))?;
92
        Ok(())
93
    }
94
}
95

            
96
impl ResolveCmdChecker {
97
    /// Return a new boxed `DataCmdChecker` in a state suitable for a newly
98
    /// constructed connection.
99
    pub(crate) fn new_any() -> AnyCmdChecker {
100
        Box::<Self>::default()
101
    }
102
}