1
//! Declare a type for streams that do hostname lookups
2

            
3
use crate::memquota::StreamAccount;
4
use crate::stream::StreamReader;
5
use crate::{Error, Result};
6
use tor_cell::relaycell::msg::Resolved;
7
use tor_cell::relaycell::RelayCmd;
8
use tor_cell::restricted_msg;
9

            
10
use super::AnyCmdChecker;
11

            
12
/// A ResolveStream represents a pending DNS request made with a RESOLVE
13
/// cell.
14
pub struct ResolveStream {
15
    /// The underlying RawCellStream.
16
    s: StreamReader,
17

            
18
    /// The memory quota account that should be used for this "stream"'s data
19
    ///
20
    /// Exists to keep the account alive
21
    _memquota: StreamAccount,
22
}
23

            
24
restricted_msg! {
25
    /// An allowable reply for a RESOLVE message.
26
    enum ResolveResponseMsg : RelayMsg {
27
        End,
28
        Resolved,
29
    }
30
}
31

            
32
impl ResolveStream {
33
    /// Wrap a RawCellStream into a ResolveStream.
34
    ///
35
    /// Call only after sending a RESOLVE cell.
36
    pub(crate) fn new(s: StreamReader, memquota: StreamAccount) -> Self {
37
        ResolveStream {
38
            s,
39
            _memquota: memquota,
40
        }
41
    }
42

            
43
    /// Read a message from this stream telling us the answer to our
44
    /// name lookup request.
45
    pub async fn read_msg(&mut self) -> Result<Resolved> {
46
        use ResolveResponseMsg::*;
47
        let cell = self.s.recv().await?;
48
        let msg = match cell.decode::<ResolveResponseMsg>() {
49
            Ok(cell) => cell.into_msg(),
50
            Err(e) => {
51
                self.s.protocol_error();
52
                return Err(Error::from_bytes_err(e, "response on a resolve stream"));
53
            }
54
        };
55
        match msg {
56
            End(e) => Err(Error::EndReceived(e.reason())),
57
            Resolved(r) => Ok(r),
58
        }
59
    }
60
}
61

            
62
/// A `CmdChecker` that enforces correctness for incoming commands on an
63
/// outbound resolve stream.
64
#[derive(Debug, Default)]
65
pub(crate) struct ResolveCmdChecker {}
66

            
67
impl super::CmdChecker for ResolveCmdChecker {
68
    fn check_msg(
69
        &mut self,
70
        msg: &tor_cell::relaycell::UnparsedRelayMsg,
71
    ) -> Result<super::StreamStatus> {
72
        use super::StreamStatus::Closed;
73
        match msg.cmd() {
74
            RelayCmd::RESOLVED => Ok(Closed),
75
            RelayCmd::END => Ok(Closed),
76
            _ => Err(Error::StreamProto(format!(
77
                "Unexpected {} on resolve stream",
78
                msg.cmd()
79
            ))),
80
        }
81
    }
82

            
83
    fn consume_checked_msg(&mut self, msg: tor_cell::relaycell::UnparsedRelayMsg) -> Result<()> {
84
        let _ = msg
85
            .decode::<ResolveResponseMsg>()
86
            .map_err(|err| Error::from_bytes_err(err, "message on resolve stream."))?;
87
        Ok(())
88
    }
89
}
90

            
91
impl ResolveCmdChecker {
92
    /// Return a new boxed `DataCmdChecker` in a state suitable for a newly
93
    /// constructed connection.
94
    pub(crate) fn new_any() -> AnyCmdChecker {
95
        Box::<Self>::default()
96
    }
97
}