1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
//! Code for turning safelogging on and off.
//!
//! By default, safelogging is on. There are two ways to turn it off: Globally
//! (with [`disable_safe_logging`]) and locally (with
//! [`with_safe_logging_suppressed`]).
use crate::{Error, Result};
use fluid_let::fluid_let;
use std::sync::atomic::{AtomicIsize, Ordering};
/// A global atomic used to track locking guards for enabling and disabling
/// safe-logging.
///
/// The value of this atomic is less than 0 if we have enabled unsafe logging.
/// greater than 0 if we have enabled safe logging, and 0 if nobody cares.
static LOGGING_STATE: AtomicIsize = AtomicIsize::new(0);
fluid_let!(
/// A dynamic variable used to temporarily disable safe-logging.
static SAFE_LOGGING_SUPPRESSED_IN_THREAD: bool
);
/// Returns true if we are displaying sensitive values, false otherwise.
pub(crate) fn unsafe_logging_enabled() -> bool {
LOGGING_STATE.load(Ordering::Relaxed) < 0
|| SAFE_LOGGING_SUPPRESSED_IN_THREAD.get(|v| v == Some(&true))
}
/// Run a given function with the regular `safelog` functionality suppressed.
///
/// The provided function, and everything it calls, will display
/// [`Sensitive`](crate::Sensitive) values as if they were not sensitive.
///
/// # Examples
///
/// ```
/// use safelog::{Sensitive, with_safe_logging_suppressed};
///
/// let string = Sensitive::new("swordfish");
///
/// // Ordinarily, the string isn't displayed as normal
/// assert_eq!(format!("The value is {}", string),
/// "The value is [scrubbed]");
///
/// // But you can override that:
/// assert_eq!(
/// with_safe_logging_suppressed(|| format!("The value is {}", string)),
/// "The value is swordfish"
/// );
/// ```
pub fn with_safe_logging_suppressed<F, V>(func: F) -> V
where
F: FnOnce() -> V,
{
// This sets the value of the variable to Some(true) temporarily, for as
// long as `func` is being called. It uses thread-local variables
// internally.
SAFE_LOGGING_SUPPRESSED_IN_THREAD.set(true, func)
}
/// Enum to describe what kind of a [`Guard`] we've created.
#[derive(Debug, Copy, Clone)]
enum GuardKind {
/// We are forcing safe-logging to be enabled, so that nobody
/// can turn it off with `disable_safe_logging`
Safe,
/// We have are turning safe-logging off with `disable_safe_logging`.
Unsafe,
}
/// A guard object used to enforce safe logging, or turn it off.
///
/// For as long as this object exists, the chosen behavior will be enforced.
//
// TODO: Should there be different types for "keep safe logging on" and "turn
// safe logging off"? Having the same type makes it easier to write code that
// does stuff like this:
//
// let g = if cfg.safe {
// enforce_safe_logging()
// } else {
// disable_safe_logging()
// };
#[derive(Debug)]
#[must_use = "If you drop the guard immediately, it won't do anything."]
pub struct Guard {
/// What kind of guard is this?
kind: GuardKind,
}
impl GuardKind {
/// Return an error if `val` (as a value of `LOGGING_STATE`) indicates that
/// intended kind of guard cannot be created.
fn check(&self, val: isize) -> Result<()> {
match self {
GuardKind::Safe => {
if val < 0 {
return Err(Error::AlreadyUnsafe);
}
}
GuardKind::Unsafe => {
if val > 0 {
return Err(Error::AlreadySafe);
}
}
}
Ok(())
}
/// Return the value by which `LOGGING_STATE` should change while a guard of
/// this type exists.
fn increment(&self) -> isize {
match self {
GuardKind::Safe => 1,
GuardKind::Unsafe => -1,
}
}
}
impl Guard {
/// Helper: Create a guard of a given kind.
fn new(kind: GuardKind) -> Result<Self> {
let inc = kind.increment();
loop {
// Find the current value of LOGGING_STATE and see if this guard can
// be created.
let old_val = LOGGING_STATE.load(Ordering::SeqCst);
// Exit if this guard can't be created.
kind.check(old_val)?;
// Otherwise, try changing LOGGING_STATE to the new value that it
// _should_ have when this guard exists.
let new_val = match old_val.checked_add(inc) {
Some(v) => v,
None => return Err(Error::Overflow),
};
if let Ok(v) =
LOGGING_STATE.compare_exchange(old_val, new_val, Ordering::SeqCst, Ordering::SeqCst)
{
// Great, we set the value to what it should be; we're done.
debug_assert_eq!(v, old_val);
return Ok(Self { kind });
}
// Otherwise, somebody else altered this value concurrently: try
// again.
}
}
}
impl Drop for Guard {
fn drop(&mut self) {
let inc = self.kind.increment();
LOGGING_STATE.fetch_sub(inc, Ordering::SeqCst);
}
}
/// Create a new [`Guard`] to prevent anyone else from disabling safe logging.
///
/// Until the resulting `Guard` is dropped, any attempts to call
/// `disable_safe_logging` will give an error. This guard does _not_ affect
/// calls to [`with_safe_logging_suppressed`].
///
/// This call will return an error if safe logging is _already_ disabled.
///
/// Note that this function is called "enforce", not "enable", since safe
/// logging is enabled by default. Its purpose is to make sure that nothing
/// _else_ has called disable_safe_logging().
pub fn enforce_safe_logging() -> Result<Guard> {
Guard::new(GuardKind::Safe)
}
/// Create a new [`Guard`] to disable safe logging.
///
/// Until the resulting `Guard` is dropped, all [`Sensitive`](crate::Sensitive)
/// values will be displayed as if they were not sensitive.
///
/// This call will return an error if safe logging has been enforced with
/// [`enforce_safe_logging`].
pub fn disable_safe_logging() -> Result<Guard> {
Guard::new(GuardKind::Unsafe)
}
#[cfg(test)]
mod test {
// @@ begin test lint list maintained by maint/add_warning @@
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_duration_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
//! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
use super::*;
// We use "serial_test" to make sure that our tests here run one at a time,
// since they modify global state.
use serial_test::serial;
#[test]
#[serial]
fn guards() {
// Try operations with logging guards turned on and off, in a single
// thread.
assert!(!unsafe_logging_enabled());
let g1 = enforce_safe_logging().unwrap();
let g2 = enforce_safe_logging().unwrap();
assert!(!unsafe_logging_enabled());
let e = disable_safe_logging();
assert!(matches!(e, Err(Error::AlreadySafe)));
assert!(!unsafe_logging_enabled());
drop(g1);
drop(g2);
let _g3 = disable_safe_logging().unwrap();
assert!(unsafe_logging_enabled());
let e = enforce_safe_logging();
assert!(matches!(e, Err(Error::AlreadyUnsafe)));
assert!(unsafe_logging_enabled());
let _g4 = disable_safe_logging().unwrap();
assert!(unsafe_logging_enabled());
}
#[test]
#[serial]
fn suppress() {
// Try out `with_safe_logging_suppressed` and make sure it does what we want
// regardless of the initial state of logging.
{
let _g = enforce_safe_logging().unwrap();
with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
assert!(!unsafe_logging_enabled());
}
{
assert!(!unsafe_logging_enabled());
with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
assert!(!unsafe_logging_enabled());
}
{
let _g = disable_safe_logging().unwrap();
assert!(unsafe_logging_enabled());
with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
}
}
#[test]
#[serial]
fn interfere_1() {
// Make sure that two threads trying to enforce and disable safe logging
// can interfere with each other, but will never enter an incorrect
// state.
use std::thread::{spawn, yield_now};
let thread1 = spawn(|| {
for _ in 0..10_000 {
if let Ok(_g) = enforce_safe_logging() {
assert!(!unsafe_logging_enabled());
yield_now();
assert!(disable_safe_logging().is_err());
}
yield_now();
}
});
let thread2 = spawn(|| {
for _ in 0..10_000 {
if let Ok(_g) = disable_safe_logging() {
assert!(unsafe_logging_enabled());
yield_now();
assert!(enforce_safe_logging().is_err());
}
yield_now();
}
});
thread1.join().unwrap();
thread2.join().unwrap();
}
#[test]
#[serial]
fn interfere_2() {
// Make sure that two threads trying to disable safe logging don't
// interfere.
use std::thread::{spawn, yield_now};
let thread1 = spawn(|| {
for _ in 0..10_000 {
let g = disable_safe_logging().unwrap();
assert!(unsafe_logging_enabled());
yield_now();
drop(g);
yield_now();
}
});
let thread2 = spawn(|| {
for _ in 0..10_000 {
let g = disable_safe_logging().unwrap();
assert!(unsafe_logging_enabled());
yield_now();
drop(g);
yield_now();
}
});
thread1.join().unwrap();
thread2.join().unwrap();
}
#[test]
#[serial]
fn interfere_3() {
// Make sure that `with_safe_logging_suppressed` only applies to the
// current thread.
use std::thread::{spawn, yield_now};
let thread1 = spawn(|| {
for _ in 0..10_000 {
assert!(!unsafe_logging_enabled());
yield_now();
}
});
let thread2 = spawn(|| {
for _ in 0..10_000 {
assert!(!unsafe_logging_enabled());
with_safe_logging_suppressed(|| {
assert!(unsafe_logging_enabled());
yield_now();
});
}
});
thread1.join().unwrap();
thread2.join().unwrap();
}
}