1
//! Houses the DNS-specifc code, including the structs that we pack the bytes
2
//! into and suitable traits and implementations to convert to and from bytes
3
//! and structs
4
//!
5
//! ### Disclaimer
6
//! This is a very barebones DNS client implementation. It hard-codes a lot of
7
//! values and is intended only for demonstration purposes on how even custom
8
//! protocols over TCP can be tunnelled through Tor. It is not meant for any
9
//! real production usage.
10
use anyhow::Result;
11
use std::fmt::Display;
12
use thiserror::Error;
13
use tracing::{debug, error};
14

            
15
#[derive(Error, Debug)]
16
#[error("Failed to parse bytes into struct!")]
17
/// Generic error we return if we fail to parse bytes into the struct
18
struct FromBytesError;
19

            
20
#[derive(Error, Debug)]
21
#[error("Invalid domain name passed")]
22
/// Error we return if a bad domain name is passed
23
pub struct DomainError;
24

            
25
/// Hardcoded DNS server, stored as (&str, u16) detailing host and port
26
pub const DNS_SERVER: (&str, u16) = ("1.1.1.1", 53);
27

            
28
/// Default value for QTYPE field
29
const QTYPE: u16 = 0x0001;
30
/// Default value for QCLASS field
31
const QCLASS: u16 = 0x0001;
32

            
33
/// Used to convert struct to raw bytes to be sent over the network
34
///
35
/// Example:
36
/// ```
37
/// // We have some struct S that implements this trait
38
/// let s = S::new();
39
/// // This prints the raw bytes as debug output
40
/// dbg!("{}", s.as_bytes());
41
/// ```
42
pub trait AsBytes {
43
    /// Return a `Vec<u8>` of the same information stored in struct
44
    ///
45
    /// This is ideal to convert typed values into raw bytes to be sent
46
    /// over the network.
47
    fn as_bytes(&self) -> Vec<u8>;
48
}
49

            
50
/// Used to convert raw bytes representation into a Rust struct
51
///
52
/// Example:
53
/// ```
54
/// let mut buf: Vec<u8> = Vec::new();
55
/// // Read the response from a stream
56
/// stream.read_to_end(&mut buf).await.unwrap();
57
/// // Interpret the response into a struct S
58
/// let resp = S::from_bytes(&buf);
59
/// ```
60
///
61
/// In the above code, `resp` is `Option<Box<S>>` type, so you will have to
62
/// deal with the `None` value appropriately. This helps denote invalid
63
/// situations, ie, parse failures
64
///
65
/// You will have to interpret each byte and convert it into each field
66
/// of your struct yourself when implementing this trait.
67
pub trait FromBytes {
68
    /// Convert two u8's into a u16
69
    ///
70
    /// It is just a thin wrapper over [u16::from_be_bytes()]
71
    fn u8_to_u16(upper: u8, lower: u8) -> u16 {
72
        let bytes = [upper, lower];
73
        u16::from_be_bytes(bytes)
74
    }
75
    /// Convert four u8's contained in a slice into a u32
76
    ///
77
    /// It is just a thin wrapper over [u32::from_be_bytes()] but also deals
78
    /// with converting &\[u8\] (u8 slice) into [u8; 4] (a fixed size array of u8)
79
    fn u8_to_u32(bytes_slice: &[u8]) -> Result<u32> {
80
        let bytes: [u8; 4] = bytes_slice.try_into()?;
81
        Ok(u32::from_be_bytes(bytes))
82
    }
83
    /// Try converting given bytes into the struct
84
    ///
85
    /// Returns an `Option<Box>` of the struct which implements
86
    /// this trait to help denote parsing failures
87
    fn from_bytes(bytes: &[u8]) -> Result<Box<Self>>;
88
}
89

            
90
/// Report length of the struct as in byte stream
91
///
92
/// Note that this doesn't mean length of struct
93
///
94
/// It is simply used to denote how long the struct is if it were
95
/// sent over the wire
96
trait Len {
97
    /// Report length of the struct as in byte stream
98
    fn len(&self) -> usize;
99
}
100

            
101
/// DNS Header to be used by both Query and Response
102
///
103
/// The default values chosen are from the perspective of the client
104
// TODO: For server we will have to interpret given values
105
struct Header {
106
    /// Random 16 bit number used to identify the DNS request
107
    identification: u16,
108
    /// Set of fields packed together into one 16 bit number
109
    ///
110
    /// Refer to RFC 1035 for more info, but here's a small
111
    /// layout of what is packed into this row:
112
    ///
113
    ///
114
    ///   0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
115
    /// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
116
    /// |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
117
    /// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
118
    ///
119
    /// TODO: don't rely on cryptic packed bits
120
    packed_second_row: u16, // set to 0x100
121
    /// Number of questions we have
122
    ///
123
    /// Here, we set it to 1 since we only ask about one hostname in a query
124
    qdcount: u16, // set to 1 since we have 1 question
125
    /// Number of answers we have
126
    ///
127
    /// For a query it will be zero, for a response hopefully it is >= 1
128
    ancount: u16, // set to 0 since client doesn't have answers
129
    /// Refer to RFC 1035 section 4.1.1, NSCOUNT
130
    nscount: u16, // set to 0
131
    /// Refer to RFC 1035 section 4.1.1, ARCOUNT
132
    arcount: u16, // set to 0
133
}
134

            
135
// Ugly, repetitive code to convert all six 16-bit fields into Vec<u8>
136
impl AsBytes for Header {
137
    fn as_bytes(&self) -> Vec<u8> {
138
        let mut v: Vec<u8> = Vec::with_capacity(14);
139
        // These 2 bytes store size of the rest of the payload (including header)
140
        // Right now it denotes 51 byte size packet, excluding these 2 bytes
141
        // We will change this when we know the size of Query
142
        v.push(0x00);
143
        v.push(0x33);
144
        // Just break u16 into [u8, u8] array and copy into vector
145
        v.extend_from_slice(&u16::to_be_bytes(self.identification));
146
        v.extend_from_slice(&u16::to_be_bytes(self.packed_second_row));
147
        v.extend_from_slice(&u16::to_be_bytes(self.qdcount));
148
        v.extend_from_slice(&u16::to_be_bytes(self.ancount));
149
        v.extend_from_slice(&u16::to_be_bytes(self.nscount));
150
        v.extend_from_slice(&u16::to_be_bytes(self.arcount));
151
        v
152
    }
153
}
154

            
155
impl Display for Header {
156
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157
        writeln!(f, "ID: 0x{:x}", self.identification)?;
158
        writeln!(f, "Flags: 0x{:x}", self.packed_second_row)?;
159
        writeln!(f, "QDCOUNT: 0x{:x}", self.qdcount)?;
160
        writeln!(f, "ANCOUNT: 0x{:x}", self.ancount)?;
161
        writeln!(f, "NSCOUNT: 0x{:x}", self.nscount)?;
162
        writeln!(f, "ARCOUNT: 0x{:x}", self.arcount)?;
163
        Ok(())
164
    }
165
}
166

            
167
impl FromBytes for Header {
168
    fn from_bytes(bytes: &[u8]) -> Result<Box<Self>> {
169
        debug!("Parsing the header");
170
        let packed_second_row = Header::u8_to_u16(bytes[2], bytes[3]);
171
        // 0x8180 denotes we have a response to a standard query,
172
        // that isn't truncated, and has recursion requested to a server
173
        // that can do recursion, with some bits reserved for future use
174
        // and some that are not relevant for our purposes
175
        if packed_second_row == 0x8180 {
176
            debug!("Correct flags set in response");
177
        } else {
178
            error!(
179
                "Incorrect flags set in response, we got {}",
180
                packed_second_row
181
            );
182
            return Err(FromBytesError.into());
183
        }
184
        // These offsets were determined by looking at RFC 1035
185
        Ok(Box::new(Header {
186
            identification: Header::u8_to_u16(bytes[0], bytes[1]),
187
            packed_second_row,
188
            qdcount: Header::u8_to_u16(bytes[4], bytes[5]),
189
            ancount: Header::u8_to_u16(bytes[6], bytes[7]),
190
            nscount: Header::u8_to_u16(bytes[8], bytes[9]),
191
            arcount: Header::u8_to_u16(bytes[10], bytes[11]),
192
        }))
193
    }
194
}
195

            
196
/// The actual query we will send to a DNS server
197
///
198
/// For now A records are fetched only
199
// TODO: add support for different records to be fetched
200
pub struct Query {
201
    /// Header of the DNS packet, see [Header] for more info
202
    header: Header,
203
    /// The domain name, stored as a `Vec<u8>`
204
    ///
205
    /// When we call [Query::from_bytes()], `qname` is automatically
206
    /// converted into string stored in a `Vec<u8>` instead of the raw
207
    /// byte format used for `qname`
208
    qname: Vec<u8>, // domain name
209
    /// Denotes the type of record to get.
210
    ///
211
    /// Here we set to 1 to get an A record, ie, IPv4
212
    qtype: u16, // set to 0x0001 for A records
213
    /// Denotes the class of the record
214
    ///
215
    /// Here we set to 1 to get an Internet address
216
    qclass: u16, // set to 1 for Internet addresses
217
}
218

            
219
impl AsBytes for Query {
220
    fn as_bytes(&self) -> Vec<u8> {
221
        let mut v: Vec<u8> = Vec::new();
222
        let header_bytes = self.header.as_bytes();
223
        v.extend(header_bytes);
224
        v.extend(&self.qname);
225
        v.extend_from_slice(&u16::to_be_bytes(self.qtype));
226
        v.extend_from_slice(&u16::to_be_bytes(self.qclass));
227
        // Now that the packet is ready, we can calculate size and set that in
228
        // first two octets
229
        // Subtract 2 since these first 2 bits are never counted when reporting
230
        // length like this
231
        let len_bits = u16::to_be_bytes((v.len() - 2) as u16);
232
        v[0] = len_bits[0];
233
        v[1] = len_bits[1];
234
        v
235
    }
236
}
237

            
238
impl Len for Query {
239
    fn len(&self) -> usize {
240
        // extra 1 is for compensating for how we
241
        // use one byte more to store length of domain name
242
        12 + 1 + self.qname.len() + 2 + 2
243
    }
244
}
245

            
246
impl FromBytes for Query {
247
    // FIXME: the name struct isn't stored as it was sent over the wire
248
    fn from_bytes(bytes: &[u8]) -> Result<Box<Self>> {
249
        let header = *Header::from_bytes(&bytes[..12])?;
250
        if bytes.len() < 12 {
251
            error!("Mismatch between expected number of bytes and given number of bytes!");
252
            return Err(FromBytesError.into());
253
        }
254
        // Parse name
255
        let mut name = String::new();
256
        // 12 represents size of Header, which we have already parsed, or errored out of
257
        let mut lastnamebyte = 12;
258
        loop {
259
            // bytes[lastnamebytes] denotes the prefix length, we read that many bytes into name
260
            let start = lastnamebyte + 1;
261
            let end = start + bytes[lastnamebyte] as usize;
262
            name.extend(std::str::from_utf8(&bytes[start..end]));
263
            lastnamebyte = end;
264
            if lastnamebyte >= bytes.len() || bytes[lastnamebyte] == 0 {
265
                // End of domain name, proceed to parse further fields
266
                debug!("Reached end of name, moving on to parse other fields");
267
                lastnamebyte += 1;
268
                break;
269
            }
270
            name.push('.');
271
        }
272
        // These offsets were determined by looking at RFC 1035
273
        Ok(Box::new(Self {
274
            header,
275
            qname: name.as_bytes().to_vec(),
276
            qtype: Query::u8_to_u16(bytes[lastnamebyte], bytes[lastnamebyte + 1]),
277
            qclass: Query::u8_to_u16(bytes[lastnamebyte + 2], bytes[lastnamebyte + 3]),
278
        }))
279
    }
280
}
281

            
282
/// A struct which represents one RR
283
struct ResourceRecord {
284
    /// Denotes the record type
285
    ///
286
    /// It is similar to [Query::qtype]
287
    rtype: u16, // same as in Query
288
    /// Denotes the class of the record
289
    ///
290
    /// It is similar to [Query::qclass]
291
    class: u16, // same as in Query
292
    /// The TTL denotes the amount of time in seconds we can cache the result
293
    ///
294
    /// After the TTL expires, we have to make a fresh request since this
295
    /// answer is not guaranteed to be correct
296
    ttl: u32, // number of seconds to cache the result
297
    /// Denotes the length of data
298
    ///
299
    /// For this implementation we only request IPv4 addresses, so its value
300
    /// will be 4.
301
    rdlength: u16, // Length of RDATA
302
    /// The actual answer we need
303
    ///
304
    /// It is an IPv4 address for us in this case
305
    rdata: [u8; 4], // IP address
306
}
307

            
308
impl Len for ResourceRecord {
309
    // return number of bytes it consumes
310
    fn len(&self) -> usize {
311
        let mut size = 0;
312
        size += 2; // name, even though we don't store it here
313
        size += 2; // rtype
314
        size += 2; // class
315
        size += 4; // ttl
316
        size += 2; // rdlength
317
        size += 4; // rdata
318
        size
319
    }
320
}
321

            
322
impl FromBytes for ResourceRecord {
323
    fn from_bytes(bytes: &[u8]) -> Result<Box<Self>> {
324
        let lastnamebyte = 1;
325
        let mut rdata = [0u8; 4];
326
        if bytes.len() < 15 {
327
            return Err(FromBytesError.into());
328
        }
329
        // Copy over IP address into rdata
330
        rdata.copy_from_slice(&bytes[lastnamebyte + 10..lastnamebyte + 14]);
331
        // These offsets were determined by looking at RFC 1035
332
        Ok(Box::new(Self {
333
            rtype: ResourceRecord::u8_to_u16(bytes[lastnamebyte], bytes[lastnamebyte + 1]),
334
            class: ResourceRecord::u8_to_u16(bytes[lastnamebyte + 2], bytes[lastnamebyte + 3]),
335
            ttl: ResourceRecord::u8_to_u32(&bytes[lastnamebyte + 4..lastnamebyte + 8])?,
336
            rdlength: Response::u8_to_u16(bytes[lastnamebyte + 8], bytes[lastnamebyte + 9]),
337
            rdata,
338
        }))
339
    }
340
}
341

            
342
impl Display for ResourceRecord {
343
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344
        writeln!(f, "RR record type: 0x{:x}", self.rtype)?;
345
        writeln!(f, "RR class: 0x{:x}", self.class)?;
346
        writeln!(f, "TTL: {}", self.ttl)?;
347
        writeln!(f, "RDLENGTH: 0x{:x}", self.rdlength)?;
348
        writeln!(
349
            f,
350
            "IP address: {}.{}.{}.{}",
351
            self.rdata[0], self.rdata[1], self.rdata[2], self.rdata[3]
352
        )?;
353
        Ok(())
354
    }
355
}
356

            
357
/// Stores the response in easy to interpret manner
358
///
359
/// A Response is made up of the query given to the server and a bunch of
360
/// Resource Records (RR). Each RR will include the resource type, class, and
361
/// name. For the A records we're requesting, we will get an A record, of Internet class,
362
/// ie an IPv4 address
363
pub struct Response {
364
    /// The Query part of the response we obtain from the server
365
    query: Query,
366
    /// A collection of resource records all parsed neatly and kept separately
367
    /// for easy iteration
368
    rr: Vec<ResourceRecord>,
369
}
370

            
371
impl FromBytes for Response {
372
    // Try to construct Response from raw byte data from network
373
    // We will also try to check if a valid DNS response has been sent back to us
374
    fn from_bytes(bytes: &[u8]) -> Result<Box<Self>> {
375
        debug!("Parsing response into struct");
376
        // Check message length
377
        let l = bytes.len();
378
        let messagelen = Response::u8_to_u16(bytes[0], bytes[1]);
379
        if messagelen == (l - 2) as u16 {
380
            debug!("Appear to have gotten good message from server");
381
        } else {
382
            error!(
383
                "Expected and observed message length don't match: {} and {} respectively",
384
                l - 2,
385
                messagelen
386
            );
387
        }
388
        // Start index at 2 to skip over message length bytes
389
        let mut index = 2;
390
        let query = *Query::from_bytes(&bytes[index..])?;
391
        index += query.len() + 2; // TODO: needs explanation why it works
392
        let mut rrvec: Vec<ResourceRecord> = Vec::new();
393
        while index < l {
394
            match ResourceRecord::from_bytes(&bytes[index..]) {
395
                Ok(rr) => {
396
                    index += rr.len();
397
                    rrvec.push(*rr);
398
                }
399
                Err(_) => break,
400
            }
401
        }
402
        Ok(Box::new(Response { query, rr: rrvec }))
403
    }
404
}
405

            
406
impl Display for Response {
407
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408
        writeln!(f, "{}", self.query.header)?;
409
        writeln!(
410
            f,
411
            "Name: {}",
412
            String::from_utf8(self.query.qname.to_owned()).unwrap()
413
        )?;
414
        writeln!(f, "Res type: 0x{:x}", self.query.qtype)?;
415
        writeln!(f, "Class: 0x{:x}", self.query.qclass)?;
416
        for record in self.rr.iter() {
417
            writeln!(f)?;
418
            writeln!(f, "{}", record)?;
419
        }
420
        Ok(())
421
    }
422
}
423

            
424
/// Craft the actual query for a particular domain and returns a Query object
425
///
426
/// The query is made for an A record of type Internet, ie, a normal IPv4 address
427
/// should be returned from the DNS server.
428
///
429
/// Convert this Query into bytes to be sent over the network by calling [Query::as_bytes()]
430
pub fn build_query(domain: &str) -> Result<Query, DomainError> {
431
    // TODO: generate identification randomly
432
    let header = Header {
433
        identification: 0x304e, // chosen by random dice roll, secure
434
        packed_second_row: 0x0100,
435
        qdcount: 0x0001,
436
        ancount: 0x0000,
437
        nscount: 0x0000,
438
        arcount: 0x0000,
439
    };
440
    let mut qname: Vec<u8> = Vec::new();
441
    let split_domain: Vec<&str> = domain.split('.').collect();
442
    for part in split_domain {
443
        if part.is_empty() {
444
            return Err(DomainError);
445
        }
446
        let l = part.len() as u8;
447
        if l != 0 {
448
            qname.push(l);
449
            qname.extend_from_slice(part.as_bytes());
450
        }
451
    }
452
    qname.push(0x00); // Denote that hostname has ended by pushing 0x00
453
    debug!("Crafted query successfully!");
454
    Ok(Query {
455
        header,
456
        qname,
457
        qtype: QTYPE,
458
        qclass: QCLASS,
459
    })
460
}