diff --git a/.gitignore b/.gitignore index 8aa81acd..ef94a9f9 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ _trial_temp.lock /sydent.db /sydent.pid /matrix_is_test/sydent.stderr +/matrix_is_test/sydent.log diff --git a/changelog.d/402.misc b/changelog.d/402.misc new file mode 100644 index 00000000..269b25c9 --- /dev/null +++ b/changelog.d/402.misc @@ -0,0 +1 @@ +Convert ConfigParser object to dict and prevent sydent from creating DEFAULT section in config file (#287). \ No newline at end of file diff --git a/matrix_is_test/launcher.py b/matrix_is_test/launcher.py index 8d60263e..5f9df0cc 100644 --- a/matrix_is_test/launcher.py +++ b/matrix_is_test/launcher.py @@ -33,6 +33,7 @@ terms.path = {terms_path} templates.path = {testsubject_path}/res brand.default = is-test +log.path = {testsubject_path}/sydent.log ip.whitelist = 127.0.0.1 @@ -89,18 +90,15 @@ def launch(self): } ) - stderr_fp = open(os.path.join(testsubject_path, "sydent.stderr"), "w") - pybin = os.getenv("SYDENT_PYTHON", "python") self.process = Popen( args=[pybin, "-m", "sydent.sydent"], cwd=self.tmpdir, env=newEnv, - stderr=stderr_fp, ) # XXX: wait for startup in a sensible way - time.sleep(2) + time.sleep(10) self._baseUrl = "http://localhost:%d" % (port,) diff --git a/sydent/config/__init__.py b/sydent/config/__init__.py index b4c16c1d..b21b586e 100644 --- a/sydent/config/__init__.py +++ b/sydent/config/__init__.py @@ -21,6 +21,8 @@ from twisted.python import log +from sydent.config._base import CONFIG_PARSER_DICT +from sydent.config._configparser import SydentConfigParser from sydent.config.crypto import CryptoConfig from sydent.config.database import DatabaseConfig from sydent.config.email import EmailConfig @@ -182,7 +184,7 @@ def __init__(self): self.http, ] - def _parse_config(self, cfg: ConfigParser) -> bool: + def _parse_config(self, cfg: CONFIG_PARSER_DICT) -> bool: """ Run the parse_config method on each of the objects in self.config_sections @@ -201,9 +203,9 @@ def _parse_config(self, cfg: ConfigParser) -> bool: return needs_saving - def _parse_from_config_parser(self, cfg: ConfigParser) -> bool: + def _parse_from_dict(self, cfg: CONFIG_PARSER_DICT) -> bool: """ - Parse the configuration from a ConfigParser object + Parse the configuration from a dict :param cfg: the configuration to be parsed @@ -214,6 +216,30 @@ def _parse_from_config_parser(self, cfg: ConfigParser) -> bool: """ return self._parse_config(cfg) + def _parse_from_sydent_config_parser(self, cfg: SydentConfigParser) -> bool: + """ + Parse the configuration from a SydentConfigParser object + + :param cfg: the configuration to be parsed + + :return: whether or not the config file needs updating. This method CAN + return True, but it *shouldn't*. Instead a ConfigError exception + should be raised. This is left in for the soon to be deprecated way + of generating config files. + """ + config_dict: CONFIG_PARSER_DICT = {} + for section in cfg.sections(): + config_dict[section] = {} + # Copy in any values that are in the DEFAULT section + # This must be done first as they might be overwritten + for key, val in cfg.items(DEFAULTSECT): + config_dict[section][key] = val + # Copy in the values set in this section + for key, val in cfg.items(section): + config_dict[section][key] = val + + return self._parse_from_dict(config_dict) + def parse_config_file( self, config_file: str, skip_logging_setup: bool = False ) -> None: @@ -224,15 +250,14 @@ def parse_config_file( :param config_file: the file to be parsed """ # If the config file doesn't exist, prepopulate the config object - # with the defaults, in the DEFAULT section. + # with the defaults. new_config_file = not os.path.exists(config_file) - cfg = ConfigParser() + cfg = SydentConfigParser() for sect, entries in CONFIG_DEFAULTS.items(): cfg.add_section(sect) for k, v in entries.items(): - cfg.set(DEFAULTSECT if new_config_file else sect, k, v) - + cfg.set(sect, k, v) cfg.read(config_file) # Logging is configured in cfg, but these options must be parsed first @@ -240,7 +265,7 @@ def parse_config_file( if not skip_logging_setup: setup_logging(cfg) - needs_updating = self._parse_from_config_parser(cfg) + needs_updating = self._parse_from_sydent_config_parser(cfg) # Don't edit config file when starting Sydent unless it's the first run if new_config_file: @@ -268,17 +293,10 @@ def parse_config_dict(self, config_dict: Dict) -> None: for option in section_dict.keys(): config[section][option] = config_dict[section][option] - # Build a ConfigParser from the merged dictionary - cfg = ConfigParser() - for section, section_dict in config.items(): - cfg.add_section(section) - for option, value in section_dict.items(): - cfg.set(section, option, value) - # This is only ever called by tests so don't configure logging # as tests do this themselves - self._parse_from_config_parser(cfg) + self._parse_from_dict(config) def setup_logging(cfg: ConfigParser) -> None: diff --git a/sydent/config/_base.py b/sydent/config/_base.py index 1ce575a3..96a66d72 100644 --- a/sydent/config/_base.py +++ b/sydent/config/_base.py @@ -13,12 +13,15 @@ # limitations under the License. from abc import ABC, abstractmethod -from configparser import ConfigParser +from typing import Dict + +# The type of dict that the SydentConfigParser object get's converted into +CONFIG_PARSER_DICT = Dict[str, Dict[str, str]] class BaseConfig(ABC): @abstractmethod - def parse_config(self, cfg: ConfigParser) -> bool: + def parse_config(self, cfg: CONFIG_PARSER_DICT) -> bool: """ Parse the a section of the config @@ -29,3 +32,13 @@ def parse_config(self, cfg: ConfigParser) -> bool: config file. """ pass + + +def parse_cfg_bool(value: str): + """ + Parse a string config option into a boolean + This method ignores capitalisation + + :param value: the string to be parsed + """ + return value.lower() == "true" diff --git a/sydent/config/_configparser.py b/sydent/config/_configparser.py new file mode 100644 index 00000000..b4302df9 --- /dev/null +++ b/sydent/config/_configparser.py @@ -0,0 +1,58 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from configparser import BasicInterpolation, ConfigParser, Interpolation + + +class SydentInterpolation(Interpolation): + """Interpolation that uses BasicInterpolation with a blacklist""" + + # Options to never interpolate for backwards compatablility + BLACLIST = ["email.invite.subject", "email.invite.subject_space"] + + # The BasicInterpolation object to use + _basic_interpolation = BasicInterpolation() + + def before_get(self, parser, section, option, value, defaults): + if option in self.BLACLIST: + return value + else: + return self._basic_interpolation.before_get( + parser, section, option, value, defaults + ) + + def before_set(self, parser, section, option, value): + if option in self.BLACLIST: + return value + else: + return self._basic_interpolation.before_set(parser, section, option, value) + + def before_read(self, parser, section, option, value): + if option in self.BLACLIST: + return value + else: + return self._basic_interpolation.before_read(parser, section, option, value) + + def before_write(self, parser, section, option, value): + if option in self.BLACLIST: + return value + else: + return self._basic_interpolation.before_write( + parser, section, option, value + ) + + +class SydentConfigParser(ConfigParser): + + _DEFAULT_INTERPOLATION = SydentInterpolation() diff --git a/sydent/config/crypto.py b/sydent/config/crypto.py index eb85e5ca..692749b6 100644 --- a/sydent/config/crypto.py +++ b/sydent/config/crypto.py @@ -13,27 +13,26 @@ # limitations under the License. import logging -from configparser import ConfigParser import nacl import signedjson.key -from sydent.config._base import BaseConfig +from sydent.config._base import CONFIG_PARSER_DICT, BaseConfig logger = logging.getLogger(__name__) class CryptoConfig(BaseConfig): - def parse_config(self, cfg: "ConfigParser") -> bool: + def parse_config(self, cfg: CONFIG_PARSER_DICT) -> bool: """ Parse the crypto section of the config :param cfg: the configuration to be parsed """ + config = cfg.get("crypto") - signing_key_str = cfg.get("crypto", "ed25519.signingkey") - signing_key_parts = signing_key_str.split(" ") + signing_key_str = config.get("ed25519.signingkey") or None - if signing_key_str == "": + if signing_key_str is None: logger.warning( "'ed25519.signingkey' cannot be blank. Please generate a new" " signing key with the 'generate-key' script." @@ -42,7 +41,10 @@ def parse_config(self, cfg: "ConfigParser") -> bool: self.signing_key = signedjson.key.generate_signing_key("0") return True - elif len(signing_key_parts) == 1: + + signing_key_parts = signing_key_str.split(" ") + + if len(signing_key_parts) == 1: # old format key logger.warning( "Updating signing key format for this run. Please run the" diff --git a/sydent/config/database.py b/sydent/config/database.py index 0da8171d..1e1c50a4 100644 --- a/sydent/config/database.py +++ b/sydent/config/database.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from configparser import ConfigParser - -from sydent.config._base import BaseConfig +from sydent.config._base import CONFIG_PARSER_DICT, BaseConfig class DatabaseConfig(BaseConfig): - def parse_config(self, cfg: "ConfigParser") -> bool: + def parse_config(self, cfg: CONFIG_PARSER_DICT) -> bool: """ Parse the database section of the config :param cfg: the configuration to be parsed """ - self.database_path = cfg.get("db", "db.file") + config = cfg.get("db") + + self.database_path = config.get("db.file") return False diff --git a/sydent/config/email.py b/sydent/config/email.py index b6128b9b..b3567132 100644 --- a/sydent/config/email.py +++ b/sydent/config/email.py @@ -13,56 +13,54 @@ # limitations under the License. import socket -from configparser import ConfigParser from typing import Optional -from sydent.config._base import BaseConfig +from sydent.config._base import CONFIG_PARSER_DICT, BaseConfig class EmailConfig(BaseConfig): - def parse_config(self, cfg: "ConfigParser") -> bool: + def parse_config(self, cfg: CONFIG_PARSER_DICT) -> bool: """ Parse the email section of the config :param cfg: the configuration to be parsed """ + config = cfg.get("email") # These two options are deprecated - self.template: Optional[str] = cfg.get("email", "email.template", fallback=None) + self.template: Optional[str] = config.get("email.template", None) - self.invite_template = cfg.get("email", "email.invite_template", fallback=None) + self.invite_template = config.get("email.invite_template", None) # This isn't used anywhere... - self.validation_subject = cfg.get("email", "email.subject") + self.validation_subject = config.get("email.subject") - self.invite_subject = cfg.get("email", "email.invite.subject", raw=True) - self.invite_subject_space = cfg.get( - "email", "email.invite.subject_space", raw=True - ) + # Interpolation is turned off for these two options + # This allows them to use %(variable)s substitution without raising errors + self.invite_subject = config.get("email.invite.subject") + self.invite_subject_space = config.get("email.invite.subject_space") - self.smtp_server = cfg.get("email", "email.smtphost") - self.smtp_port = cfg.get("email", "email.smtpport") - self.smtp_username = cfg.get("email", "email.smtpusername") - self.smtp_password = cfg.get("email", "email.smtppassword") - self.tls_mode = cfg.get("email", "email.tlsmode") + self.smtp_server = config.get("email.smtphost") + self.smtp_port = config.get("email.smtpport") + self.smtp_username = config.get("email.smtpusername") + self.smtp_password = config.get("email.smtppassword") + self.tls_mode = config.get("email.tlsmode") # This is the fully qualified domain name for SMTP HELO/EHLO - self.host_name = cfg.get("email", "email.hostname") - if self.host_name == "": - self.host_name = socket.getfqdn() + self.host_name = config.get("email.hostname") or socket.getfqdn() - self.sender = cfg.get("email", "email.from") + self.sender = config.get("email.from") - self.default_web_client_location = cfg.get( - "email", "email.default_web_client_location" + self.default_web_client_location = config.get( + "email.default_web_client_location" ) - self.username_obfuscate_characters = cfg.getint( - "email", "email.third_party_invite_username_obfuscate_characters" + self.username_obfuscate_characters = int( + config.get("email.third_party_invite_username_obfuscate_characters") ) - self.domain_obfuscate_characters = cfg.getint( - "email", "email.third_party_invite_domain_obfuscate_characters" + self.domain_obfuscate_characters = int( + config.get("email.third_party_invite_domain_obfuscate_characters") ) return False diff --git a/sydent/config/general.py b/sydent/config/general.py index dab20847..d68b731e 100644 --- a/sydent/config/general.py +++ b/sydent/config/general.py @@ -14,27 +14,28 @@ import logging import os -from configparser import ConfigParser from typing import List from jinja2.environment import Environment from jinja2.loaders import FileSystemLoader -from sydent.config._base import BaseConfig +from sydent.config._base import CONFIG_PARSER_DICT, BaseConfig, parse_cfg_bool from sydent.util.ip_range import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set logger = logging.getLogger(__name__) class GeneralConfig(BaseConfig): - def parse_config(self, cfg: "ConfigParser") -> bool: + def parse_config(self, cfg: CONFIG_PARSER_DICT) -> bool: """ Parse the 'general' section of the config :param cfg: the configuration to be parsed """ - self.server_name = cfg.get("general", "server.name") - if self.server_name == "": + config = cfg.get("general") + + self.server_name = config.get("server.name") or None + if self.server_name is None: self.server_name = os.uname()[1] logger.warning( "You have not specified a server name. I have guessed that this server is called '%s'. " @@ -44,7 +45,7 @@ def parse_config(self, cfg: "ConfigParser") -> bool: # Get the possible brands by looking at directories under the # templates.path directory. - self.templates_path = cfg.get("general", "templates.path") + self.templates_path = config.get("templates.path") if os.path.exists(self.templates_path): self.valid_brands = { p @@ -60,40 +61,41 @@ def parse_config(self, cfg: "ConfigParser") -> bool: self.valid_brands = set() self.template_environment = Environment( - loader=FileSystemLoader(cfg.get("general", "templates.path")), + loader=FileSystemLoader(self.templates_path), autoescape=True, ) - self.default_brand = cfg.get("general", "brand.default") + self.default_brand = config.get("brand.default") - self.pidfile = cfg.get("general", "pidfile.path") + self.pidfile = config.get("pidfile.path") - self.terms_path = cfg.get("general", "terms.path") + self.terms_path = config.get("terms.path") - self.address_lookup_limit = cfg.getint("general", "address_lookup_limit") + self.address_lookup_limit = int(config.get("address_lookup_limit")) - self.prometheus_port = cfg.getint("general", "prometheus_port", fallback=None) - self.prometheus_addr = cfg.get("general", "prometheus_addr", fallback=None) - self.prometheus_enabled = ( - self.prometheus_port is not None and self.prometheus_addr is not None - ) + self.prometheus_port = config.get("prometheus_port", None) + self.prometheus_addr = config.get("prometheus_addr", None) - self.sentry_enabled = cfg.has_option("general", "sentry_dsn") - self.sentry_dsn = cfg.get("general", "sentry_dsn", fallback=None) + if self.prometheus_port is not None and self.prometheus_addr is not None: + self.prometheus_enabled = True + self.prometheus_port = int(self.prometheus_port) + else: + self.prometheus_enabled = False + + self.sentry_dsn = config.get("sentry_dsn", None) + self.sentry_enabled = self.sentry_dsn is not None self.enable_v1_associations = parse_cfg_bool( - cfg.get("general", "enable_v1_associations") + config.get("enable_v1_associations") ) - self.delete_tokens_on_bind = parse_cfg_bool( - cfg.get("general", "delete_tokens_on_bind") - ) + self.delete_tokens_on_bind = parse_cfg_bool(config.get("delete_tokens_on_bind")) - ip_blacklist = list_from_comma_sep_string(cfg.get("general", "ip.blacklist")) + ip_blacklist = list_from_comma_sep_string(config.get("ip.blacklist")) if not ip_blacklist: ip_blacklist = DEFAULT_IP_RANGE_BLACKLIST - ip_whitelist = list_from_comma_sep_string(cfg.get("general", "ip.whitelist")) + ip_whitelist = list_from_comma_sep_string(config.get("ip.whitelist")) self.ip_blacklist = generate_ip_set(ip_blacklist) self.ip_whitelist = generate_ip_set(ip_whitelist) @@ -110,13 +112,3 @@ def list_from_comma_sep_string(rawstr: str) -> List[str]: if rawstr == "": return [] return [x.strip() for x in rawstr.split(",")] - - -def parse_cfg_bool(value: str): - """ - Parse a string config option into a boolean - This method ignores capitalisation - - :param value: the string to be parsed - """ - return value.lower() == "true" diff --git a/sydent/config/http.py b/sydent/config/http.py index 141cb36d..8f3948b6 100644 --- a/sydent/config/http.py +++ b/sydent/config/http.py @@ -12,61 +12,58 @@ # See the License for the specific language governing permissions and # limitations under the License. -from configparser import ConfigParser -from typing import Optional -from sydent.config._base import BaseConfig +from sydent.config._base import CONFIG_PARSER_DICT, BaseConfig, parse_cfg_bool class HTTPConfig(BaseConfig): - def parse_config(self, cfg: "ConfigParser") -> bool: + def parse_config(self, cfg: CONFIG_PARSER_DICT) -> bool: """ Parse the http section of the config :param cfg: the configuration to be parsed """ + config = cfg.get("http") + # This option is deprecated - self.verify_response_template = cfg.get( - "http", "verify_response_template", fallback=None - ) + self.verify_response_template = config.get("verify_response_template", None) - self.client_bind_address = cfg.get("http", "clientapi.http.bind_address") - self.client_port = cfg.getint("http", "clientapi.http.port") + self.client_bind_address = config.get("clientapi.http.bind_address") + self.client_port = int(config.get("clientapi.http.port")) # internal port is allowed to be set to an empty string in the config - internal_api_port = cfg.get("http", "internalapi.http.port") - self.internal_bind_address = cfg.get( - "http", "internalapi.http.bind_address", fallback="::1" - ) + internal_api_port = config.get("internalapi.http.port") + self.internal_bind_address = config.get("internalapi.http.bind_address", "::1") + if internal_api_port != "": self.internal_api_enabled = True - self.internal_port: Optional[int] = int(internal_api_port) + self.internal_port = int(internal_api_port) else: self.internal_api_enabled = False - self.internal_port = None - self.cert_file = cfg.get("http", "replication.https.certfile") - self.ca_cert_file = cfg.get("http", "replication.https.cacert") + self.cert_file = config.get("replication.https.certfile") + self.ca_cert_file = config.get("replication.https.cacert") - self.replication_bind_address = cfg.get( - "http", "replication.https.bind_address" - ) - self.replication_port = cfg.getint("http", "replication.https.port") + self.replication_bind_address = config.get("replication.https.bind_address") + self.replication_port = int(config.get("replication.https.port")) - self.obey_x_forwarded_for = cfg.getboolean("http", "obey_x_forwarded_for") + self.obey_x_forwarded_for = parse_cfg_bool(config.get("obey_x_forwarded_for")) - self.verify_federation_certs = cfg.getboolean("http", "federation.verifycerts") + self.verify_federation_certs = parse_cfg_bool( + config.get("federation.verifycerts") + ) - self.server_http_url_base = cfg.get("http", "client_http_base") + self.server_http_url_base = config.get("client_http_base") self.base_replication_urls = {} - for section in cfg.sections(): + for section in cfg.keys(): if section.startswith("peer."): # peer name is all the characters after 'peer.' peer = section[5:] - if cfg.has_option(section, "base_replication_url"): - base_url = cfg.get(section, "base_replication_url") + peer_config = cfg.get(section) + if "base_replication_url" in peer_config.keys(): + base_url = peer_config.get("base_replication_url") self.base_replication_urls[peer] = base_url return False diff --git a/sydent/config/sms.py b/sydent/config/sms.py index b0b977cd..da427de2 100644 --- a/sydent/config/sms.py +++ b/sydent/config/sms.py @@ -12,33 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -from configparser import ConfigParser from typing import Dict, List -from sydent.config._base import BaseConfig +from sydent.config._base import CONFIG_PARSER_DICT, BaseConfig class SMSConfig(BaseConfig): - def parse_config(self, cfg: "ConfigParser") -> bool: + def parse_config(self, cfg: CONFIG_PARSER_DICT) -> bool: """ Parse the sms section of the config :param cfg: the configuration to be parsed """ - self.body_template = cfg.get("sms", "bodyTemplate") + config = cfg.get("sms") + + self.body_template = config.get("bodyTemplate") # Make sure username and password are bytes otherwise we can't use them with # b64encode. - self.api_username = cfg.get("sms", "username").encode("UTF-8") - self.api_password = cfg.get("sms", "password").encode("UTF-8") + self.api_username = config.get("username").encode("UTF-8") + self.api_password = config.get("password").encode("UTF-8") self.originators: Dict[str, List[Dict[str, str]]] = {} self.smsRules = {} - for opt in cfg.options("sms"): + for opt in config.keys(): if opt.startswith("originators."): country = opt.split(".")[1] - rawVal = cfg.get("sms", opt) + rawVal = config.get(opt) rawList = [i.strip() for i in rawVal.split(",")] self.originators[country] = [] @@ -60,7 +61,7 @@ def parse_config(self, cfg: "ConfigParser") -> bool: ) elif opt.startswith("smsrule."): country = opt.split(".")[1] - action = cfg.get("sms", opt) + action = config.get(opt) if action not in ["allow", "reject"]: raise Exception(