1
//! Implements address policies, based on a series of accept/reject
2
//! rules.
3

            
4
use std::fmt::Display;
5
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
6
use std::str::FromStr;
7

            
8
use super::{PolicyError, PortRange};
9

            
10
/// A sequence of rules that are applied to an address:port until one
11
/// matches.
12
///
13
/// Each rule is of the form "accept PATTERN" or "reject PATTERN",
14
/// where every pattern describes a set of addresses and ports.
15
/// Address sets are given as a prefix of 0-128 bits that the address
16
/// must have; port sets are given as a low-bound and high-bound that
17
/// the target port might lie between.
18
///
19
/// Relays use this type for defining their own policies, and for
20
/// publishing their IPv4 policies.  Clients instead use
21
/// [super::portpolicy::PortPolicy] objects to view a summary of the
22
/// relays' declared policies.
23
///
24
/// An example IPv4 policy might be:
25
///
26
/// ```ignore
27
///  reject *:25
28
///  reject 127.0.0.0/8:*
29
///  reject 192.168.0.0/16:*
30
///  accept *:80
31
///  accept *:443
32
///  accept *:9000-65535
33
///  reject *:*
34
/// ```
35
#[derive(Clone, Debug, Default)]
36
pub struct AddrPolicy {
37
    /// A list of rules to apply to find out whether an address is
38
    /// contained by this policy.
39
    ///
40
    /// The rules apply in order; the first one to match determines
41
    /// whether the address is accepted or rejected.
42
    rules: Vec<AddrPolicyRule>,
43
}
44

            
45
/// A kind of policy rule: either accepts or rejects addresses
46
/// matching a pattern.
47
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
48
#[allow(clippy::exhaustive_enums)]
49
pub enum RuleKind {
50
    /// A rule that accepts matching address:port combinations.
51
    Accept,
52
    /// A rule that rejects matching address:port combinations.
53
    Reject,
54
}
55

            
56
impl AddrPolicy {
57
    /// Apply this policy to an address:port combination
58
    ///
59
    /// We do this by applying each rule in sequence, until one
60
    /// matches.
61
    ///
62
    /// Returns None if no rule matches.
63
12
    pub fn allows(&self, addr: &IpAddr, port: u16) -> Option<RuleKind> {
64
12
        self.rules
65
12
            .iter()
66
32
            .find(|rule| rule.pattern.matches(addr, port))
67
17
            .map(|AddrPolicyRule { kind, .. }| *kind)
68
12
    }
69

            
70
    /// As allows, but accept a SocketAddr.
71
12
    pub fn allows_sockaddr(&self, addr: &SocketAddr) -> Option<RuleKind> {
72
12
        self.allows(&addr.ip(), addr.port())
73
12
    }
74

            
75
    /// Create a new AddrPolicy that matches nothing.
76
1774
    pub fn new() -> Self {
77
1774
        AddrPolicy::default()
78
1774
    }
79

            
80
    /// Add a new rule to this policy.
81
    ///
82
    /// The newly added rule is applied _after_ all previous rules.
83
    /// It matches all addresses and ports covered by AddrPortPattern.
84
    ///
85
    /// If accept is true, the rule is to accept addresses that match;
86
    /// if accept is false, the rule rejects such addresses.
87
1782
    pub fn push(&mut self, kind: RuleKind, pattern: AddrPortPattern) {
88
1782
        self.rules.push(AddrPolicyRule { kind, pattern });
89
1782
    }
90
}
91

            
92
/// A single rule in an address policy.
93
///
94
/// Contains a pattern and what to do with things that match it.
95
#[derive(Clone, Debug)]
96
struct AddrPolicyRule {
97
    /// What do we do with items that match the pattern?
98
    kind: RuleKind,
99
    /// What pattern are we trying to match?
100
    pattern: AddrPortPattern,
101
}
102

            
103
/*
104
impl Display for AddrPolicyRule {
105
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106
        let cmd = match self.kind {
107
            RuleKind::Accept => "accept",
108
            RuleKind::Reject => "reject",
109
        };
110
        write!(f, "{} {}", cmd, self.pattern)
111
    }
112
}
113
*/
114

            
115
/// A pattern that may or may not match an address and port.
116
///
117
/// Each AddrPortPattern has an IP pattern, which matches a set of
118
/// addresses by prefix, and a port pattern, which matches a range of
119
/// ports.
120
///
121
/// # Example
122
///
123
/// ```
124
/// use tor_netdoc::types::policy::AddrPortPattern;
125
/// use std::net::{IpAddr,Ipv4Addr};
126
/// let localhost = IpAddr::V4(Ipv4Addr::new(127,3,4,5));
127
/// let not_localhost = IpAddr::V4(Ipv4Addr::new(192,0,2,16));
128
/// let pat: AddrPortPattern = "127.0.0.0/8:*".parse().unwrap();
129
///
130
/// assert!(pat.matches(&localhost, 22));
131
/// assert!(! pat.matches(&not_localhost, 22));
132
/// ```
133
#[derive(
134
6
    Clone, Debug, Eq, PartialEq, serde_with::SerializeDisplay, serde_with::DeserializeFromStr,
135
)]
136
pub struct AddrPortPattern {
137
    /// A pattern to match somewhere between zero and all IP addresses.
138
    pattern: IpPattern,
139
    /// A pattern to match a range of ports.
140
    ports: PortRange,
141
}
142

            
143
impl AddrPortPattern {
144
    /// Return an AddrPortPattern matching all targets.
145
13524
    pub fn new_all() -> Self {
146
13524
        Self {
147
13524
            pattern: IpPattern::Star,
148
13524
            ports: PortRange::new_all(),
149
13524
        }
150
13524
    }
151

            
152
    /// Return true iff this pattern matches a given address and port.
153
10306
    pub fn matches(&self, addr: &IpAddr, port: u16) -> bool {
154
10306
        self.pattern.matches(addr) && self.ports.contains(port)
155
10306
    }
156
    /// As matches, but accept a SocketAddr.
157
10280
    pub fn matches_sockaddr(&self, addr: &SocketAddr) -> bool {
158
10280
        self.matches(&addr.ip(), addr.port())
159
10280
    }
160
}
161

            
162
impl Display for AddrPortPattern {
163
20
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164
20
        if self.ports.is_all() {
165
2
            write!(f, "{}:*", self.pattern)
166
        } else {
167
18
            write!(f, "{}:{}", self.pattern, self.ports)
168
        }
169
20
    }
170
}
171

            
172
impl FromStr for AddrPortPattern {
173
    type Err = PolicyError;
174
2250
    fn from_str(s: &str) -> Result<Self, PolicyError> {
175
2250
        let last_colon = s.rfind(':').ok_or(PolicyError::InvalidPolicy)?;
176
2250
        let pattern: IpPattern = s[..last_colon].parse()?;
177
2242
        let ports_s = &s[last_colon + 1..];
178
2242
        let ports: PortRange = if ports_s == "*" {
179
1988
            PortRange::new_all()
180
        } else {
181
254
            ports_s.parse()?
182
        };
183

            
184
2240
        Ok(AddrPortPattern { pattern, ports })
185
2250
    }
186
}
187

            
188
/// A pattern that matches one or more IP addresses.
189
//
190
// TODO(nickm): At present there is no way for Display or FromStr to distinguish
191
// V4Star, V6Star, and Star.  If we decide it's important to have a syntax for
192
// "all IPv4 addresses" that isn't "0.0.0.0/0", we'll need to revisit that.
193
// At present, C tor allows '*', '*4', and '*6'.
194
#[derive(Clone, Debug, Eq, PartialEq)]
195
enum IpPattern {
196
    /// Match all addresses.
197
    Star,
198
    /// Match all IPv4 addresses.
199
    V4Star,
200
    /// Match all IPv6 addresses.
201
    V6Star,
202
    /// Match all IPv4 addresses beginning with a given prefix.
203
    V4(Ipv4Addr, u8),
204
    /// Match all IPv6 addresses beginning with a given prefix.
205
    V6(Ipv6Addr, u8),
206
}
207

            
208
impl IpPattern {
209
    /// Construct an IpPattern that matches the first `mask` bits of `addr`.
210
332
    fn from_addr_and_mask(addr: IpAddr, mask: u8) -> Result<Self, PolicyError> {
211
332
        match (addr, mask) {
212
4
            (IpAddr::V4(_), 0) => Ok(IpPattern::V4Star),
213
2
            (IpAddr::V6(_), 0) => Ok(IpPattern::V6Star),
214
314
            (IpAddr::V4(a), m) if m <= 32 => Ok(IpPattern::V4(a, m)),
215
12
            (IpAddr::V6(a), m) if m <= 128 => Ok(IpPattern::V6(a, m)),
216
4
            (_, _) => Err(PolicyError::InvalidMask),
217
        }
218
332
    }
219
    /// Return true iff `addr` is matched by this pattern.
220
10306
    fn matches(&self, addr: &IpAddr) -> bool {
221
10306
        match (self, addr) {
222
446
            (IpPattern::Star, _) => true,
223
2
            (IpPattern::V4Star, IpAddr::V4(_)) => true,
224
2
            (IpPattern::V6Star, IpAddr::V6(_)) => true,
225
9838
            (IpPattern::V4(pat, mask), IpAddr::V4(addr)) => {
226
9838
                let p1 = u32::from_be_bytes(pat.octets());
227
9838
                let p2 = u32::from_be_bytes(addr.octets());
228
9838
                let shift = 32 - mask;
229
9838
                (p1 >> shift) == (p2 >> shift)
230
            }
231
10
            (IpPattern::V6(pat, mask), IpAddr::V6(addr)) => {
232
10
                let p1 = u128::from_be_bytes(pat.octets());
233
10
                let p2 = u128::from_be_bytes(addr.octets());
234
10
                let shift = 128 - mask;
235
10
                (p1 >> shift) == (p2 >> shift)
236
            }
237
8
            (_, _) => false,
238
        }
239
10306
    }
240
}
241

            
242
impl Display for IpPattern {
243
20
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244
        use IpPattern::*;
245
20
        match self {
246
6
            Star | V4Star | V6Star => write!(f, "*"),
247
4
            V4(a, 32) => write!(f, "{}", a),
248
4
            V4(a, m) => write!(f, "{}/{}", a, m),
249
4
            V6(a, 128) => write!(f, "[{}]", a),
250
2
            V6(a, m) => write!(f, "[{}]/{}", a, m),
251
        }
252
20
    }
253
}
254

            
255
/// Helper: try to parse a plain ipv4 address, or an IPv6 address
256
/// wrapped in brackets.
257
336
fn parse_addr(mut s: &str) -> Result<IpAddr, PolicyError> {
258
336
    let bracketed = s.starts_with('[') && s.ends_with(']');
259
336
    if bracketed {
260
16
        s = &s[1..s.len() - 1];
261
320
    }
262
337
    let addr: IpAddr = s.parse().map_err(|_| PolicyError::InvalidAddress)?;
263
334
    if addr.is_ipv6() != bracketed {
264
2
        return Err(PolicyError::InvalidAddress);
265
332
    }
266
332
    Ok(addr)
267
336
}
268

            
269
impl FromStr for IpPattern {
270
    type Err = PolicyError;
271
2250
    fn from_str(s: &str) -> Result<Self, PolicyError> {
272
2250
        let (ip_s, mask_s) = match s.find('/') {
273
328
            Some(slash_idx) => (&s[..slash_idx], Some(&s[slash_idx + 1..])),
274
1922
            None => (s, None),
275
        };
276
2250
        match (ip_s, mask_s) {
277
2250
            ("*", Some(_)) => Err(PolicyError::MaskWithStar),
278
1914
            ("*", None) => Ok(IpPattern::Star),
279
328
            (s, Some(m)) => {
280
328
                let a: IpAddr = parse_addr(s)?;
281
326
                let m: u8 = m.parse().map_err(|_| PolicyError::InvalidMask)?;
282
326
                IpPattern::from_addr_and_mask(a, m)
283
            }
284
8
            (s, None) => {
285
8
                let a: IpAddr = parse_addr(s)?;
286
6
                let m = if a.is_ipv4() { 32 } else { 128 };
287
6
                IpPattern::from_addr_and_mask(a, m)
288
            }
289
        }
290
2250
    }
291
}
292

            
293
#[cfg(test)]
294
mod test {
295
    // @@ begin test lint list maintained by maint/add_warning @@
296
    #![allow(clippy::bool_assert_comparison)]
297
    #![allow(clippy::clone_on_copy)]
298
    #![allow(clippy::dbg_macro)]
299
    #![allow(clippy::mixed_attributes_style)]
300
    #![allow(clippy::print_stderr)]
301
    #![allow(clippy::print_stdout)]
302
    #![allow(clippy::single_char_pattern)]
303
    #![allow(clippy::unwrap_used)]
304
    #![allow(clippy::unchecked_duration_subtraction)]
305
    #![allow(clippy::useless_vec)]
306
    #![allow(clippy::needless_pass_by_value)]
307
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
308
    use super::*;
309

            
310
    #[test]
311
    fn test_roundtrip_rules() {
312
        fn check(inp: &str, outp: &str) {
313
            let policy = inp.parse::<AddrPortPattern>().unwrap();
314
            assert_eq!(format!("{}", policy), outp);
315
        }
316

            
317
        check("127.0.0.2/32:77-10000", "127.0.0.2:77-10000");
318
        check("127.0.0.2/32:*", "127.0.0.2:*");
319
        check("127.0.0.0/16:9-100", "127.0.0.0/16:9-100");
320
        check("127.0.0.0/0:443", "*:443");
321
        check("*:443", "*:443");
322
        check("[::1]:443", "[::1]:443");
323
        check("[ffaa::]/16:80", "[ffaa::]/16:80");
324
        check("[ffaa::77]/128:80", "[ffaa::77]:80");
325
    }
326

            
327
    #[test]
328
    fn test_bad_rules() {
329
        fn check(s: &str) {
330
            assert!(s.parse::<AddrPortPattern>().is_err());
331
        }
332

            
333
        check("marzipan:80");
334
        check("1.2.3.4:90-80");
335
        check("1.2.3.4/100:8888");
336
        check("[1.2.3.4]/16:80");
337
        check("[::1]/130:8888");
338
    }
339

            
340
    #[test]
341
    fn test_rule_matches() {
342
        fn check(addr: &str, yes: &[&str], no: &[&str]) {
343
            use std::net::SocketAddr;
344
            let policy = addr.parse::<AddrPortPattern>().unwrap();
345
            for s in yes {
346
                let sa = s.parse::<SocketAddr>().unwrap();
347
                assert!(policy.matches_sockaddr(&sa));
348
            }
349
            for s in no {
350
                let sa = s.parse::<SocketAddr>().unwrap();
351
                assert!(!policy.matches_sockaddr(&sa));
352
            }
353
        }
354

            
355
        check(
356
            "1.2.3.4/16:80",
357
            &["1.2.3.4:80", "1.2.44.55:80"],
358
            &["9.9.9.9:80", "1.3.3.4:80", "1.2.3.4:81"],
359
        );
360
        check(
361
            "*:443-8000",
362
            &["1.2.3.4:443", "[::1]:500"],
363
            &["9.0.0.0:80", "[::1]:80"],
364
        );
365
        check(
366
            "[face::]/8:80",
367
            &["[fab0::7]:80"],
368
            &["[dd00::]:80", "[face::7]:443"],
369
        );
370

            
371
        check("0.0.0.0/0:*", &["127.0.0.1:80"], &["[f00b::]:80"]);
372
        check("[::]/0:*", &["[f00b::]:80"], &["127.0.0.1:80"]);
373
    }
374

            
375
    #[test]
376
    fn test_policy_matches() -> Result<(), PolicyError> {
377
        let mut policy = AddrPolicy::default();
378
        policy.push(RuleKind::Accept, "*:443".parse()?);
379
        policy.push(RuleKind::Accept, "[::1]:80".parse()?);
380
        policy.push(RuleKind::Reject, "*:80".parse()?);
381

            
382
        let policy = policy; // drop mut
383
        assert_eq!(
384
            policy.allows_sockaddr(&"[::6]:443".parse().unwrap()),
385
            Some(RuleKind::Accept)
386
        );
387
        assert_eq!(
388
            policy.allows_sockaddr(&"127.0.0.1:443".parse().unwrap()),
389
            Some(RuleKind::Accept)
390
        );
391
        assert_eq!(
392
            policy.allows_sockaddr(&"[::1]:80".parse().unwrap()),
393
            Some(RuleKind::Accept)
394
        );
395
        assert_eq!(
396
            policy.allows_sockaddr(&"[::2]:80".parse().unwrap()),
397
            Some(RuleKind::Reject)
398
        );
399
        assert_eq!(
400
            policy.allows_sockaddr(&"127.0.0.1:80".parse().unwrap()),
401
            Some(RuleKind::Reject)
402
        );
403
        assert_eq!(
404
            policy.allows_sockaddr(&"127.0.0.1:66".parse().unwrap()),
405
            None
406
        );
407
        Ok(())
408
    }
409

            
410
    #[test]
411
    fn serde() {
412
        #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Eq, PartialEq)]
413
        struct X {
414
            p1: AddrPortPattern,
415
            p2: AddrPortPattern,
416
        }
417

            
418
        let x = X {
419
            p1: "127.0.0.1/8:9-10".parse().unwrap(),
420
            p2: "*:80".parse().unwrap(),
421
        };
422

            
423
        let encoded = serde_json::to_string(&x).unwrap();
424
        let expected = r#"{"p1":"127.0.0.1/8:9-10","p2":"*:80"}"#;
425
        let x2: X = serde_json::from_str(&encoded).unwrap();
426
        let x3: X = serde_json::from_str(expected).unwrap();
427
        assert_eq!(&x2, &x3);
428
        assert_eq!(&x2, &x);
429
    }
430
}