dns_resolver/
dns.rs

1//! Houses the DNS-specifc code, including the structs that we pack the bytes
2//! into and suitable traits and implementations to convert to and from bytes
3//! and structs
4//!
5//! ### Disclaimer
6//! This is a very barebones DNS client implementation. It hard-codes a lot of
7//! values and is intended only for demonstration purposes on how even custom
8//! protocols over TCP can be tunnelled through Tor. It is not meant for any
9//! real production usage.
10use anyhow::Result;
11use std::fmt::Display;
12use thiserror::Error;
13use tracing::{debug, error};
14
15#[derive(Error, Debug)]
16#[error("Failed to parse bytes into struct!")]
17/// Generic error we return if we fail to parse bytes into the struct
18struct FromBytesError;
19
20#[derive(Error, Debug)]
21#[error("Invalid domain name passed")]
22/// Error we return if a bad domain name is passed
23pub struct DomainError;
24
25/// Hardcoded DNS server, stored as (&str, u16) detailing host and port
26pub const DNS_SERVER: (&str, u16) = ("1.1.1.1", 53);
27
28/// Default value for QTYPE field
29const QTYPE: u16 = 0x0001;
30/// Default value for QCLASS field
31const QCLASS: u16 = 0x0001;
32
33/// Used to convert struct to raw bytes to be sent over the network
34///
35/// Example:
36/// ```
37/// // We have some struct S that implements this trait
38/// let s = S::new();
39/// // This prints the raw bytes as debug output
40/// dbg!("{}", s.as_bytes());
41/// ```
42pub trait AsBytes {
43    /// Return a `Vec<u8>` of the same information stored in struct
44    ///
45    /// This is ideal to convert typed values into raw bytes to be sent
46    /// over the network.
47    fn as_bytes(&self) -> Vec<u8>;
48}
49
50/// Used to convert raw bytes representation into a Rust struct
51///
52/// Example:
53/// ```
54/// let mut buf: Vec<u8> = Vec::new();
55/// // Read the response from a stream
56/// stream.read_to_end(&mut buf).await.unwrap();
57/// // Interpret the response into a struct S
58/// let resp = S::from_bytes(&buf);
59/// ```
60///
61/// In the above code, `resp` is `Option<Box<S>>` type, so you will have to
62/// deal with the `None` value appropriately. This helps denote invalid
63/// situations, ie, parse failures
64///
65/// You will have to interpret each byte and convert it into each field
66/// of your struct yourself when implementing this trait.
67pub trait FromBytes {
68    /// Convert two u8's into a u16
69    ///
70    /// It is just a thin wrapper over [u16::from_be_bytes()]
71    fn u8_to_u16(upper: u8, lower: u8) -> u16 {
72        let bytes = [upper, lower];
73        u16::from_be_bytes(bytes)
74    }
75    /// Convert four u8's contained in a slice into a u32
76    ///
77    /// It is just a thin wrapper over [u32::from_be_bytes()] but also deals
78    /// with converting &\[u8\] (u8 slice) into [u8; 4] (a fixed size array of u8)
79    fn u8_to_u32(bytes_slice: &[u8]) -> Result<u32> {
80        let bytes: [u8; 4] = bytes_slice.try_into()?;
81        Ok(u32::from_be_bytes(bytes))
82    }
83    /// Try converting given bytes into the struct
84    ///
85    /// Returns an `Option<Box>` of the struct which implements
86    /// this trait to help denote parsing failures
87    fn from_bytes(bytes: &[u8]) -> Result<Box<Self>>;
88}
89
90/// Report length of the struct as in byte stream
91///
92/// Note that this doesn't mean length of struct
93///
94/// It is simply used to denote how long the struct is if it were
95/// sent over the wire
96trait Len {
97    /// Report length of the struct as in byte stream
98    fn len(&self) -> usize;
99}
100
101/// DNS Header to be used by both Query and Response
102///
103/// The default values chosen are from the perspective of the client
104// TODO: For server we will have to interpret given values
105struct Header {
106    /// Random 16 bit number used to identify the DNS request
107    identification: u16,
108    /// Set of fields packed together into one 16 bit number
109    ///
110    /// Refer to RFC 1035 for more info, but here's a small
111    /// layout of what is packed into this row:
112    ///
113    ///
114    ///   0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
115    /// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
116    /// |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
117    /// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
118    ///
119    /// TODO: don't rely on cryptic packed bits
120    packed_second_row: u16, // set to 0x100
121    /// Number of questions we have
122    ///
123    /// Here, we set it to 1 since we only ask about one hostname in a query
124    qdcount: u16, // set to 1 since we have 1 question
125    /// Number of answers we have
126    ///
127    /// For a query it will be zero, for a response hopefully it is >= 1
128    ancount: u16, // set to 0 since client doesn't have answers
129    /// Refer to RFC 1035 section 4.1.1, NSCOUNT
130    nscount: u16, // set to 0
131    /// Refer to RFC 1035 section 4.1.1, ARCOUNT
132    arcount: u16, // set to 0
133}
134
135// Ugly, repetitive code to convert all six 16-bit fields into Vec<u8>
136impl AsBytes for Header {
137    fn as_bytes(&self) -> Vec<u8> {
138        let mut v: Vec<u8> = Vec::with_capacity(14);
139        // These 2 bytes store size of the rest of the payload (including header)
140        // Right now it denotes 51 byte size packet, excluding these 2 bytes
141        // We will change this when we know the size of Query
142        v.push(0x00);
143        v.push(0x33);
144        // Just break u16 into [u8, u8] array and copy into vector
145        v.extend_from_slice(&u16::to_be_bytes(self.identification));
146        v.extend_from_slice(&u16::to_be_bytes(self.packed_second_row));
147        v.extend_from_slice(&u16::to_be_bytes(self.qdcount));
148        v.extend_from_slice(&u16::to_be_bytes(self.ancount));
149        v.extend_from_slice(&u16::to_be_bytes(self.nscount));
150        v.extend_from_slice(&u16::to_be_bytes(self.arcount));
151        v
152    }
153}
154
155impl Display for Header {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        writeln!(f, "ID: 0x{:x}", self.identification)?;
158        writeln!(f, "Flags: 0x{:x}", self.packed_second_row)?;
159        writeln!(f, "QDCOUNT: 0x{:x}", self.qdcount)?;
160        writeln!(f, "ANCOUNT: 0x{:x}", self.ancount)?;
161        writeln!(f, "NSCOUNT: 0x{:x}", self.nscount)?;
162        writeln!(f, "ARCOUNT: 0x{:x}", self.arcount)?;
163        Ok(())
164    }
165}
166
167impl FromBytes for Header {
168    fn from_bytes(bytes: &[u8]) -> Result<Box<Self>> {
169        debug!("Parsing the header");
170        let packed_second_row = Header::u8_to_u16(bytes[2], bytes[3]);
171        // 0x8180 denotes we have a response to a standard query,
172        // that isn't truncated, and has recursion requested to a server
173        // that can do recursion, with some bits reserved for future use
174        // and some that are not relevant for our purposes
175        if packed_second_row == 0x8180 {
176            debug!("Correct flags set in response");
177        } else {
178            error!(
179                "Incorrect flags set in response, we got {}",
180                packed_second_row
181            );
182            return Err(FromBytesError.into());
183        }
184        // These offsets were determined by looking at RFC 1035
185        Ok(Box::new(Header {
186            identification: Header::u8_to_u16(bytes[0], bytes[1]),
187            packed_second_row,
188            qdcount: Header::u8_to_u16(bytes[4], bytes[5]),
189            ancount: Header::u8_to_u16(bytes[6], bytes[7]),
190            nscount: Header::u8_to_u16(bytes[8], bytes[9]),
191            arcount: Header::u8_to_u16(bytes[10], bytes[11]),
192        }))
193    }
194}
195
196/// The actual query we will send to a DNS server
197///
198/// For now A records are fetched only
199// TODO: add support for different records to be fetched
200pub struct Query {
201    /// Header of the DNS packet, see [Header] for more info
202    header: Header,
203    /// The domain name, stored as a `Vec<u8>`
204    ///
205    /// When we call [Query::from_bytes()], `qname` is automatically
206    /// converted into string stored in a `Vec<u8>` instead of the raw
207    /// byte format used for `qname`
208    qname: Vec<u8>, // domain name
209    /// Denotes the type of record to get.
210    ///
211    /// Here we set to 1 to get an A record, ie, IPv4
212    qtype: u16, // set to 0x0001 for A records
213    /// Denotes the class of the record
214    ///
215    /// Here we set to 1 to get an Internet address
216    qclass: u16, // set to 1 for Internet addresses
217}
218
219impl AsBytes for Query {
220    fn as_bytes(&self) -> Vec<u8> {
221        let mut v: Vec<u8> = Vec::new();
222        let header_bytes = self.header.as_bytes();
223        v.extend(header_bytes);
224        v.extend(&self.qname);
225        v.extend_from_slice(&u16::to_be_bytes(self.qtype));
226        v.extend_from_slice(&u16::to_be_bytes(self.qclass));
227        // Now that the packet is ready, we can calculate size and set that in
228        // first two octets
229        // Subtract 2 since these first 2 bits are never counted when reporting
230        // length like this
231        let len_bits = u16::to_be_bytes((v.len() - 2) as u16);
232        v[0] = len_bits[0];
233        v[1] = len_bits[1];
234        v
235    }
236}
237
238impl Len for Query {
239    fn len(&self) -> usize {
240        // extra 1 is for compensating for how we
241        // use one byte more to store length of domain name
242        12 + 1 + self.qname.len() + 2 + 2
243    }
244}
245
246impl FromBytes for Query {
247    // FIXME: the name struct isn't stored as it was sent over the wire
248    fn from_bytes(bytes: &[u8]) -> Result<Box<Self>> {
249        let header = *Header::from_bytes(&bytes[..12])?;
250        if bytes.len() < 12 {
251            error!("Mismatch between expected number of bytes and given number of bytes!");
252            return Err(FromBytesError.into());
253        }
254        // Parse name
255        let mut name = String::new();
256        // 12 represents size of Header, which we have already parsed, or errored out of
257        let mut lastnamebyte = 12;
258        loop {
259            // bytes[lastnamebytes] denotes the prefix length, we read that many bytes into name
260            let start = lastnamebyte + 1;
261            let end = start + bytes[lastnamebyte] as usize;
262            name.extend(std::str::from_utf8(&bytes[start..end]));
263            lastnamebyte = end;
264            if lastnamebyte >= bytes.len() || bytes[lastnamebyte] == 0 {
265                // End of domain name, proceed to parse further fields
266                debug!("Reached end of name, moving on to parse other fields");
267                lastnamebyte += 1;
268                break;
269            }
270            name.push('.');
271        }
272        // These offsets were determined by looking at RFC 1035
273        Ok(Box::new(Self {
274            header,
275            qname: name.as_bytes().to_vec(),
276            qtype: Query::u8_to_u16(bytes[lastnamebyte], bytes[lastnamebyte + 1]),
277            qclass: Query::u8_to_u16(bytes[lastnamebyte + 2], bytes[lastnamebyte + 3]),
278        }))
279    }
280}
281
282/// A struct which represents one RR
283struct ResourceRecord {
284    /// Denotes the record type
285    ///
286    /// It is similar to [Query::qtype]
287    rtype: u16, // same as in Query
288    /// Denotes the class of the record
289    ///
290    /// It is similar to [Query::qclass]
291    class: u16, // same as in Query
292    /// The TTL denotes the amount of time in seconds we can cache the result
293    ///
294    /// After the TTL expires, we have to make a fresh request since this
295    /// answer is not guaranteed to be correct
296    ttl: u32, // number of seconds to cache the result
297    /// Denotes the length of data
298    ///
299    /// For this implementation we only request IPv4 addresses, so its value
300    /// will be 4.
301    rdlength: u16, // Length of RDATA
302    /// The actual answer we need
303    ///
304    /// It is an IPv4 address for us in this case
305    rdata: [u8; 4], // IP address
306}
307
308impl Len for ResourceRecord {
309    // return number of bytes it consumes
310    fn len(&self) -> usize {
311        let mut size = 0;
312        size += 2; // name, even though we don't store it here
313        size += 2; // rtype
314        size += 2; // class
315        size += 4; // ttl
316        size += 2; // rdlength
317        size += 4; // rdata
318        size
319    }
320}
321
322impl FromBytes for ResourceRecord {
323    fn from_bytes(bytes: &[u8]) -> Result<Box<Self>> {
324        let lastnamebyte = 1;
325        let mut rdata = [0u8; 4];
326        if bytes.len() < 15 {
327            return Err(FromBytesError.into());
328        }
329        // Copy over IP address into rdata
330        rdata.copy_from_slice(&bytes[lastnamebyte + 10..lastnamebyte + 14]);
331        // These offsets were determined by looking at RFC 1035
332        Ok(Box::new(Self {
333            rtype: ResourceRecord::u8_to_u16(bytes[lastnamebyte], bytes[lastnamebyte + 1]),
334            class: ResourceRecord::u8_to_u16(bytes[lastnamebyte + 2], bytes[lastnamebyte + 3]),
335            ttl: ResourceRecord::u8_to_u32(&bytes[lastnamebyte + 4..lastnamebyte + 8])?,
336            rdlength: Response::u8_to_u16(bytes[lastnamebyte + 8], bytes[lastnamebyte + 9]),
337            rdata,
338        }))
339    }
340}
341
342impl Display for ResourceRecord {
343    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        writeln!(f, "RR record type: 0x{:x}", self.rtype)?;
345        writeln!(f, "RR class: 0x{:x}", self.class)?;
346        writeln!(f, "TTL: {}", self.ttl)?;
347        writeln!(f, "RDLENGTH: 0x{:x}", self.rdlength)?;
348        writeln!(
349            f,
350            "IP address: {}.{}.{}.{}",
351            self.rdata[0], self.rdata[1], self.rdata[2], self.rdata[3]
352        )?;
353        Ok(())
354    }
355}
356
357/// Stores the response in easy to interpret manner
358///
359/// A Response is made up of the query given to the server and a bunch of
360/// Resource Records (RR). Each RR will include the resource type, class, and
361/// name. For the A records we're requesting, we will get an A record, of Internet class,
362/// ie an IPv4 address
363pub struct Response {
364    /// The Query part of the response we obtain from the server
365    query: Query,
366    /// A collection of resource records all parsed neatly and kept separately
367    /// for easy iteration
368    rr: Vec<ResourceRecord>,
369}
370
371impl FromBytes for Response {
372    // Try to construct Response from raw byte data from network
373    // We will also try to check if a valid DNS response has been sent back to us
374    fn from_bytes(bytes: &[u8]) -> Result<Box<Self>> {
375        debug!("Parsing response into struct");
376        // Check message length
377        let l = bytes.len();
378        let messagelen = Response::u8_to_u16(bytes[0], bytes[1]);
379        if messagelen == (l - 2) as u16 {
380            debug!("Appear to have gotten good message from server");
381        } else {
382            error!(
383                "Expected and observed message length don't match: {} and {} respectively",
384                l - 2,
385                messagelen
386            );
387        }
388        // Start index at 2 to skip over message length bytes
389        let mut index = 2;
390        let query = *Query::from_bytes(&bytes[index..])?;
391        index += query.len() + 2; // TODO: needs explanation why it works
392        let mut rrvec: Vec<ResourceRecord> = Vec::new();
393        while index < l {
394            match ResourceRecord::from_bytes(&bytes[index..]) {
395                Ok(rr) => {
396                    index += rr.len();
397                    rrvec.push(*rr);
398                }
399                Err(_) => break,
400            }
401        }
402        Ok(Box::new(Response { query, rr: rrvec }))
403    }
404}
405
406impl Display for Response {
407    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408        writeln!(f, "{}", self.query.header)?;
409        writeln!(
410            f,
411            "Name: {}",
412            String::from_utf8(self.query.qname.to_owned()).unwrap()
413        )?;
414        writeln!(f, "Res type: 0x{:x}", self.query.qtype)?;
415        writeln!(f, "Class: 0x{:x}", self.query.qclass)?;
416        for record in self.rr.iter() {
417            writeln!(f)?;
418            writeln!(f, "{}", record)?;
419        }
420        Ok(())
421    }
422}
423
424/// Craft the actual query for a particular domain and returns a Query object
425///
426/// The query is made for an A record of type Internet, ie, a normal IPv4 address
427/// should be returned from the DNS server.
428///
429/// Convert this Query into bytes to be sent over the network by calling [Query::as_bytes()]
430pub fn build_query(domain: &str) -> Result<Query, DomainError> {
431    // TODO: generate identification randomly
432    let header = Header {
433        identification: 0x304e, // chosen by random dice roll, secure
434        packed_second_row: 0x0100,
435        qdcount: 0x0001,
436        ancount: 0x0000,
437        nscount: 0x0000,
438        arcount: 0x0000,
439    };
440    let mut qname: Vec<u8> = Vec::new();
441    let split_domain: Vec<&str> = domain.split('.').collect();
442    for part in split_domain {
443        if part.is_empty() {
444            return Err(DomainError);
445        }
446        let l = part.len() as u8;
447        if l != 0 {
448            qname.push(l);
449            qname.extend_from_slice(part.as_bytes());
450        }
451    }
452    qname.push(0x00); // Denote that hostname has ended by pushing 0x00
453    debug!("Crafted query successfully!");
454    Ok(Query {
455        header,
456        qname,
457        qtype: QTYPE,
458        qclass: QCLASS,
459    })
460}