"""Util functions to manage sbws configuration files."""
import logging
import logging.config
import os
from configparser import (
ConfigParser,
ExtendedInterpolation,
InterpolationMissingOptionError,
)
from string import Template
from tempfile import NamedTemporaryFile
from urllib.parse import urlparse
from sbws.globals import (
DEFAULT_CONFIG_PATH,
DEFAULT_LOG_CONFIG_PATH,
DIRAUTH_NICKNAMES,
SUPERVISED_RUN_DPATH,
SUPERVISED_USER_CONFIG_PATH,
USER_CONFIG_PATH,
)
from sbws.util.iso3166 import ISO_3166_ALPHA_2
_ALPHANUM = "abcdefghijklmnopqrstuvwxyz"
_ALPHANUM += _ALPHANUM.upper()
_ALPHANUM += "0123456789"
_SYMBOLS_NO_QUOTES = "!@#$%^&*()-_=+\\|[]{}:;/?.,<>"
_HEX = "0123456789ABCDEF"
_LOG_LEVELS = ["debug", "info", "warning", "error", "critical"]
log = logging.getLogger(__name__)
def _expand_path(path):
"""Expand path string containing shell variables and ~ constructions
into their values. Environment variables have to have their $ escaped by
another $. For example: $$XDG_RUNTIME_DIR/foo.bar
"""
return os.path.expanduser(os.path.expandvars(path))
def _extend_config(conf, fname):
"""Extend ConfigParser from file configuration."""
log.debug("Reading config file %s", fname)
with open(fname, "rt") as fd:
conf.read_file(fd, source=fname)
return conf
def _get_default_config():
"""Return ConfigParser with default configuration."""
conf = ConfigParser(
interpolation=ExtendedInterpolation(),
converters={"path": _expand_path},
)
return _extend_config(conf, DEFAULT_CONFIG_PATH)
def _obtain_user_conf_path():
if os.environ.get("SUPERVISED") == "1":
return SUPERVISED_USER_CONFIG_PATH
return USER_CONFIG_PATH
def _get_user_config(args, conf=None):
"""Get user configuration.
Search for user configuration in the default path or the path passed as
argument and extend the configuration if they are found.
"""
if not conf:
conf = ConfigParser(
interpolation=ExtendedInterpolation(),
converters={"path": _expand_path},
)
if args.config:
if not os.path.isfile(args.config):
# XXX: The logger is not configured at this stage,
# sbws should start with a logger before reading configurations.
# #40110: while the log is not configured, do not print, so that
# no output is generated unless there're warnings.
print(
"Configuration file %s not found, using defaults."
% args.config
)
return conf
# print("Using configuration provided as argument %s" % args.config)
return _extend_config(conf, args.config)
user_config_path = _obtain_user_conf_path()
if os.path.isfile(user_config_path):
# print("Using configuration file %s" % user_config_path)
return _extend_config(conf, user_config_path)
log.debug("No user config found, using defaults.")
return conf
def _get_default_logging_config(conf=None):
"""Get default logging configuration."""
if not conf:
conf = ConfigParser(
interpolation=ExtendedInterpolation(),
converters={"path": _expand_path},
)
return _extend_config(conf, DEFAULT_LOG_CONFIG_PATH)
[docs]
def get_config(args):
"""Get ConfigParser interpolating all configuration files."""
conf = _get_default_config()
conf = _get_default_logging_config(conf=conf)
conf = _get_user_config(args, conf=conf)
return conf
def _can_log_to_file(conf):
"""
Checks all the known reasons for why we might not be able to log to a file,
and returns whether or not we think we will be able to do so. This is
useful because if we can't log to a file, we might want to force logging to
stdout.
If we can't log to file, return False and the reason. Otherwise return True
and an empty string.
"""
# We won't be able to get paths.log_dname from the config when we are first
# initializing sbws because it depends on paths.sbws_home (by default).
# If there is an issue getting this option, tell the caller that we can't
# log to file.
try:
conf.getpath("paths", "log_dname")
except InterpolationMissingOptionError as e:
return False, e
return True, ""
[docs]
def validate_config(conf):
"""Checks the given conf for bad values or bad combinations of values. If
there's something wrong, returns False and a list of error messages.
Otherwise, return True and an empty list"""
errors = []
errors.extend(_validate_general(conf))
errors.extend(_validate_cleanup(conf))
errors.extend(_validate_scanner(conf))
errors.extend(_validate_tor(conf))
errors.extend(_validate_paths(conf))
errors.extend(_validate_destinations(conf))
errors.extend(_validate_relayprioritizer(conf))
errors.extend(_validate_logging(conf))
return len(errors) < 1, errors
def _validate_cleanup(conf):
errors = []
sec = "cleanup"
err_tmpl = Template("$sec/$key ($val): $e")
ints = {
"data_files_compress_after_days": {"minimum": 1, "maximum": None},
"data_files_delete_after_days": {"minimum": 1, "maximum": None},
"v3bw_files_compress_after_days": {"minimum": 1, "maximum": None},
"v3bw_files_delete_after_days": {"minimum": 1, "maximum": None},
}
all_valid_keys = list(ints.keys())
errors.extend(_validate_section_keys(conf, sec, all_valid_keys, err_tmpl))
errors.extend(_validate_section_ints(conf, sec, ints, err_tmpl))
return errors
def _validate_general(conf):
errors = []
sec = "general"
err_tmpl = Template("$sec/$key ($val): $e")
ints = {
"data_period": {"minimum": 1, "maximum": None},
"circuit_timeout": {"minimum": 1, "maximum": None},
}
floats = {
"http_timeout": {"minimum": 0.0, "maximum": None},
}
bools = {
"reset_bw_ipv4_changes": {},
"reset_bw_ipv6_changes": {},
}
all_valid_keys = (
list(ints.keys()) + list(floats.keys()) + list(bools.keys())
)
errors.extend(_validate_section_keys(conf, sec, all_valid_keys, err_tmpl))
errors.extend(_validate_section_ints(conf, sec, ints, err_tmpl))
errors.extend(_validate_section_floats(conf, sec, floats, err_tmpl))
errors.extend(_validate_section_bools(conf, sec, bools, err_tmpl))
return errors
def _obtain_sbws_home(conf):
sbws_home = conf.getpath("paths", "sbws_home")
# No need for .sbws when this is the default home
if sbws_home == "/var/lib/sbws/.sbws":
conf["paths"]["sbws_home"] = os.path.dirname(sbws_home)
def _obtain_run_dpath(conf):
"""Set runtime directory when sbws is run by a system service."""
xdg = os.environ.get("XDG_RUNTIME_DIR")
if os.environ.get("SUPERVISED") == "1":
conf["tor"]["run_dpath"] = SUPERVISED_RUN_DPATH
elif xdg is not None:
conf["tor"]["run_dpath"] = os.path.join(xdg, "sbws", "tor")
def _validate_paths(conf):
_obtain_sbws_home(conf)
errors = []
sec = "paths"
err_tmpl = Template("$sec/$key ($val): $e")
unvalidated_keys = [
"datadir",
"sbws_home",
"v3bw_fname",
"v3bw_dname",
"state_fname",
"log_dname",
]
all_valid_keys = unvalidated_keys
allow_missing = ["sbws_home"]
errors.extend(
_validate_section_keys(
conf, sec, all_valid_keys, err_tmpl, allow_missing=allow_missing
)
)
return errors
def _validate_country(conf, sec, key, err_tmpl):
errors = []
if conf[sec].get(key, None) is None:
errors.append(
err_tmpl.substitute(
sec=sec,
key=key,
val=None,
e="Missing country in configuration file.",
)
)
return errors
valid = conf[sec]["country"] in ISO_3166_ALPHA_2
if not valid:
errors.append(
err_tmpl.substitute(
sec=sec,
key=key,
val=conf[sec][key],
e="Not a valid ISO 3166 alpha-2 country code.",
)
)
return errors
def _validate_dirauth_nickname(conf, sec, key, err_tmpl):
errors = []
if conf[sec].get(key, None) is None:
errors.append(
err_tmpl.substitute(
sec=sec,
key=key,
val=None,
e="Missing dirauth_nickname in configuration file.",
)
)
return errors
valid = conf[sec]["dirauth_nickname"] in DIRAUTH_NICKNAMES
if not valid:
errors.append(
err_tmpl.substitute(
sec=sec,
key=key,
val=conf[sec][key],
e="Not a valid dirauth_nickname.",
)
)
return errors
def _validate_scanner(conf):
errors = []
sec = "scanner"
err_tmpl = Template("$sec/$key ($val): $e")
ints = {
"num_rtts": {"minimum": 0, "maximum": 100},
"num_downloads": {"minimum": 1, "maximum": 100},
"initial_read_request": {"minimum": 1, "maximum": None},
"measurement_threads": {"minimum": 1, "maximum": None},
"min_download_size": {"minimum": 1, "maximum": None},
"max_download_size": {"minimum": 1, "maximum": None},
"http_post_initial_bytes": {"minimum": 1, "maximum": None},
}
floats = {
"download_toofast": {"minimum": 0.001, "maximum": None},
"download_min": {"minimum": 0.001, "maximum": None},
"download_target": {"minimum": 0.001, "maximum": None},
"download_max": {"minimum": 0.001, "maximum": None},
}
all_valid_keys = (
list(ints.keys())
+ list(floats.keys())
+ ["nickname", "country", "dirauth_nickname", "upload"]
)
errors.extend(_validate_section_keys(conf, sec, all_valid_keys, err_tmpl))
errors.extend(_validate_section_ints(conf, sec, ints, err_tmpl))
errors.extend(_validate_section_floats(conf, sec, floats, err_tmpl))
valid, error_msg = _validate_nickname(conf[sec], "nickname")
if not valid:
errors.append(
err_tmpl.substitute(
sec=sec, key="nickname", val=conf[sec]["nickname"], e=error_msg
)
)
errors.extend(_validate_country(conf, sec, "country", err_tmpl))
errors.extend(
_validate_dirauth_nickname(conf, sec, "dirauth_nickname", err_tmpl)
)
return errors
def _validate_tor(conf):
_obtain_run_dpath(conf)
errors = []
sec = "tor"
err_tmpl = Template("$sec/$key ($val): $e")
unvalidated_keys = [
"datadir",
"run_dpath",
"control_socket",
"pid",
"log",
"external_control_port",
"extra_lines",
]
all_valid_keys = unvalidated_keys
errors.extend(_validate_section_keys(conf, sec, all_valid_keys, err_tmpl))
return errors
def _validate_relayprioritizer(conf):
errors = []
sec = "relayprioritizer"
err_tmpl = Template("$sec/$key ($val): $e")
ints = {
"min_relays": {"minimum": 1, "maximum": None},
}
floats = {
"fraction_relays": {"minimum": 0.0, "maximum": 1.0},
}
bools = {
"measure_authorities": {},
}
all_valid_keys = (
list(ints.keys()) + list(floats.keys()) + list(bools.keys())
)
errors.extend(_validate_section_keys(conf, sec, all_valid_keys, err_tmpl))
errors.extend(_validate_section_ints(conf, sec, ints, err_tmpl))
errors.extend(_validate_section_floats(conf, sec, floats, err_tmpl))
errors.extend(_validate_section_bools(conf, sec, bools, err_tmpl))
return errors
def _validate_logging(conf):
errors = []
sec = "logging"
err_tmpl = Template("$sec/$key ($val): $e")
enums = {
"level": {"choices": _LOG_LEVELS},
"to_file_level": {"choices": _LOG_LEVELS},
"to_stdout_level": {"choices": _LOG_LEVELS},
"to_syslog_level": {"choices": _LOG_LEVELS},
}
bools = {
"to_file": {},
"to_stdout": {},
"to_syslog": {},
}
ints = {
"to_file_max_bytes": {"minimum": 0, "maximum": None},
"to_file_num_backups": {"minimum": 0, "maximum": None},
"to_file_interval": {"minimum": 0, "maximum": None},
}
unvalidated = [
"to_file_when",
"format",
"to_file_format",
"to_stdout_format",
"to_syslog_format",
]
all_valid_keys = (
list(bools.keys())
+ list(enums.keys())
+ list(ints.keys())
+ unvalidated
)
errors.extend(_validate_section_keys(conf, sec, all_valid_keys, err_tmpl))
errors.extend(_validate_section_bools(conf, sec, bools, err_tmpl))
errors.extend(_validate_section_enums(conf, sec, enums, err_tmpl))
return errors
def _validate_destinations(conf):
errors = []
sec = "destinations"
section = conf[sec]
err_tmpl = Template("$sec/$key ($val): $e")
dest_sections = []
for key in section.keys():
if key == "usability_test_interval":
value = section[key]
valid, error_msg = _validate_int(section, key, minimum=1)
if not valid:
errors.append(
err_tmpl.substitute(
sec=sec, key=key, val=value, e=error_msg
)
)
continue
value = section[key]
valid, error_msg = _validate_boolean(section, key)
if not valid:
errors.append(
err_tmpl.substitute(sec=sec, key=key, val=value, e=error_msg)
)
continue
if section.getboolean(key):
dest_sections.append("{}.{}".format(sec, key))
urls = {
"url": {},
}
all_valid_keys = list(urls.keys()) + [
"verify",
"country",
"max_num_failures",
]
for sec in dest_sections:
if sec not in conf:
errors.append(
"{} is an enabled destination but is not a "
"section in the config".format(sec)
)
continue
errors.extend(
_validate_section_keys(
conf,
sec,
all_valid_keys,
err_tmpl,
allow_missing=["verify", "max_num_failures"],
)
)
errors.extend(_validate_section_urls(conf, sec, urls, err_tmpl))
errors.extend(_validate_country(conf, sec, "country", err_tmpl))
return errors
def _validate_section_keys(conf, sec, keys, tmpl, allow_missing=None):
if allow_missing is None:
allow_missing = []
errors = []
section = conf[sec]
# Find keys that exist in the user's config that are not known
for key in section:
if key not in keys:
errors.append(
tmpl.substitute(
sec=sec, key=key, val=section[key], e="Unknown key"
)
)
# Find keys that don't exist in the user's config that should
for key in keys:
if key not in section and key not in allow_missing:
errors.append(
tmpl.substitute(
sec=sec, key=key, val="[NOT SET]", e="Missing key"
)
)
return errors
def _validate_section_ints(conf, sec, ints, tmpl):
errors = []
section = conf[sec]
for key in ints:
valid, error = _validate_int(
section,
key,
minimum=ints[key]["minimum"],
maximum=ints[key]["maximum"],
)
if not valid:
errors.append(
tmpl.substitute(sec=sec, key=key, val=section[key], e=error)
)
return errors
def _validate_section_floats(conf, sec, floats, tmpl):
errors = []
section = conf[sec]
for key in floats:
valid, error = _validate_float(
section,
key,
minimum=floats[key]["minimum"],
maximum=floats[key]["maximum"],
)
if not valid:
errors.append(
tmpl.substitute(sec=sec, key=key, val=section[key], e=error)
)
return errors
def _validate_section_hosts(conf, sec, hosts, tmpl):
errors = []
section = conf[sec]
for key in hosts:
valid, error = _validate_host(section, key)
if not valid:
errors.append(
tmpl.substitute(sec=sec, key=key, val=section[key], e=error)
)
return errors
def _validate_section_ports(conf, sec, ports, tmpl):
errors = []
section = conf[sec]
for key in ports:
valid, error = _validate_int(section, key, minimum=1, maximum=2**16)
if not valid:
errors.append(
tmpl.substitute(
sec=sec,
key=key,
val=section[key],
e="Not a valid port ({})".format(error),
)
)
return errors
def _validate_section_bools(conf, sec, bools, tmpl):
errors = []
section = conf[sec]
for key in bools:
valid, error = _validate_boolean(section, key)
if not valid:
errors.append(
tmpl.substitute(
sec=sec,
key=key,
val=section[key],
e="Not a valid boolean string ({})".format(error),
)
)
return errors
def _validate_section_fingerprints(conf, sec, fps, tmpl):
errors = []
section = conf[sec]
for key in fps:
valid, error = _validate_fingerprint(section, key)
if not valid:
errors.append(
tmpl.substitute(
sec=sec,
key=key,
val=section[key],
e="Not a valid fingerprint ({})".format(error),
)
)
return errors
def _validate_section_urls(conf, sec, urls, tmpl):
errors = []
section = conf[sec]
for key in urls:
valid, error = _validate_url(section, key)
if not valid:
errors.append(
tmpl.substitute(
sec=sec,
key=key,
val=section[key],
e="Not a valid url ({})".format(error),
)
)
return errors
def _validate_section_enums(conf, sec, enums, tmpl):
errors = []
section = conf[sec]
for key in enums:
choices = enums[key]["choices"]
valid, error = _validate_enum(section, key, choices)
if not valid:
errors.append(
tmpl.substitute(
sec=sec,
key=key,
val=section[key],
e="Not a valid enum choice ({})".format(
", ".join(choices)
),
)
)
return errors
def _validate_enum(section, key, choices):
value = section[key]
if value not in choices:
return False, "{} not in allowed choices: {}".format(
value, ", ".join(choices)
)
return True, ""
def _validate_url(section, key):
value = section[key]
url = urlparse(value)
if not url.netloc:
return False, "Does not appear to contain a hostname"
# It should be possible to have an URL that starts by http:// that uses
# TLS,but python requests is just checking the scheme starts by https
# when verifying certificate:
# https://github.com/requests/requests/blob/master/requests/adapters.py#L215 # noqa
# When the scheme is https but the protocol is not TLS, requests will hang.
# For tests we use 127.0.0.1, allow this address without https scheme.
if url.scheme != "https" and url.hostname != "127.0.0.1":
return False, "URL scheme must be HTTPS (except for the test server)"
return True, ""
def _validate_int(section, key, minimum=None, maximum=None):
try:
value = section.getint(key)
except ValueError as e:
return False, e
if minimum is not None:
if value < minimum:
return False, "Cannot be less than {}".format(minimum)
if maximum is not None:
if value > maximum:
return False, "Cannot be greater than {}".format(maximum)
return True, ""
def _validate_boolean(section, key):
try:
section.getboolean(key)
except ValueError as e:
return False, e
return True, ""
def _validate_float(section, key, minimum=None, maximum=None):
try:
value = section.getfloat(key)
except ValueError as e:
return False, e
if minimum is not None:
if value < minimum:
return False, "Cannot be less than {}".format(minimum)
if maximum is not None:
if value > maximum:
return False, "Cannot be greater than {}".format(maximum)
return True, ""
def _validate_host(section, key):
# XXX: Implement this
return True, ""
def _validate_fingerprint(section, key):
alphabet = _HEX
length = 40
return _validate_string(
section, key, min_len=length, max_len=length, alphabet=alphabet
)
def _validate_nickname(section, key):
alphabet = _ALPHANUM + _SYMBOLS_NO_QUOTES
min_len = 1
max_len = 32
return _validate_string(
section, key, min_len=min_len, max_len=max_len, alphabet=alphabet
)
def _validate_string(
section, key, min_len=None, max_len=None, alphabet=None, starts_with=None
):
s = section[key]
if min_len is not None and len(s) < min_len:
return False, "{} is below minimum allowed length {}".format(
len(s), min_len
)
if max_len is not None and len(s) > max_len:
return False, "{} is above maximum allowed length {}".format(
len(s), max_len
)
if alphabet is not None:
for i, c in enumerate(s):
if c not in alphabet:
return (
False,
"Letter {} at position {} is not in allowed "
'characters "{}"'.format(c, i, alphabet),
)
if starts_with is not None:
if not s.startswith(starts_with):
return False, "{} does not start with {}".format(s, starts_with)
return True, ""