1
//! `ByteQty`, Quantity of memory used, measured in bytes.
2
//
3
// The closest crate to this on crates.io is `bytesize`.
4
// But it has serious bugs including confusion about KiB vs KB,
5
// and isn't maintained.
6
//
7
// There is also humansize, but that just does printing.
8

            
9
#![allow(clippy::comparison_to_empty)] // unit == "" etc. is much clearer
10

            
11
use derive_more::{Deref, DerefMut, From, Into};
12
use itertools::Itertools;
13
use thiserror::Error;
14

            
15
#[cfg(feature = "serde")]
16
use serde::{Deserialize, Serialize};
17

            
18
use std::fmt::{self, Display};
19
use std::str::FromStr;
20

            
21
use InvalidByteQty as IBQ;
22

            
23
/// Quantity of memory used, measured in bytes.
24
///
25
/// Like `usize` but `FromStr` and `Display`s in a more friendly and less precise way
26
///
27
/// Parses from (with or without the internal space):
28
///  * `<amount>` (implicitly, bytes)
29
///  * `<amount> B`
30
///  * `<amount> KiB`/`MiB`/`GiB`/`TiB` (binary, 1024-based units)
31
///  * `<amount> KB`/`MB`/`GB`/`TB` (decimal, 1000-based units)
32
///
33
/// Displays to approximately 3 significant figures,
34
/// preferring binary (1024-based) multipliers.
35
/// (There is no facility for adjusting the format.)
36
#[derive(Debug, Clone, Copy, Hash, Default, Eq, PartialEq, Ord, PartialOrd)] //
37
#[derive(From, Into, Deref, DerefMut)]
38
#[cfg_attr(
39
    feature = "serde",
40
46
    derive(Serialize, Deserialize),
41
    serde(into = "usize", try_from = "ByteQtySerde")
42
)]
43
#[allow(clippy::exhaustive_structs)] // this is a behavioural newtype wrapper
44
pub struct ByteQty(pub usize);
45

            
46
/// Error parsing (or deserialising) a [`ByteQty`]
47
#[derive(Error, Copy, Clone, Debug, Eq, PartialEq, Hash)]
48
pub enum InvalidByteQty {
49
    /// Value bigger than `usize::MAX`
50
    #[error(
51
        "size/quantity outside range supported on this system (max is {} B)",
52
        usize::MAX
53
    )]
54
    Overflow,
55
    /// Unknown unit
56
    #[error(
57
        "size/quantity specified unknown unit; supported are {}",
58
        SupportedUnits
59
    )]
60
    UnknownUnit,
61
    /// Unknown unit, probably because the B at the end was missing
62
    ///
63
    /// We insist on the `B` so that all our units end in `B` or `iB`.
64
    #[error(
65
        "size/quantity specified unknown unit - we require the `B`; supported units are {}",
66
        SupportedUnits
67
    )]
68
    UnknownUnitMissingB,
69
    /// Bad syntax
70
    #[error("size/quantity specified string in bad syntax")]
71
    BadSyntax,
72
    /// Negative value
73
    #[error("size/quantity cannot be negative")]
74
    Negative,
75
    /// NaN
76
    #[error("size/quantity cannot be obtained from a floating point NaN")]
77
    NaN,
78
    /// BadValue
79
    #[error("bad type for size/quantity (only numbers, and strings to parse, are supported)")]
80
    BadValue,
81
}
82

            
83
//---------- units (definitions) ----------
84

            
85
/// Units that can be suffixed to a number, when displaying [`ByteQty`] (macro)
86
const DISPLAY_UNITS: &[(&str, u64)] = &[
87
    ("B", 1),
88
    ("KiB", 1024),
89
    ("MiB", 1024 * 1024),
90
    ("GiB", 1024 * 1024 * 1024),
91
    ("TiB", 1024 * 1024 * 1024 * 1024),
92
];
93

            
94
/// Units that are (only) recognised parsing a [`ByteQty`] from a string
95
const PARSE_UNITS: &[(&str, u64)] = &[
96
    ("", 1),
97
    ("KB", 1000),
98
    ("MB", 1000 * 1000),
99
    ("GB", 1000 * 1000 * 1000),
100
    ("TB", 1000 * 1000 * 1000 * 1000),
101
];
102

            
103
/// Units that are used when parsing *and* when printing
104
const ALL_UNITS: &[&[(&str, u64)]] = &[
105
    //
106
    DISPLAY_UNITS,
107
    PARSE_UNITS,
108
];
109

            
110
//---------- inherent methods ----------
111

            
112
impl ByteQty {
113
    /// Maximum for the type
114
    pub const MAX: ByteQty = ByteQty(usize::MAX);
115

            
116
    /// Return the value as a plain number, a `usize`
117
    ///
118
    /// Provided so call sites don't need to write an opaque `.0` everywhere,
119
    /// even though that would be fine.
120
320
    pub const fn as_usize(self) -> usize {
121
320
        self.0
122
320
    }
123
}
124

            
125
//---------- printing ----------
126

            
127
impl Display for ByteQty {
128
11052
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
129
11052
        let v = self.0 as f64;
130
11052

            
131
11052
        // Find the first entry which is big enough that the mantissa will be <999.5,
132
11052
        // ie where it won't print as 4 decimal digits after the point.
133
11052
        // Or, if that doesn't work, we'll use the last entry which is the largest.
134
11052

            
135
11052
        let (unit, mantissa) = DISPLAY_UNITS
136
11052
            .iter()
137
11052
            .copied()
138
31220
            .filter(|(unit, _)| *unit != "")
139
31220
            .map(|(unit, multiplier)| (unit, v / multiplier as f64))
140
31220
            .find_or_last(|(_, mantissa)| *mantissa < 999.5)
141
11052
            .expect("DISPLAY_UNITS Is empty?!");
142

            
143
        // Select a precision so that we'll print about 3 significant figures.
144
        // We can't do this precisely, so we err on the side of slightly
145
        // fewer SF with mantissae starting with 9.
146

            
147
11052
        let after_decimal = if mantissa < 9. {
148
1288
            2
149
9764
        } else if mantissa < 99. {
150
9604
            1
151
        } else {
152
160
            0
153
        };
154

            
155
11052
        write!(f, "{mantissa:.*} {unit}", after_decimal)
156
11052
    }
157
}
158

            
159
//---------- incoming conversions ----------
160

            
161
// We don't provide Into<u64> or Into<f64> because they're actually quite faffsome
162
// due to all the corner cases.  We only provide these two, because we need them
163
// ourselves for parsing and deserialisation.
164

            
165
impl TryFrom<u64> for ByteQty {
166
    type Error = InvalidByteQty;
167
664
    fn try_from(v: u64) -> Result<ByteQty, IBQ> {
168
664
        let v = v.try_into().map_err(|_| IBQ::Overflow)?;
169
664
        Ok(ByteQty(v))
170
664
    }
171
}
172

            
173
impl TryFrom<f64> for ByteQty {
174
    type Error = InvalidByteQty;
175
30
    fn try_from(f: f64) -> Result<ByteQty, IBQ> {
176
30
        if f.is_nan() {
177
2
            Err(IBQ::NaN)
178
28
        } else if f > (usize::MAX as f64) {
179
4
            Err(IBQ::Overflow)
180
24
        } else if f >= 0. {
181
22
            Ok(ByteQty(f as usize))
182
        } else {
183
2
            Err(IBQ::Negative)
184
        }
185
30
    }
186
}
187

            
188
/// Helper for deserializing [`ByteQty`]
189
#[cfg(feature = "serde")]
190
#[derive(Deserialize)]
191
#[serde(untagged)]
192
enum ByteQtySerde {
193
    /// `String`
194
    U(u64),
195
    /// `String`
196
    S(String),
197
    /// `f64`
198
    F(f64),
199
    /// Other things
200
    Bad(serde::de::IgnoredAny),
201
}
202
#[cfg(feature = "serde")]
203
impl TryFrom<ByteQtySerde> for ByteQty {
204
    type Error = InvalidByteQty;
205
670
    fn try_from(qs: ByteQtySerde) -> Result<ByteQty, IBQ> {
206
670
        match qs {
207
646
            ByteQtySerde::S(s) => s.parse(),
208
4
            ByteQtySerde::U(u) => u.try_into(),
209
4
            ByteQtySerde::F(f) => f.try_into(),
210
16
            ByteQtySerde::Bad(_) => Err(IBQ::BadValue),
211
        }
212
670
    }
213
}
214

            
215
//---------- FromStr ----------
216

            
217
impl FromStr for ByteQty {
218
    type Err = InvalidByteQty;
219
678
    fn from_str(s: &str) -> Result<Self, IBQ> {
220
678
        let s = s.trim();
221

            
222
678
        let last_digit = s
223
3377
            .rfind(|c: char| c.is_ascii_digit())
224
678
            .ok_or(IBQ::BadSyntax)?;
225

            
226
        // last_digit points to an ASCII digit so +1 is right to skip it
227
676
        let (mantissa, unit) = s.split_at(last_digit + 1);
228
676

            
229
676
        let unit = unit.trim_start(); // remove any whitespace in the middle
230
676

            
231
676
        // defer unknown unit errors until we've done the rest of the parsing
232
676
        let multiplier: Result<u64, _> = ALL_UNITS
233
676
            .iter()
234
676
            .copied()
235
676
            .flatten()
236
2296
            .find(|(s, _)| *s == unit)
237
699
            .map(|(_, m)| *m)
238
679
            .ok_or_else(|| {
239
6
                if unit.ends_with('B') {
240
2
                    IBQ::UnknownUnit
241
                } else {
242
4
                    IBQ::UnknownUnitMissingB
243
                }
244
679
            });
245

            
246
        // We try this via u64 (so we give byte-precise answers if possible)
247
        // and via f64 (so we can support fractions).
248
        //
249
        // (Byte-precise amounts aren't important here in tor-memquota,
250
        // but this code seems like it may end up elsewhere.)
251
676
        if let Ok::<u64, _>(mantissa) = mantissa.parse() {
252
658
            let multiplier = multiplier?;
253
656
            (|| {
254
656
                mantissa
255
656
                    .checked_mul(multiplier)? //
256
656
                    .try_into()
257
656
                    .ok()
258
656
            })()
259
656
            .ok_or(IBQ::Overflow)
260
18
        } else if let Ok::<f64, _>(mantissa) = mantissa.parse() {
261
14
            let value = mantissa * (multiplier? as f64);
262
12
            value.try_into()
263
        } else {
264
4
            Err(IBQ::BadSyntax)
265
        }
266
678
    }
267
}
268

            
269
/// Helper to format the list of supported units into `IBQ::UnknownUnit`
270
struct SupportedUnits;
271

            
272
impl Display for SupportedUnits {
273
    #[allow(unstable_name_collisions)] // Itertools::intersperse vs std's;  rust-lang/rust#48919
274
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
275
        for s in ALL_UNITS
276
            .iter()
277
            .copied()
278
            .flatten()
279
            .copied()
280
            .map(|(unit, _multiplier)| unit)
281
            .filter(|unit| !unit.is_empty())
282
            .intersperse("/")
283
        {
284
            Display::fmt(s, f)?;
285
        }
286
        Ok(())
287
    }
288
}
289

            
290
#[cfg(test)]
291
mod test {
292
    // @@ begin test lint list maintained by maint/add_warning @@
293
    #![allow(clippy::bool_assert_comparison)]
294
    #![allow(clippy::clone_on_copy)]
295
    #![allow(clippy::dbg_macro)]
296
    #![allow(clippy::mixed_attributes_style)]
297
    #![allow(clippy::print_stderr)]
298
    #![allow(clippy::print_stdout)]
299
    #![allow(clippy::single_char_pattern)]
300
    #![allow(clippy::unwrap_used)]
301
    #![allow(clippy::unchecked_duration_subtraction)]
302
    #![allow(clippy::useless_vec)]
303
    #![allow(clippy::needless_pass_by_value)]
304
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
305

            
306
    use super::*;
307

            
308
    #[test]
309
    fn display_qty() {
310
        let chk = |by, s: &str| {
311
            assert_eq!(ByteQty(by).to_string(), s, "{s:?}");
312
            assert_eq!(s.parse::<ByteQty>().expect(s).to_string(), s, "{s:?}");
313
        };
314

            
315
        chk(10 * 1024, "10.0 KiB");
316
        chk(1024 * 1024, "1.00 MiB");
317
        chk(1000 * 1024 * 1024, "0.98 GiB");
318
    }
319

            
320
    #[test]
321
    fn parse_qty() {
322
        let chk = |s: &str, b| assert_eq!(s.parse::<ByteQty>(), b, "{s:?}");
323
        let chk_y = |s, v| chk(s, Ok(ByteQty(v)));
324

            
325
        chk_y("1", 1);
326
        chk_y("1B", 1);
327
        chk_y("1KB", 1000);
328
        chk_y("1 KB", 1000);
329
        chk_y("1 KiB", 1024);
330
        chk_y("1.0 KiB", 1024);
331
        chk_y(".00195312499909050529 TiB", 2147483647);
332

            
333
        chk("1 2 K", Err(IBQ::BadSyntax));
334
        chk("1.2 K", Err(IBQ::UnknownUnitMissingB));
335
        chk("no digits", Err(IBQ::BadSyntax));
336
        chk("1 2 KB", Err(IBQ::BadSyntax));
337
        chk("1 mB", Err(IBQ::UnknownUnit));
338
        chk("1.0e100 TiB", Err(IBQ::Overflow));
339
    }
340

            
341
    #[test]
342
    fn convert() {
343
        fn chk(a: impl TryInto<ByteQty, Error = IBQ>, b: Result<ByteQty, IBQ>) {
344
            assert_eq!(a.try_into(), b);
345
        }
346
        fn chk_y(a: impl TryInto<ByteQty, Error = IBQ>, v: usize) {
347
            chk(a, Ok(ByteQty(v)));
348
        }
349

            
350
        chk_y(0.0_f64, 0);
351
        chk_y(1.0_f64, 1);
352
        chk_y(f64::from(u32::MAX), u32::MAX as usize);
353
        chk_y(-0.0_f64, 0);
354

            
355
        chk(-0.01_f64, Err(IBQ::Negative));
356
        chk(1.0e100_f64, Err(IBQ::Overflow));
357
        chk(f64::NAN, Err(IBQ::NaN));
358

            
359
        chk_y(0_u64, 0);
360
        chk_y(u64::from(u32::MAX), u32::MAX as usize);
361
        // we can't easily test the u64 overflow case without getting arch-specific
362
    }
363

            
364
    #[cfg(feature = "serde")]
365
    #[test]
366
    fn serde_deser() {
367
        // Use serde__value so we can try all the exciting things in the serde model
368
        use serde_value::Value as SV;
369

            
370
        let chk = |sv: SV, b: Result<ByteQty, IBQ>| {
371
            assert_eq!(
372
                sv.clone().deserialize_into().map_err(|e| e.to_string()),
373
                b.map_err(|e| e.to_string()),
374
                "{sv:?}",
375
            );
376
        };
377
        let chk_y = |sv, v| chk(sv, Ok(ByteQty(v)));
378
        let chk_bv = |sv| chk(sv, Err(IBQ::BadValue));
379

            
380
        chk_y(SV::U8(1), 1);
381
        chk_y(SV::String("1".to_owned()), 1);
382
        chk_y(SV::String("1 KiB".to_owned()), 1024);
383
        chk_y(SV::I32(i32::MAX), i32::MAX as usize);
384
        chk_y(SV::F32(1.0), 1);
385
        chk_y(SV::F64(f64::from(u32::MAX)), u32::MAX as usize);
386
        chk_y(SV::Bytes("1".to_string().into()), 1);
387

            
388
        chk_bv(SV::Bool(false));
389
        chk_bv(SV::Char('1'));
390
        chk_bv(SV::Unit);
391
        chk_bv(SV::Option(None));
392
        chk_bv(SV::Option(Some(Box::new(SV::String("1".to_owned())))));
393
        chk_bv(SV::Newtype(Box::new(SV::String("1".to_owned()))));
394
        chk_bv(SV::Seq(vec![]));
395
        chk_bv(SV::Map(Default::default()));
396
    }
397

            
398
    #[cfg(feature = "serde")]
399
    #[test]
400
    fn serde_ser() {
401
        // Use serde_json so we don't have to worry about how precisely
402
        // serde decides to encode a usize (eg is it u32 or u64 or what).
403
        assert_eq!(
404
            serde_json::to_value(ByteQty(1)).unwrap(),
405
            serde_json::json!(1),
406
        );
407
    }
408
}