1
//! A crate for performing GeoIP lookups using the Tor GeoIP database.
2

            
3
// @@ begin lint list maintained by maint/add_warning @@
4
#![allow(renamed_and_removed_lints)] // @@REMOVE_WHEN(ci_arti_stable)
5
#![allow(unknown_lints)] // @@REMOVE_WHEN(ci_arti_nightly)
6
#![warn(missing_docs)]
7
#![warn(noop_method_call)]
8
#![warn(unreachable_pub)]
9
#![warn(clippy::all)]
10
#![deny(clippy::await_holding_lock)]
11
#![deny(clippy::cargo_common_metadata)]
12
#![deny(clippy::cast_lossless)]
13
#![deny(clippy::checked_conversions)]
14
#![warn(clippy::cognitive_complexity)]
15
#![deny(clippy::debug_assert_with_mut_call)]
16
#![deny(clippy::exhaustive_enums)]
17
#![deny(clippy::exhaustive_structs)]
18
#![deny(clippy::expl_impl_clone_on_copy)]
19
#![deny(clippy::fallible_impl_from)]
20
#![deny(clippy::implicit_clone)]
21
#![deny(clippy::large_stack_arrays)]
22
#![warn(clippy::manual_ok_or)]
23
#![deny(clippy::missing_docs_in_private_items)]
24
#![warn(clippy::needless_borrow)]
25
#![warn(clippy::needless_pass_by_value)]
26
#![warn(clippy::option_option)]
27
#![deny(clippy::print_stderr)]
28
#![deny(clippy::print_stdout)]
29
#![warn(clippy::rc_buffer)]
30
#![deny(clippy::ref_option_ref)]
31
#![warn(clippy::semicolon_if_nothing_returned)]
32
#![warn(clippy::trait_duplication_in_bounds)]
33
#![deny(clippy::unchecked_duration_subtraction)]
34
#![deny(clippy::unnecessary_wraps)]
35
#![warn(clippy::unseparated_literal_suffix)]
36
#![deny(clippy::unwrap_used)]
37
#![deny(clippy::mod_module_files)]
38
#![allow(clippy::let_unit_value)] // This can reasonably be done for explicitness
39
#![allow(clippy::uninlined_format_args)]
40
#![allow(clippy::significant_drop_in_scrutinee)] // arti/-/merge_requests/588/#note_2812945
41
#![allow(clippy::result_large_err)] // temporary workaround for arti#587
42
#![allow(clippy::needless_raw_string_hashes)] // complained-about code is fine, often best
43
#![allow(clippy::needless_lifetimes)] // See arti#1765
44
#![allow(mismatched_lifetime_syntaxes)] // temporary workaround for arti#2060
45
//! <!-- @@ end lint list maintained by maint/add_warning @@ -->
46

            
47
// TODO #1645 (either remove this, or decide to have it everywhere)
48
#![cfg_attr(not(all(feature = "full")), allow(unused))]
49

            
50
pub use crate::err::Error;
51
use rangemap::RangeInclusiveMap;
52
use std::fmt::{Debug, Display, Formatter};
53
use std::net::{IpAddr, Ipv6Addr};
54
use std::num::{NonZeroU32, NonZeroU8, TryFromIntError};
55
use std::str::FromStr;
56
use std::sync::{Arc, OnceLock};
57

            
58
mod err;
59

            
60
/// An embedded copy of the latest geoip v4 database at the time of compilation.
61
///
62
/// FIXME(eta): This does use a few megabytes of binary size, which is less than ideal.
63
///             It would be better to parse it at compile time or something.
64
#[cfg(feature = "embedded-db")]
65
static EMBEDDED_DB_V4: &str = include_str!("../data/geoip");
66

            
67
/// An embedded copy of the latest geoip v6 database at the time of compilation.
68
#[cfg(feature = "embedded-db")]
69
static EMBEDDED_DB_V6: &str = include_str!("../data/geoip6");
70

            
71
/// A parsed copy of the embedded database.
72
#[cfg(feature = "embedded-db")]
73
static EMBEDDED_DB_PARSED: OnceLock<Arc<GeoipDb>> = OnceLock::new();
74

            
75
/// A two-letter country code.
76
///
77
/// Specifically, this type represents a purported "ISO 3166-1 alpha-2" country
78
/// code, such as "IT" for Italy or "UY" for Uruguay.
79
///
80
/// It does not include the sentinel value `??` that we use to represent
81
/// "country unknown"; if you need that, use [`OptionCc`]. Other than that, we
82
/// do not check whether the country code represents a real country: we only
83
/// ensure that it is a pair of printing ASCII characters.
84
///
85
/// Note that the geoip databases included with Arti will only include real
86
/// countries; we do not include the pseudo-countries `A1` through `An` for
87
/// "anonymous proxies", since doing so would mean putting nearly all Tor relays
88
/// into one of those countries.
89
#[derive(Copy, Clone, Eq, PartialEq)]
90
pub struct CountryCode {
91
    /// The underlying value (two printable ASCII characters, stored uppercase).
92
    ///
93
    /// The special value `??` is excluded, since it is not a country; use
94
    /// `OptionCc` instead if you need to represent that.
95
    ///
96
    /// We store these as `NonZeroU8` so that an `Option<CountryCode>` only has to
97
    /// take 2 bytes. This helps with alignment and storage.
98
    inner: [NonZeroU8; 2],
99
}
100

            
101
impl CountryCode {
102
    /// Make a new `CountryCode`.
103
14070228
    fn new(cc_orig: &str) -> Result<Self, Error> {
104
        /// Try to convert an array of 2 bytes into an array of 2 nonzero bytes.
105
        #[inline]
106
14041266
        fn try_cvt_to_nz(inp: [u8; 2]) -> Result<[NonZeroU8; 2], TryFromIntError> {
107
14041266
            // I have confirmed that the asm here is reasonably efficient.
108
14041266
            Ok([inp[0].try_into()?, inp[1].try_into()?])
109
14041266
        }
110

            
111
14070228
        let cc = cc_orig.to_ascii_uppercase();
112

            
113
14070228
        let cc: [u8; 2] = cc
114
14070228
            .as_bytes()
115
14070228
            .try_into()
116
14070231
            .map_err(|_| Error::BadCountryCode(cc))?;
117

            
118
28726711
        if !cc.iter().all(|b| b.is_ascii() && !b.is_ascii_control()) {
119
6
            return Err(Error::BadCountryCode(cc_orig.to_owned()));
120
14070216
        }
121
14070216

            
122
14070216
        if &cc == b"??" {
123
28950
            return Err(Error::NowhereNotSupported);
124
14041266
        }
125
14041266

            
126
14041266
        Ok(Self {
127
14041266
            inner: try_cvt_to_nz(cc).map_err(|_| Error::BadCountryCode(cc_orig.to_owned()))?,
128
        })
129
14070228
    }
130

            
131
    /// Get the actual country code.
132
    ///
133
    /// This just calls `.as_ref()`.
134
    pub fn get(&self) -> &str {
135
        self.as_ref()
136
    }
137
}
138

            
139
impl Display for CountryCode {
140
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
141
        write!(f, "{}", self.as_ref())
142
    }
143
}
144

            
145
impl Debug for CountryCode {
146
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
147
        write!(f, "CountryCode(\"{}\")", self.as_ref())
148
    }
149
}
150

            
151
impl AsRef<str> for CountryCode {
152
100
    fn as_ref(&self) -> &str {
153
        /// Convert a reference to an array of 2 nonzero bytes to a reference to
154
        /// an array of 2 bytes.
155
        #[inline]
156
100
        fn cvt_ref(inp: &[NonZeroU8; 2]) -> &[u8; 2] {
157
100
            // SAFETY: Every NonZeroU8 has a layout and bit validity that is
158
100
            // also a valid u8.  The layout of arrays is also guaranteed.
159
100
            //
160
100
            // (We don't use try_into here because we need to return a str that
161
100
            // points to a reference to self.)
162
100
            let ptr = inp.as_ptr() as *const u8;
163
100
            let slice = unsafe { std::slice::from_raw_parts(ptr, inp.len()) };
164
100
            slice
165
100
                .try_into()
166
100
                .expect("the resulting slice should have the correct length!")
167
100
        }
168

            
169
        // This shouldn't ever panic, since we shouldn't feed non-utf8 country
170
        // codes in.
171
        //
172
        // In theory we could use from_utf8_unchecked, but that's probably not
173
        // needed.
174
100
        std::str::from_utf8(cvt_ref(&self.inner)).expect("invalid country code in CountryCode")
175
100
    }
176
}
177

            
178
impl FromStr for CountryCode {
179
    type Err = Error;
180

            
181
32
    fn from_str(s: &str) -> Result<Self, Self::Err> {
182
32
        CountryCode::new(s)
183
32
    }
184
}
185

            
186
/// Wrapper for an `Option<`[`CountryCode`]`>` that encodes `None` as `??`.
187
///
188
/// Used so that we can implement foreign traits.
189
#[derive(
190
    Copy, Clone, Debug, Eq, PartialEq, derive_more::Into, derive_more::From, derive_more::AsRef,
191
)]
192
#[allow(clippy::exhaustive_structs)]
193
pub struct OptionCc(pub Option<CountryCode>);
194

            
195
impl FromStr for OptionCc {
196
    type Err = Error;
197

            
198
14070196
    fn from_str(s: &str) -> Result<Self, Self::Err> {
199
14070196
        match CountryCode::new(s) {
200
28948
            Err(Error::NowhereNotSupported) => Ok(None.into()),
201
            Err(e) => Err(e),
202
14041248
            Ok(cc) => Ok(Some(cc).into()),
203
        }
204
14070196
    }
205
}
206

            
207
impl Display for OptionCc {
208
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
209
        match self.0 {
210
            Some(cc) => write!(f, "{}", cc),
211
            None => write!(f, "??"),
212
        }
213
    }
214
}
215

            
216
/// A country code / ASN definition.
217
///
218
/// Type lifted from `geoip-db-tool` in the C-tor source.
219
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
220
struct NetDefn {
221
    /// The country code.
222
    ///
223
    /// We translate the value "??" into None.
224
    cc: Option<CountryCode>,
225
    /// The ASN, if we have one. We translate the value "0" into None.
226
    asn: Option<NonZeroU32>,
227
}
228

            
229
impl NetDefn {
230
    /// Make a new `NetDefn`.
231
14070192
    fn new(cc: &str, asn: Option<u32>) -> Result<Self, Error> {
232
14070192
        let asn = NonZeroU32::new(asn.unwrap_or(0));
233
14070192
        let cc = cc.parse::<OptionCc>()?.into();
234
14070192

            
235
14070192
        Ok(Self { cc, asn })
236
14070192
    }
237

            
238
    /// Return the country code.
239
240
    fn country_code(&self) -> Option<&CountryCode> {
240
240
        self.cc.as_ref()
241
240
    }
242

            
243
    /// Return the ASN, if there is one.
244
    fn asn(&self) -> Option<u32> {
245
        self.asn.as_ref().map(|x| x.get())
246
    }
247
}
248

            
249
/// A database of IP addresses to country codes.
250
#[derive(Clone, Eq, PartialEq, Debug)]
251
pub struct GeoipDb {
252
    /// The IPv4 subset of the database, with v4 addresses stored as 32-bit integers.
253
    map_v4: RangeInclusiveMap<u32, NetDefn>,
254
    /// The IPv6 subset of the database, with v6 addresses stored as 128-bit integers.
255
    map_v6: RangeInclusiveMap<u128, NetDefn>,
256
}
257

            
258
impl GeoipDb {
259
    /// Make a new `GeoipDb` using a compiled-in copy of the GeoIP database.
260
    ///
261
    /// The returned instance of the database is shared with `Arc` across all invocations of this
262
    /// function in the same program.
263
    #[cfg(feature = "embedded-db")]
264
140
    pub fn new_embedded() -> Arc<Self> {
265
142
        Arc::clone(EMBEDDED_DB_PARSED.get_or_init(|| {
266
48
            Arc::new(
267
48
                // It's reasonable to assume the one we embedded is fine -- we'll test it in CI, etc.
268
48
                Self::new_from_legacy_format(EMBEDDED_DB_V4, EMBEDDED_DB_V6)
269
48
                    .expect("failed to parse embedded geoip database"),
270
48
            )
271
142
        }))
272
140
    }
273

            
274
    /// Make a new `GeoipDb` using provided copies of the v4 and v6 database, in Tor legacy format.
275
96
    pub fn new_from_legacy_format(db_v4: &str, db_v6: &str) -> Result<Self, Error> {
276
96
        let mut ret = GeoipDb {
277
96
            map_v4: Default::default(),
278
96
            map_v6: Default::default(),
279
96
        };
280

            
281
8079654
        for line in db_v4.lines() {
282
8079654
            if line.starts_with('#') {
283
816
                continue;
284
8078838
            }
285
8078838
            let line = line.trim();
286
8078838
            if line.is_empty() {
287
4
                continue;
288
8078834
            }
289
8078834
            let mut split = line.split(',');
290
8078834
            let from = split
291
8078834
                .next()
292
8078834
                .ok_or(Error::BadFormat("empty line somehow?"))?
293
8078834
                .parse::<u32>()?;
294
8078834
            let to = split
295
8078834
                .next()
296
8078834
                .ok_or(Error::BadFormat("line with insufficient commas"))?
297
8078834
                .parse::<u32>()?;
298
8078834
            let cc = split
299
8078834
                .next()
300
8078834
                .ok_or(Error::BadFormat("line with insufficient commas"))?;
301
8078834
            let asn = split.next().map(|x| x.parse::<u32>()).transpose()?;
302

            
303
8078834
            let defn = NetDefn::new(cc, asn)?;
304

            
305
8078834
            ret.map_v4.insert(from..=to, defn);
306
        }
307

            
308
        // This is slightly copypasta, but probably less readable to merge into one thing.
309
5992270
        for line in db_v6.lines() {
310
5992270
            if line.starts_with('#') {
311
816
                continue;
312
5991454
            }
313
5991454
            let line = line.trim();
314
5991454
            if line.is_empty() {
315
96
                continue;
316
5991358
            }
317
5991358
            let mut split = line.split(',');
318
5991358
            let from = split
319
5991358
                .next()
320
5991358
                .ok_or(Error::BadFormat("empty line somehow?"))?
321
5991358
                .parse::<Ipv6Addr>()?;
322
5991358
            let to = split
323
5991358
                .next()
324
5991358
                .ok_or(Error::BadFormat("line with insufficient commas"))?
325
5991358
                .parse::<Ipv6Addr>()?;
326
5991358
            let cc = split
327
5991358
                .next()
328
5991358
                .ok_or(Error::BadFormat("line with insufficient commas"))?;
329
5991358
            let asn = split.next().map(|x| x.parse::<u32>()).transpose()?;
330

            
331
5991358
            let defn = NetDefn::new(cc, asn)?;
332

            
333
5991358
            ret.map_v6.insert(from.into()..=to.into(), defn);
334
        }
335

            
336
96
        Ok(ret)
337
96
    }
338

            
339
    /// Get the `NetDefn` for an IP address.
340
3004
    fn lookup_defn(&self, ip: IpAddr) -> Option<&NetDefn> {
341
3004
        match ip {
342
2490
            IpAddr::V4(v4) => self.map_v4.get(&v4.into()),
343
514
            IpAddr::V6(v6) => self.map_v6.get(&v6.into()),
344
        }
345
3004
    }
346

            
347
    /// Get a 2-letter country code for the given IP address, if this data is available.
348
3004
    pub fn lookup_country_code(&self, ip: IpAddr) -> Option<&CountryCode> {
349
3014
        self.lookup_defn(ip).and_then(|x| x.country_code())
350
3004
    }
351

            
352
    /// Determine a 2-letter country code for a host with multiple IP addresses.
353
    ///
354
    /// This looks up all of the IP addresses with `lookup_country_code`. If the lookups
355
    /// return different countries, `None` is returned. IP addresses that fail to resolve
356
    /// into a country are ignored if some of the other addresses do resolve successfully.
357
710
    pub fn lookup_country_code_multi<I>(&self, ips: I) -> Option<&CountryCode>
358
710
    where
359
710
        I: IntoIterator<Item = IpAddr>,
360
710
    {
361
710
        let mut ret = None;
362

            
363
1698
        for ip in ips {
364
990
            if let Some(cc) = self.lookup_country_code(ip) {
365
                // If we already have a return value and it's different, then return None;
366
                // a server can't be in two different countries.
367
10
                if ret.is_some() && ret != Some(cc) {
368
2
                    return None;
369
8
                }
370
8

            
371
8
                ret = Some(cc);
372
980
            }
373
        }
374

            
375
708
        ret
376
710
    }
377

            
378
    /// Return the ASN the IP address is in, if this data is available.
379
    pub fn lookup_asn(&self, ip: IpAddr) -> Option<u32> {
380
        self.lookup_defn(ip)?.asn()
381
    }
382
}
383

            
384
/// A (representation of a) host on the network which may have a known country code.
385
pub trait HasCountryCode {
386
    /// Return the country code in which this server is most likely located.
387
    ///
388
    /// This is usually implemented by simple GeoIP lookup on the addresses provided by `HasAddrs`.
389
    /// It follows that the server might not actually be in the returned country, but this is a
390
    /// halfway decent estimate for what other servers might guess the server's location to be
391
    /// (and thus useful for e.g. getting around simple geo-blocks, or having webpages return
392
    /// the correct localised versions).
393
    ///
394
    /// Returning `None` signifies that no country code information is available. (Conflicting
395
    /// GeoIP lookup results might also cause `None` to be returned.)
396
    fn country_code(&self) -> Option<CountryCode>;
397
}
398

            
399
#[cfg(test)]
400
mod test {
401
    // @@ begin test lint list maintained by maint/add_warning @@
402
    #![allow(clippy::bool_assert_comparison)]
403
    #![allow(clippy::clone_on_copy)]
404
    #![allow(clippy::dbg_macro)]
405
    #![allow(clippy::mixed_attributes_style)]
406
    #![allow(clippy::print_stderr)]
407
    #![allow(clippy::print_stdout)]
408
    #![allow(clippy::single_char_pattern)]
409
    #![allow(clippy::unwrap_used)]
410
    #![allow(clippy::unchecked_duration_subtraction)]
411
    #![allow(clippy::useless_vec)]
412
    #![allow(clippy::needless_pass_by_value)]
413
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
414

            
415
    use super::*;
416
    use std::net::Ipv4Addr;
417

            
418
    // NOTE(eta): this test takes a whole 1.6 seconds in *non-release* mode
419
    #[test]
420
    #[cfg(feature = "embedded-db")]
421
    fn embedded_db() {
422
        let db = GeoipDb::new_embedded();
423

            
424
        assert_eq!(
425
            db.lookup_country_code(Ipv4Addr::new(8, 8, 8, 8).into())
426
                .map(|x| x.as_ref()),
427
            Some("US")
428
        );
429

            
430
        assert_eq!(
431
            db.lookup_country_code("2001:4860:4860::8888".parse().unwrap())
432
                .map(|x| x.as_ref()),
433
            Some("US")
434
        );
435
    }
436

            
437
    #[test]
438
    fn basic_lookups() {
439
        let src_v4 = r#"
440
        16909056,16909311,GB
441
        "#;
442
        let src_v6 = r#"
443
        fe80::,fe81::,US
444
        dead:beef::,dead:ffff::,??
445
        "#;
446
        let db = GeoipDb::new_from_legacy_format(src_v4, src_v6).unwrap();
447

            
448
        assert_eq!(
449
            db.lookup_country_code(Ipv4Addr::new(1, 2, 3, 4).into())
450
                .map(|x| x.as_ref()),
451
            Some("GB")
452
        );
453

            
454
        assert_eq!(
455
            db.lookup_country_code(Ipv4Addr::new(1, 1, 1, 1).into()),
456
            None
457
        );
458

            
459
        assert_eq!(
460
            db.lookup_country_code("fe80::dead:beef".parse().unwrap())
461
                .map(|x| x.as_ref()),
462
            Some("US")
463
        );
464

            
465
        assert_eq!(
466
            db.lookup_country_code("fe81::dead:beef".parse().unwrap()),
467
            None
468
        );
469
        assert_eq!(
470
            db.lookup_country_code("dead:beef::1".parse().unwrap()),
471
            None
472
        );
473
    }
474

            
475
    #[test]
476
    fn cc_parse() -> Result<(), Error> {
477
        // real countries.
478
        assert_eq!(CountryCode::from_str("us")?, CountryCode::from_str("US")?);
479
        assert_eq!(CountryCode::from_str("UY")?, CountryCode::from_str("UY")?);
480

            
481
        // not real as of this writing, but still representable.
482
        assert_eq!(CountryCode::from_str("A7")?, CountryCode::from_str("a7")?);
483
        assert_eq!(CountryCode::from_str("xz")?, CountryCode::from_str("xz")?);
484

            
485
        // Can't convert to two bytes.
486
        assert!(matches!(
487
            CountryCode::from_str("z"),
488
            Err(Error::BadCountryCode(_))
489
        ));
490
        assert!(matches!(
491
            CountryCode::from_str("🐻‍❄️"),
492
            Err(Error::BadCountryCode(_))
493
        ));
494
        assert!(matches!(
495
            CountryCode::from_str("Sheboygan"),
496
            Err(Error::BadCountryCode(_))
497
        ));
498

            
499
        // Can convert to two bytes, but still not printable ascii
500
        assert!(matches!(
501
            CountryCode::from_str("\r\n"),
502
            Err(Error::BadCountryCode(_))
503
        ));
504
        assert!(matches!(
505
            CountryCode::from_str("\0\0"),
506
            Err(Error::BadCountryCode(_))
507
        ));
508
        assert!(matches!(
509
            CountryCode::from_str("¡"),
510
            Err(Error::BadCountryCode(_))
511
        ));
512

            
513
        // Not a country.
514
        assert!(matches!(
515
            CountryCode::from_str("??"),
516
            Err(Error::NowhereNotSupported)
517
        ));
518

            
519
        Ok(())
520
    }
521

            
522
    #[test]
523
    fn opt_cc_parse() -> Result<(), Error> {
524
        assert_eq!(
525
            CountryCode::from_str("br")?,
526
            OptionCc::from_str("BR")?.0.unwrap()
527
        );
528
        assert!(OptionCc::from_str("??")?.0.is_none());
529

            
530
        Ok(())
531
    }
532
}