1
//! Code for turning safelogging on and off.
2
//!
3
//! By default, safelogging is on.  There are two ways to turn it off: Globally
4
//! (with [`disable_safe_logging`]) and locally (with
5
//! [`with_safe_logging_suppressed`]).
6

            
7
use crate::{Error, Result};
8
use fluid_let::fluid_let;
9
use std::sync::atomic::{AtomicIsize, Ordering};
10

            
11
/// A global atomic used to track locking guards for enabling and disabling
12
/// safe-logging.
13
///
14
/// The value of this atomic is less than 0 if we have enabled unsafe logging.
15
/// greater than 0 if we have enabled safe logging, and 0 if nobody cares.
16
static LOGGING_STATE: AtomicIsize = AtomicIsize::new(0);
17

            
18
fluid_let!(
19
    /// A dynamic variable used to temporarily disable safe-logging.
20
    static SAFE_LOGGING_SUPPRESSED_IN_THREAD: bool
21
);
22

            
23
/// Returns true if we are displaying sensitive values, false otherwise.
24
120851
pub(crate) fn unsafe_logging_enabled() -> bool {
25
120851
    LOGGING_STATE.load(Ordering::Relaxed) < 0
26
105810
        || SAFE_LOGGING_SUPPRESSED_IN_THREAD.get(|v| v == Some(&true))
27
120851
}
28

            
29
/// Run a given function with the regular `safelog` functionality suppressed.
30
///
31
/// The provided function, and everything it calls, will display
32
/// [`Sensitive`](crate::Sensitive) values as if they were not sensitive.
33
///
34
/// # Examples
35
///
36
/// ```
37
/// use safelog::{Sensitive, with_safe_logging_suppressed};
38
///
39
/// let string = Sensitive::new("swordfish");
40
///
41
/// // Ordinarily, the string isn't displayed as normal
42
/// assert_eq!(format!("The value is {}", string),
43
///            "The value is [scrubbed]");
44
///
45
/// // But you can override that:
46
/// assert_eq!(
47
///     with_safe_logging_suppressed(|| format!("The value is {}", string)),
48
///     "The value is swordfish"
49
/// );
50
/// ```
51
20036
pub fn with_safe_logging_suppressed<F, V>(func: F) -> V
52
20036
where
53
20036
    F: FnOnce() -> V,
54
20036
{
55
20036
    // This sets the value of the variable to Some(true) temporarily, for as
56
20036
    // long as `func` is being called.  It uses thread-local variables
57
20036
    // internally.
58
20036
    SAFE_LOGGING_SUPPRESSED_IN_THREAD.set(true, func)
59
20036
}
60

            
61
/// Enum to describe what kind of a [`Guard`] we've created.
62
#[derive(Debug, Copy, Clone)]
63
enum GuardKind {
64
    /// We are forcing safe-logging to be enabled, so that nobody
65
    /// can turn it off with `disable_safe_logging`
66
    Safe,
67
    /// We have are turning safe-logging off with `disable_safe_logging`.
68
    Unsafe,
69
}
70

            
71
/// A guard object used to enforce safe logging, or turn it off.
72
///
73
/// For as long as this object exists, the chosen behavior will be enforced.
74
//
75
// TODO: Should there be different types for "keep safe logging on" and "turn
76
// safe logging off"?  Having the same type makes it easier to write code that
77
// does stuff like this:
78
//
79
//     let g = if cfg.safe {
80
//         enforce_safe_logging()
81
//     } else {
82
//         disable_safe_logging()
83
//     };
84
#[derive(Debug)]
85
#[must_use = "If you drop the guard immediately, it won't do anything."]
86
pub struct Guard {
87
    /// What kind of guard is this?
88
    kind: GuardKind,
89
}
90

            
91
impl GuardKind {
92
    /// Return an error if `val` (as a value of `LOGGING_STATE`) indicates that
93
    /// intended kind of guard cannot be created.
94
100328
    fn check(&self, val: isize) -> Result<()> {
95
100328
        match self {
96
            GuardKind::Safe => {
97
30384
                if val < 0 {
98
20272
                    return Err(Error::AlreadyUnsafe);
99
10112
                }
100
            }
101
            GuardKind::Unsafe => {
102
69944
                if val > 0 {
103
19732
                    return Err(Error::AlreadySafe);
104
50212
                }
105
            }
106
        }
107
60324
        Ok(())
108
100328
    }
109
    /// Return the value by which `LOGGING_STATE` should change while a guard of
110
    /// this type exists.
111
159228
    fn increment(&self) -> isize {
112
159228
        match self {
113
39612
            GuardKind::Safe => 1,
114
119616
            GuardKind::Unsafe => -1,
115
        }
116
159228
    }
117
}
118

            
119
impl Guard {
120
    /// Helper: Create a guard of a given kind.
121
99616
    fn new(kind: GuardKind) -> Result<Self> {
122
99616
        let inc = kind.increment();
123
        loop {
124
            // Find the current value of LOGGING_STATE and see if this guard can
125
            // be created.
126
100328
            let old_val = LOGGING_STATE.load(Ordering::SeqCst);
127
100328
            // Exit if this guard can't be created.
128
100328
            kind.check(old_val)?;
129
            // Otherwise, try changing LOGGING_STATE to the new value that it
130
            // _should_ have when this guard exists.
131
60324
            let new_val = match old_val.checked_add(inc) {
132
60324
                Some(v) => v,
133
                None => return Err(Error::Overflow),
134
            };
135
59612
            if let Ok(v) =
136
60324
                LOGGING_STATE.compare_exchange(old_val, new_val, Ordering::SeqCst, Ordering::SeqCst)
137
            {
138
                // Great, we set the value to what it should be; we're done.
139
59612
                debug_assert_eq!(v, old_val);
140
59612
                return Ok(Self { kind });
141
712
            }
142
            // Otherwise, somebody else altered this value concurrently: try
143
            // again.
144
        }
145
99616
    }
146
}
147

            
148
impl Drop for Guard {
149
59612
    fn drop(&mut self) {
150
59612
        let inc = self.kind.increment();
151
59612
        LOGGING_STATE.fetch_sub(inc, Ordering::SeqCst);
152
59612
    }
153
}
154

            
155
/// Create a new [`Guard`] to prevent anyone else from disabling safe logging.
156
///
157
/// Until the resulting `Guard` is dropped, any attempts to call
158
/// `disable_safe_logging` will give an error.  This guard does _not_ affect
159
/// calls to [`with_safe_logging_suppressed`].
160
///
161
/// This call will return an error if safe logging is _already_ disabled.
162
///
163
/// Note that this function is called "enforce", not "enable", since safe
164
/// logging is enabled by default.  Its purpose is to make sure that nothing
165
/// _else_ has called disable_safe_logging().
166
29942
pub fn enforce_safe_logging() -> Result<Guard> {
167
29942
    Guard::new(GuardKind::Safe)
168
29942
}
169

            
170
/// Create a new [`Guard`] to disable safe logging.
171
///
172
/// Until the resulting `Guard` is dropped, all [`Sensitive`](crate::Sensitive)
173
/// values will be displayed as if they were not sensitive.
174
///
175
/// This call will return an error if safe logging has been enforced with
176
/// [`enforce_safe_logging`].
177
69674
pub fn disable_safe_logging() -> Result<Guard> {
178
69674
    Guard::new(GuardKind::Unsafe)
179
69674
}
180

            
181
#[cfg(test)]
182
mod test {
183
    // @@ begin test lint list maintained by maint/add_warning @@
184
    #![allow(clippy::bool_assert_comparison)]
185
    #![allow(clippy::clone_on_copy)]
186
    #![allow(clippy::dbg_macro)]
187
    #![allow(clippy::mixed_attributes_style)]
188
    #![allow(clippy::print_stderr)]
189
    #![allow(clippy::print_stdout)]
190
    #![allow(clippy::single_char_pattern)]
191
    #![allow(clippy::unwrap_used)]
192
    #![allow(clippy::unchecked_duration_subtraction)]
193
    #![allow(clippy::useless_vec)]
194
    #![allow(clippy::needless_pass_by_value)]
195
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
196
    use super::*;
197
    // We use "serial_test" to make sure that our tests here run one at a time,
198
    // since they modify global state.
199
    use serial_test::serial;
200

            
201
    #[test]
202
    #[serial]
203
    fn guards() {
204
        // Try operations with logging guards turned on and off, in a single
205
        // thread.
206
        assert!(!unsafe_logging_enabled());
207
        let g1 = enforce_safe_logging().unwrap();
208
        let g2 = enforce_safe_logging().unwrap();
209

            
210
        assert!(!unsafe_logging_enabled());
211

            
212
        let e = disable_safe_logging();
213
        assert!(matches!(e, Err(Error::AlreadySafe)));
214
        assert!(!unsafe_logging_enabled());
215

            
216
        drop(g1);
217
        drop(g2);
218
        let _g3 = disable_safe_logging().unwrap();
219
        assert!(unsafe_logging_enabled());
220
        let e = enforce_safe_logging();
221
        assert!(matches!(e, Err(Error::AlreadyUnsafe)));
222
        assert!(unsafe_logging_enabled());
223
        let _g4 = disable_safe_logging().unwrap();
224

            
225
        assert!(unsafe_logging_enabled());
226
    }
227

            
228
    #[test]
229
    #[serial]
230
    fn suppress() {
231
        // Try out `with_safe_logging_suppressed` and make sure it does what we want
232
        // regardless of the initial state of logging.
233
        {
234
            let _g = enforce_safe_logging().unwrap();
235
            with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
236
            assert!(!unsafe_logging_enabled());
237
        }
238

            
239
        {
240
            assert!(!unsafe_logging_enabled());
241
            with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
242
            assert!(!unsafe_logging_enabled());
243
        }
244

            
245
        {
246
            let _g = disable_safe_logging().unwrap();
247
            assert!(unsafe_logging_enabled());
248
            with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
249
        }
250
    }
251

            
252
    #[test]
253
    #[serial]
254
    fn interfere_1() {
255
        // Make sure that two threads trying to enforce and disable safe logging
256
        // can interfere with each other, but will never enter an incorrect
257
        // state.
258
        use std::thread::{spawn, yield_now};
259

            
260
        let thread1 = spawn(|| {
261
            for _ in 0..10_000 {
262
                if let Ok(_g) = enforce_safe_logging() {
263
                    assert!(!unsafe_logging_enabled());
264
                    yield_now();
265
                    assert!(disable_safe_logging().is_err());
266
                }
267
                yield_now();
268
            }
269
        });
270

            
271
        let thread2 = spawn(|| {
272
            for _ in 0..10_000 {
273
                if let Ok(_g) = disable_safe_logging() {
274
                    assert!(unsafe_logging_enabled());
275
                    yield_now();
276
                    assert!(enforce_safe_logging().is_err());
277
                }
278
                yield_now();
279
            }
280
        });
281

            
282
        thread1.join().unwrap();
283
        thread2.join().unwrap();
284
    }
285

            
286
    #[test]
287
    #[serial]
288
    fn interfere_2() {
289
        // Make sure that two threads trying to disable safe logging don't
290
        // interfere.
291
        use std::thread::{spawn, yield_now};
292

            
293
        let thread1 = spawn(|| {
294
            for _ in 0..10_000 {
295
                let g = disable_safe_logging().unwrap();
296
                assert!(unsafe_logging_enabled());
297
                yield_now();
298
                drop(g);
299
                yield_now();
300
            }
301
        });
302

            
303
        let thread2 = spawn(|| {
304
            for _ in 0..10_000 {
305
                let g = disable_safe_logging().unwrap();
306
                assert!(unsafe_logging_enabled());
307
                yield_now();
308
                drop(g);
309
                yield_now();
310
            }
311
        });
312

            
313
        thread1.join().unwrap();
314
        thread2.join().unwrap();
315
    }
316

            
317
    #[test]
318
    #[serial]
319
    fn interfere_3() {
320
        // Make sure that `with_safe_logging_suppressed` only applies to the
321
        // current thread.
322
        use std::thread::{spawn, yield_now};
323

            
324
        let thread1 = spawn(|| {
325
            for _ in 0..10_000 {
326
                assert!(!unsafe_logging_enabled());
327
                yield_now();
328
            }
329
        });
330

            
331
        let thread2 = spawn(|| {
332
            for _ in 0..10_000 {
333
                assert!(!unsafe_logging_enabled());
334
                with_safe_logging_suppressed(|| {
335
                    assert!(unsafe_logging_enabled());
336
                    yield_now();
337
                });
338
            }
339
        });
340

            
341
        thread1.join().unwrap();
342
        thread2.join().unwrap();
343
    }
344
}