From 62c3fadf96f8baf93b503e3e4fc61dcdff25906f Mon Sep 17 00:00:00 2001 From: Benjamin Renard Date: Mon, 16 Jan 2023 12:56:12 +0100 Subject: [PATCH] Introduce pyupgrade,isort,black and configure pre-commit hooks to run all testing tools before commit --- .pre-commit-config.yaml | 39 ++ .pylintrc | 5 +- README.md | 35 ++ mylib/__init__.py | 41 +- mylib/config.py | 660 ++++++++++++------------ mylib/db.py | 173 +++---- mylib/email.py | 403 ++++++++------- mylib/ldap.py | 525 +++++++++++-------- mylib/mapping.py | 5 +- mylib/mysql.py | 47 +- mylib/opening_hours.py | 161 +++--- mylib/oracle.py | 29 +- mylib/pbar.py | 16 +- mylib/pgsql.py | 57 +- mylib/report.py | 92 ++-- mylib/scripts/email_test.py | 50 +- mylib/scripts/email_test_with_config.py | 39 +- mylib/scripts/helpers.py | 190 +++---- mylib/scripts/ldap_test.py | 155 ++++-- mylib/scripts/map_test.py | 6 +- mylib/scripts/pbar_test.py | 23 +- mylib/scripts/report_test.py | 30 +- mylib/scripts/sftp_test.py | 76 ++- mylib/sftp.py | 129 +++-- mylib/telltale.py | 32 +- setup.cfg | 1 + setup.py | 86 ++- tests.sh | 15 +- tests/test_config.py | 196 +++---- tests/test_mysql.py | 291 ++++++----- tests/test_opening_hours.py | 215 ++++---- tests/test_oracle.py | 272 +++++----- tests/test_pgsql.py | 279 +++++----- tests/test_telltale.py | 9 +- 34 files changed, 2356 insertions(+), 2026 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..23b00c0 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,39 @@ +# Pre-commit hooks to run tests and ensure code is cleaned. +# See https://pre-commit.com for more information +repos: +- repo: local + hooks: + - id: pytest + name: pytest + entry: python3 -m pytest tests + language: system + pass_filenames: false + always_run: true +- repo: local + hooks: + - id: pylint + name: pylint + entry: pylint --extension-pkg-whitelist=cx_Oracle + language: system + types: [python] + require_serial: true +- repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + args: ['--max-line-length=100'] +- repo: https://github.com/asottile/pyupgrade + rev: v3.3.1 + hooks: + - id: pyupgrade + args: ['--keep-percent-format', '--py37-plus'] +- repo: https://github.com/psf/black + rev: 22.12.0 + hooks: + - id: black + args: ['--target-version', 'py37', '--line-length', '100'] +- repo: https://github.com/PyCQA/isort + rev: 5.11.4 + hooks: + - id: isort + args: ['--profile', 'black', '--line-length', '100'] diff --git a/.pylintrc b/.pylintrc index 14f7ccf..b505075 100644 --- a/.pylintrc +++ b/.pylintrc @@ -8,5 +8,8 @@ disable=invalid-name, too-many-nested-blocks, too-many-instance-attributes, too-many-lines, - line-too-long, duplicate-code, + +[FORMAT] +# Maximum number of characters on a single line. +max-line-length=100 diff --git a/README.md b/README.md index 1e79d24..a48fd37 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,41 @@ Just run `python setup.py install` To know how to use these libs, you can take a look on *mylib.scripts* content or in *tests* directory. +## Code Style + +[pylint](https://pypi.org/project/pylint/) is used to check for errors and enforces a coding standard, using thoses parameters: + +```bash +pylint --extension-pkg-whitelist=cx_Oracle +``` + +[flake8](https://pypi.org/project/flake8/) is also used to check for errors and enforces a coding standard, using thoses parameters: + +```bash +flake8 --max-line-length=100 +``` + +[black](https://pypi.org/project/black/) is used to format the code, using thoses parameters: + +```bash +black --target-version py37 --line-length 100 +``` + +[isort](https://pypi.org/project/isort/) is used to format the imports, using those parameter: + +```bash +isort --profile black --line-length 100 +``` + +[pyupgrade](https://pypi.org/project/pyupgrade/) is used to automatically upgrade syntax, using those parameters: + +```bash +pyupgrade --keep-percent-format --py37-plus +``` + +**Note:** There is `.pre-commit-config.yaml` to use [pre-commit](https://pre-commit.com/) to automatically run these tools before commits. After cloning the repository, execute `pre-commit install` to install the git hook. + + ## Copyright Copyright (c) 2013-2021 Benjamin Renard diff --git a/mylib/__init__.py b/mylib/__init__.py index 6898a9f..8284814 100644 --- a/mylib/__init__.py +++ b/mylib/__init__.py @@ -6,12 +6,12 @@ def increment_prefix(prefix): - """ Increment the given prefix with two spaces """ + """Increment the given prefix with two spaces""" return f'{prefix if prefix else " "} ' -def pretty_format_value(value, encoding='utf8', prefix=None): - """ Returned pretty formated value to display """ +def pretty_format_value(value, encoding="utf8", prefix=None): + """Returned pretty formated value to display""" if isinstance(value, dict): return pretty_format_dict(value, encoding=encoding, prefix=prefix) if isinstance(value, list): @@ -22,10 +22,10 @@ def pretty_format_value(value, encoding='utf8', prefix=None): return f"'{value}'" if value is None: return "None" - return f'{value} ({type(value)})' + return f"{value} ({type(value)})" -def pretty_format_value_in_list(value, encoding='utf8', prefix=None): +def pretty_format_value_in_list(value, encoding="utf8", prefix=None): """ Returned pretty formated value to display in list @@ -34,42 +34,31 @@ def pretty_format_value_in_list(value, encoding='utf8', prefix=None): """ prefix = prefix if prefix else "" value = pretty_format_value(value, encoding, prefix) - if '\n' in value: + if "\n" in value: inc_prefix = increment_prefix(prefix) - value = "\n" + "\n".join([ - inc_prefix + line - for line in value.split('\n') - ]) + value = "\n" + "\n".join([inc_prefix + line for line in value.split("\n")]) return value -def pretty_format_dict(value, encoding='utf8', prefix=None): - """ Returned pretty formated dict to display """ +def pretty_format_dict(value, encoding="utf8", prefix=None): + """Returned pretty formated dict to display""" prefix = prefix if prefix else "" result = [] for key in sorted(value.keys()): result.append( - f'{prefix}- {key} : ' - + pretty_format_value_in_list( - value[key], - encoding=encoding, - prefix=prefix - ) + f"{prefix}- {key} : " + + pretty_format_value_in_list(value[key], encoding=encoding, prefix=prefix) ) return "\n".join(result) -def pretty_format_list(row, encoding='utf8', prefix=None): - """ Returned pretty formated list to display """ +def pretty_format_list(row, encoding="utf8", prefix=None): + """Returned pretty formated list to display""" prefix = prefix if prefix else "" result = [] for idx, values in enumerate(row): result.append( - f'{prefix}- #{idx} : ' - + pretty_format_value_in_list( - values, - encoding=encoding, - prefix=prefix - ) + f"{prefix}- #{idx} : " + + pretty_format_value_in_list(values, encoding=encoding, prefix=prefix) ) return "\n".join(result) diff --git a/mylib/config.py b/mylib/config.py index c7dbfba..df0c461 100644 --- a/mylib/config.py +++ b/mylib/config.py @@ -1,12 +1,8 @@ -# -*- coding: utf-8 -*- # pylint: disable=too-many-lines """ Configuration & options parser """ import argparse -from configparser import ConfigParser -from getpass import getpass -from logging.config import fileConfig import logging import os import re @@ -14,26 +10,38 @@ import stat import sys import textwrap import traceback +from configparser import ConfigParser +from getpass import getpass +from logging.config import fileConfig import argcomplete import keyring from systemd.journal import JournalHandler - log = logging.getLogger(__name__) # Constants -DEFAULT_ENCODING = 'utf-8' -DEFAULT_CONFIG_DIRPATH = os.path.expanduser('./') -DEFAULT_CONSOLE_LOG_FORMAT = '%(asctime)s - %(module)s:%(lineno)d - %(levelname)s - %(message)s' +DEFAULT_ENCODING = "utf-8" +DEFAULT_CONFIG_DIRPATH = os.path.expanduser("./") +DEFAULT_CONSOLE_LOG_FORMAT = "%(asctime)s - %(module)s:%(lineno)d - %(levelname)s - %(message)s" class BaseOption: # pylint: disable=too-many-instance-attributes - """ Base configuration option class """ + """Base configuration option class""" - def __init__(self, config, section, name, default=None, comment=None, - arg=None, short_arg=None, arg_help=None, no_arg=False): + def __init__( + self, + config, + section, + name, + default=None, + comment=None, + arg=None, + short_arg=None, + arg_help=None, + no_arg=False, + ): self.config = config self.section = section self.name = name @@ -47,53 +55,43 @@ class BaseOption: # pylint: disable=too-many-instance-attributes @property def _isset_in_options(self): - """ Check if option is defined in registered arguments parser options """ - return ( - self.config.options - and not self.no_arg - and self._from_options != self.default - ) + """Check if option is defined in registered arguments parser options""" + return self.config.options and not self.no_arg and self._from_options != self.default @property def _from_options(self): - """ Get option from arguments parser options """ + """Get option from arguments parser options""" value = ( getattr(self.config.options, self.parser_dest) - if self.config.options and not self.no_arg else None - ) - log.debug( - '_from_options(%s, %s) = %s', - self.section.name, self.name, value + if self.config.options and not self.no_arg + else None ) + log.debug("_from_options(%s, %s) = %s", self.section.name, self.name, value) return value @property def _isset_in_config_file(self): - """ Check if option is defined in the loaded configuration file """ - return ( - self.config.config_filepath - and self.config.config_parser.has_option(self.section.name, self.name) + """Check if option is defined in the loaded configuration file""" + return self.config.config_filepath and self.config.config_parser.has_option( + self.section.name, self.name ) @property def _from_config(self): - """ Get option value from ConfigParser """ + """Get option value from ConfigParser""" return self.config.config_parser.get(self.section.name, self.name) @property def _default_in_config(self): - """ Get option default value considering current value from configuration """ - return ( - self._from_config if self._isset_in_config_file - else self.default - ) + """Get option default value considering current value from configuration""" + return self._from_config if self._isset_in_config_file else self.default def isset(self): - """ Check if option is defined in the loaded configuration file """ + """Check if option is defined in the loaded configuration file""" return self._isset_in_config_file or self._isset_in_options def get(self): - """ Get option value from options, config or default """ + """Get option value from options, config or default""" if self._isset_in_options and not self._set: return self._from_options if self._isset_in_config_file: @@ -101,49 +99,45 @@ class BaseOption: # pylint: disable=too-many-instance-attributes return self.default def set(self, value): - """ Set option value to config file and options """ - if value == '': + """Set option value to config file and options""" + if value == "": value = None if value == self.default or value is None: # Remove option from config (is section exists) if self.config.config_parser.has_section(self.section.name): - self.config.config_parser.remove_option( - self.section.name, self.name) + self.config.config_parser.remove_option(self.section.name, self.name) else: # Store option to config if not self.config.config_parser.has_section(self.section.name): self.config.config_parser.add_section(self.section.name) - self.config.config_parser.set( - self.section.name, self.name, self.to_config(value) - ) + self.config.config_parser.set(self.section.name, self.name, self.to_config(value)) self._set = True @property def parser_action(self): - """ Get action as accept by argparse.ArgumentParser """ - return 'store' + """Get action as accept by argparse.ArgumentParser""" + return "store" @property def parser_type(self): - """ Get type as handle by argparse.ArgumentParser """ + """Get type as handle by argparse.ArgumentParser""" return str @property def parser_dest(self): - """ Get option name in arguments parser options """ - return f'{self.section.name}_{self.name}' + """Get option name in arguments parser options""" + return f"{self.section.name}_{self.name}" @property def parser_help(self): - """ Get option help message in arguments parser options """ + """Get option help message in arguments parser options""" if self.arg_help and self.default is not None: # pylint: disable=consider-using-f-string - return '{0} (Default: {1})'.format( - self.arg_help, - re.sub(r'%([^%])', r'%%\1', str(self._default_in_config)) + return "{} (Default: {})".format( + self.arg_help, re.sub(r"%([^%])", r"%%\1", str(self._default_in_config)) ) if self.arg_help: return self.arg_help @@ -151,14 +145,13 @@ class BaseOption: # pylint: disable=too-many-instance-attributes @property def parser_argument_name(self): - """ Get option argument name in parser options """ + """Get option argument name in parser options""" return ( - self.arg if self.arg else - f'--{self.section.name}-{self.name}'.lower().replace('_', '-') + self.arg if self.arg else f"--{self.section.name}-{self.name}".lower().replace("_", "-") ) def add_option_to_parser(self, section_opts): - """ Add option to arguments parser """ + """Add option to arguments parser""" if self.no_arg: return args = [self.parser_argument_name] @@ -171,53 +164,53 @@ class BaseOption: # pylint: disable=too-many-instance-attributes default=self.default, ) if self.parser_type: # pylint: disable=using-constant-test - kwargs['type'] = self.parser_type + kwargs["type"] = self.parser_type log.debug( - 'add_option_to_parser(%s, %s): argument name(s)=%s / kwargs=%s', - self.section.name, self.name, ', '.join(args), kwargs + "add_option_to_parser(%s, %s): argument name(s)=%s / kwargs=%s", + self.section.name, + self.name, + ", ".join(args), + kwargs, ) section_opts.add_argument(*args, **kwargs) def to_config(self, value=None): - """ Format value as stored in configuration file """ + """Format value as stored in configuration file""" value = value if value is not None else self.get() - return '' if value is None else str(value) + return "" if value is None else str(value) def export_to_config(self): - """ Export option to configuration file """ + """Export option to configuration file""" lines = [] if self.comment: - lines.append(f'# {self.comment}') + lines.append(f"# {self.comment}") value = self.to_config() - default_value = ( - '' if self.default is None else - self.to_config(self.default) - ) + default_value = "" if self.default is None else self.to_config(self.default) log.debug( - 'export_to_config(%s, %s): value=%s / default=%s', - self.section.name, self.name, value, default_value) + "export_to_config(%s, %s): value=%s / default=%s", + self.section.name, + self.name, + value, + default_value, + ) if default_value: - if isinstance(default_value, str) and '\n' in default_value: - lines.append('# Default:') - lines.extend([f'# {line}' for line in default_value.split('\n')]) + if isinstance(default_value, str) and "\n" in default_value: + lines.append("# Default:") + lines.extend([f"# {line}" for line in default_value.split("\n")]) else: - lines.append( - f'# Default: {default_value}' - ) + lines.append(f"# Default: {default_value}") if value and value != default_value: - if isinstance(value, str) and '\n' in value: - value_lines = value.split('\n') - lines.append(f'# Default: {value_lines[0]}') + if isinstance(value, str) and "\n" in value: + value_lines = value.split("\n") + lines.append(f"# Default: {value_lines[0]}") lines.extend([f' {line.replace("#", "%(hash)s")}' for line in value_lines[1:]]) else: - lines.append( - f'{self.name} = {value}' - ) + lines.append(f"{self.name} = {value}") else: - lines.append(f'# {self.name} =') - lines.append('') - return '\n'.join(lines) + lines.append(f"# {self.name} =") + lines.append("") + return "\n".join(lines) @staticmethod def _get_user_input(prompt): @@ -228,21 +221,21 @@ class BaseOption: # pylint: disable=too-many-instance-attributes return input(prompt) def _ask_value(self, prompt=None, **kwargs): - """ Ask to user to enter value of this option and return it """ + """Ask to user to enter value of this option and return it""" if self.comment: - print(f'# {self.comment}') - default_value = kwargs.get('default_value', self.get()) + print(f"# {self.comment}") + default_value = kwargs.get("default_value", self.get()) if not prompt: - prompt = f'{self.name}: ' + prompt = f"{self.name}: " if default_value is not None: - if isinstance(default_value, str) and '\n' in default_value: + if isinstance(default_value, str) and "\n" in default_value: prompt += "[\n %s\n] " % "\n ".join( - self.to_config(default_value).split('\n') + self.to_config(default_value).split("\n") ) else: - prompt += f'[{self.to_config(default_value)}] ' + prompt += f"[{self.to_config(default_value)}] " value = self._get_user_input(prompt) - return default_value if value == '' else value + return default_value if value == "" else value def ask_value(self, set_it=True): """ @@ -256,24 +249,24 @@ class BaseOption: # pylint: disable=too-many-instance-attributes class StringOption(BaseOption): - """ String configuration option class """ + """String configuration option class""" class BooleanOption(BaseOption): - """ Boolean configuration option class """ + """Boolean configuration option class""" @property def _from_config(self): - """ Get option value from ConfigParser """ + """Get option value from ConfigParser""" return self.config.config_parser.getboolean(self.section.name, self.name) def to_config(self, value=None): - """ Format value as stored in configuration file """ + """Format value as stored in configuration file""" return super().to_config(value).lower() @property def _isset_in_options(self): - """ Check if option is defined in registered arguments parser options """ + """Check if option is defined in registered arguments parser options""" return ( self.config.options and not self.no_arg @@ -282,11 +275,8 @@ class BooleanOption(BaseOption): @property def _from_options(self): - """ Get option from arguments parser options """ - return ( - not self._default_in_config if self._isset_in_options - else self._default_in_config - ) + """Get option from arguments parser options""" + return not self._default_in_config if self._isset_in_options else self._default_in_config @property def parser_action(self): @@ -298,43 +288,44 @@ class BooleanOption(BaseOption): @property def parser_argument_name(self): - """ Get option argument name in parser options """ + """Get option argument name in parser options""" # pylint: disable=consider-using-f-string return ( - self.arg if self.arg else - '--{0}-{1}-{2}'.format( - self.section.name, - 'enable' if not self._default_in_config else 'disable', - self.name - ).lower().replace('_', '-') + self.arg + if self.arg + else "--{}-{}-{}".format( + self.section.name, "enable" if not self._default_in_config else "disable", self.name + ) + .lower() + .replace("_", "-") ) def _ask_value(self, prompt=None, **kwargs): - """ Ask to user to enter value of this option and return it """ + """Ask to user to enter value of this option and return it""" default_value = self.get() - prompt = f'{self.name}: ' + prompt = f"{self.name}: " if default_value: - prompt += '[Y/n] ' + prompt += "[Y/n] " else: - prompt += '[y/N] ' + prompt += "[y/N] " while True: value = super()._ask_value(prompt, **kwargs) - if value in ['', None, default_value]: + if value in ["", None, default_value]: return default_value - if value.lower() == 'y': + if value.lower() == "y": return True - if value.lower() == 'n': + if value.lower() == "n": return False - print('Invalid answer. Possible values: Y or N (case insensitive)') + print("Invalid answer. Possible values: Y or N (case insensitive)") class FloatOption(BaseOption): - """ Float configuration option class """ + """Float configuration option class""" @property def _from_config(self): - """ Get option value from ConfigParser """ + """Get option value from ConfigParser""" return self.config.config_parser.getfloat(self.section.name, self.name) @property @@ -342,11 +333,11 @@ class FloatOption(BaseOption): return float def _ask_value(self, prompt=None, **kwargs): - """ Ask to user to enter value of this option and return it """ + """Ask to user to enter value of this option and return it""" default_value = self.get() while True: value = super()._ask_value(prompt, **kwargs) - if value in ['', None, default_value]: + if value in ["", None, default_value]: return default_value try: return float(value) @@ -355,58 +346,55 @@ class FloatOption(BaseOption): class IntegerOption(BaseOption): - """ Integer configuration option class """ + """Integer configuration option class""" @property def _from_config(self): - """ Get option value from ConfigParser """ + """Get option value from ConfigParser""" return self.config.config_parser.getint(self.section.name, self.name) def to_config(self, value=None): - """ Format value as stored in configuration file """ + """Format value as stored in configuration file""" value = value if value is not None else self.get() - return str(int(value)) if value is not None else '' + return str(int(value)) if value is not None else "" @property def parser_type(self): return int def _ask_value(self, prompt=None, **kwargs): - """ Ask to user to enter value of this option and return it """ - default_value = kwargs.pop('default_value', self.get()) + """Ask to user to enter value of this option and return it""" + default_value = kwargs.pop("default_value", self.get()) while True: value = super()._ask_value(prompt, default_value=default_value, **kwargs) - if value in ['', None, default_value]: + if value in ["", None, default_value]: return default_value try: return int(value) except ValueError: - print('Invalid answer. Must a integer value') + print("Invalid answer. Must a integer value") class PasswordOption(StringOption): - """ Password configuration option class """ + """Password configuration option class""" def __init__(self, *arg, username_option=None, keyring_value=None, **kwargs): super().__init__(*arg, **kwargs) self.username_option = username_option - self.keyring_value = keyring_value if keyring_value is not None else 'keyring' + self.keyring_value = keyring_value if keyring_value is not None else "keyring" @property def _keyring_service_name(self): - """ Return keyring service name """ - return '.'.join([ - self.config.shortname, - self.section.name, self.name - ]) + """Return keyring service name""" + return ".".join([self.config.shortname, self.section.name, self.name]) @property def _keyring_username(self): - """ Return keyring username """ + """Return keyring username""" return self.section.get(self.username_option) if self.username_option else self.name def get(self): - """ Get option value """ + """Get option value""" value = super().get() if value != self.keyring_value: @@ -414,53 +402,48 @@ class PasswordOption(StringOption): service_name = self._keyring_service_name username = self._keyring_username - log.debug( - 'Retreive password %s for username=%s from keyring', - service_name, username - ) + log.debug("Retreive password %s for username=%s from keyring", service_name, username) value = keyring.get_password(service_name, username) if value is None: # pylint: disable=consider-using-f-string value = getpass( - 'Please enter {0}{1}: '.format( - f'{self.section.name} {self.name}', - f' for {username}' if username != self.name else '' + "Please enter {}{}: ".format( + f"{self.section.name} {self.name}", + f" for {username}" if username != self.name else "", ) ) keyring.set_password(service_name, username, value) return value def to_config(self, value=None): - """ Format value as stored in configuration file """ + """Format value as stored in configuration file""" if super().get() == self.keyring_value: return self.keyring_value return super().to_config(value) def set(self, value, use_keyring=None): # pylint: disable=arguments-differ - """ Set option value to config file """ + """Set option value to config file""" if (use_keyring is None and super().get() == self.keyring_value) or use_keyring: - keyring.set_password( - self._keyring_service_name, self._keyring_username, - value) + keyring.set_password(self._keyring_service_name, self._keyring_username, value) value = self.keyring_value super().set(value) def _ask_value(self, prompt=None, **kwargs): - """ Ask to user to enter value of this option and return it """ + """Ask to user to enter value of this option and return it""" if self.comment: - print('# ' + self.comment) - default_value = kwargs.pop('default_value', self.get()) + print("# " + self.comment) + default_value = kwargs.pop("default_value", self.get()) if not prompt: - prompt = f'{self.name}: ' + prompt = f"{self.name}: " if default_value is not None: # Hide value only if it differed from default value if default_value == self.default: - prompt += f'[{default_value}] ' + prompt += f"[{default_value}] " else: - prompt += '[secret defined, leave to empty to keep it as unchange] ' + prompt += "[secret defined, leave to empty to keep it as unchange] " value = getpass(prompt) - return default_value if value == '' else value + return default_value if value == "" else value def ask_value(self, set_it=True): """ @@ -470,27 +453,26 @@ class PasswordOption(StringOption): value = self._ask_value() if set_it: use_keyring = None - default_use_keyring = (super().get() == self.keyring_value) + default_use_keyring = super().get() == self.keyring_value while use_keyring is None: prompt = ( - 'Do you want to use XDG keyring ? ' - f"[{'Y/n' if default_use_keyring else 'y/N'}] " + f"Do you want to use XDG keyring ? [{'Y/n' if default_use_keyring else 'y/N'}] " ) result = input(prompt).lower() - if result == '': + if result == "": use_keyring = default_use_keyring - elif result == 'y': + elif result == "y": use_keyring = True - elif result == 'n': + elif result == "n": use_keyring = False else: - print('Invalid answer. Possible values: Y or N (case insensitive)') + print("Invalid answer. Possible values: Y or N (case insensitive)") return self.set(value, use_keyring=use_keyring) return value class ConfigSection: - """ Configuration section class """ + """Configuration section class""" def __init__(self, config, name, comment=None, order=None): self.config = config @@ -507,48 +489,47 @@ class ConfigSection: :param name: Option name :param **kwargs: Dict of raw option for type class """ - assert not self.defined(name), f'Duplicated option {name}' + assert not self.defined(name), f"Duplicated option {name}" self.options[name] = _type(self.config, self, name, **kwargs) return self.options[name] def defined(self, option): - """ Check if option is defined """ + """Check if option is defined""" return option in self.options def isset(self, option): - """ Check if option is set """ + """Check if option is set""" return self.defined(option) and self.options[option].isset() def get(self, option): - """ Get option value """ - assert self.defined(option), f'Option {option} unknown' + """Get option value""" + assert self.defined(option), f"Option {option} unknown" return self.options[option].get() def set(self, option, value): - """ Set option value """ - assert self.defined(option), f'Option {option} unknown' + """Set option value""" + assert self.defined(option), f"Option {option} unknown" return self.options[option].set(value) def add_options_to_parser(self, parser): - """ Add section to argparse.ArgumentParser """ + """Add section to argparse.ArgumentParser""" assert isinstance(parser, argparse.ArgumentParser) section_opts = parser.add_argument_group( - self.comment if self.comment - else self.name.capitalize() + self.comment if self.comment else self.name.capitalize() ) for option in self.options: # pylint: disable=consider-using-dict-items self.options[option].add_option_to_parser(section_opts) def export_to_config(self): - """ Export section and their options to configuration file """ + """Export section and their options to configuration file""" lines = [] if self.comment: - lines.append(f'# {self.comment}') - lines.append(f'[{self.name}]') + lines.append(f"# {self.comment}") + lines.append(f"[{self.name}]") for option in self.options: # pylint: disable=consider-using-dict-items lines.append(self.options[option].export_to_config()) - return '\n'.join(lines) + return "\n".join(lines) def ask_values(self, set_it=True): """ @@ -562,8 +543,8 @@ class ConfigSection: :rtype: bool of dict """ if self.comment: - print(f'# {self.comment}') - print(f'[{self.name}]''\n') + print(f"# {self.comment}") + print(f"[{self.name}]\n") result = {} error = False for name, option in self.options.items(): @@ -592,7 +573,7 @@ class RawWrappedTextHelpFormatter(argparse.RawDescriptionHelpFormatter): result.append(line) continue # Split ident prefix and line text - m = re.match('^( *)(.*)$', line) + m = re.match("^( *)(.*)$", line) ident = m.group(1) line_text = m.group(2) # Wrap each lines and add in result with ident prefix @@ -602,14 +583,21 @@ class RawWrappedTextHelpFormatter(argparse.RawDescriptionHelpFormatter): class Config: # pylint: disable=too-many-instance-attributes - """ Configuration helper """ + """Configuration helper""" - def __init__(self, appname, shortname=None, version=None, encoding=None, - config_file_env_variable=None, default_config_dirpath=None): + def __init__( + self, + appname, + shortname=None, + version=None, + encoding=None, + config_file_env_variable=None, + default_config_dirpath=None, + ): self.appname = appname self.shortname = shortname - self.version = version if version else '0.0' - self.encoding = encoding if encoding else 'utf-8' + self.version = version if version else "0.0" + self.encoding = encoding if encoding else "utf-8" self.config_parser = None self.options_parser = None self.options = None @@ -626,11 +614,12 @@ class Config: # pylint: disable=too-many-instance-attributes Add section : param name: The section name - : param loaded_callback: An optional callback method that will be executed after configuration is loaded - The specified callback method will receive Config object as parameter. + : param loaded_callback: An optional callback method that will be executed after + configuration is loaded. The specified callback method will receive + Config object as parameter. : param ** kwargs: Raw parameters dict pass to ConfigSection __init__() method """ - assert name not in self.sections, f'Duplicated section {name}' + assert name not in self.sections, f"Duplicated section {name}" self.sections[name] = ConfigSection(self, name, **kwargs) if loaded_callback: @@ -641,45 +630,43 @@ class Config: # pylint: disable=too-many-instance-attributes return self.sections[name] def defined(self, section, option): - """ Check option is defined in specified section """ + """Check option is defined in specified section""" return section in self.sections and self.sections[section].defined(option) def isset(self, section, option): - """ Check option is set in specified section """ + """Check option is set in specified section""" return section in self.sections and self.sections[section].isset(option) def get(self, section, option): - """ Get option value """ - assert self.defined( - section, option), f'Unknown option {section}.{option}' + """Get option value""" + assert self.defined(section, option), f"Unknown option {section}.{option}" value = self.sections[section].get(option) - log.debug('get(%s, %s): %s (%s)', section, option, value, type(value)) + log.debug("get(%s, %s): %s (%s)", section, option, value, type(value)) return value def get_option(self, option, default=None): - """ Get an argument parser option value """ + """Get an argument parser option value""" if self.options and hasattr(self.options, option): return getattr(self.options, option) return default def __getitem__(self, key): - assert key in self.sections, f'Unknown section {key}' + assert key in self.sections, f"Unknown section {key}" return ConfigSectionAsDictWrapper(self.sections[key]) def set(self, section, option, value): - """ Set option value """ - assert self.defined( - section, option), f'Unknown option {section}.{option}' + """Set option value""" + assert self.defined(section, option), f"Unknown option {section}.{option}" self._init_config_parser() self.sections[section].set(option, value) def _init_config_parser(self, force=False): - """ Initialize ConfigParser object """ + """Initialize ConfigParser object""" if not self.config_parser or force: - self.config_parser = ConfigParser(defaults={'hash': '#'}) + self.config_parser = ConfigParser(defaults={"hash": "#"}) def load_file(self, filepath, execute_callback=True): - """ Read configuration file """ + """Read configuration file""" self._init_config_parser(force=True) self._filepath = filepath @@ -694,20 +681,16 @@ class Config: # pylint: disable=too-many-instance-attributes self.config_parser.read(filepath, encoding=self.encoding) except Exception: # pylint: disable=broad-except self._init_config_parser(force=True) - log.exception('Failed to read configuration file %s', filepath) + log.exception("Failed to read configuration file %s", filepath) return False # Logging initialization - if self.config_parser.has_section('loggers'): + if self.config_parser.has_section("loggers"): fileConfig(filepath) else: # Otherwise, use systemd journal handler handler = JournalHandler(SYSLOG_IDENTIFIER=self.shortname) - handler.setFormatter( - logging.Formatter( - '%(levelname)s | %(name)s | %(message)s' - ) - ) + handler.setFormatter(logging.Formatter("%(levelname)s | %(name)s | %(message)s")) logging.getLogger().addHandler(handler) self._filepath = filepath @@ -718,7 +701,7 @@ class Config: # pylint: disable=too-many-instance-attributes return True def _loaded(self): - """ Execute loaded callbacks """ + """Execute loaded callbacks""" error = False for callback in self._loaded_callbacks: if callback in self._loaded_callbacks_executed: @@ -729,9 +712,9 @@ class Config: # pylint: disable=too-many-instance-attributes return not error def save(self, filepath=None): - """ Save configuration file """ + """Save configuration file""" filepath = filepath if filepath else self._filepath - assert filepath, 'Configuration filepath is not set or provided' + assert filepath, "Configuration filepath is not set or provided" # Checking access of target file/directory dirpath = os.path.dirname(filepath) @@ -740,115 +723,115 @@ class Config: # pylint: disable=too-many-instance-attributes log.error('Configuration file "%s" is not writable', filepath) return False elif not os.path.isdir(dirpath) or not os.access(dirpath, os.R_OK | os.W_OK | os.X_OK): - log.error( - 'Configuration directory "%s" does not exist (or not writable)', dirpath) + log.error('Configuration directory "%s" does not exist (or not writable)', dirpath) return False - lines = [ - '#\n' - f'# {self.appname} configuration' - '\n#\n' - ] + lines = [f"#\n# {self.appname} configuration\n#\n"] for section_name in self._ordered_section_names: - lines.append('') + lines.append("") lines.append(self.sections[section_name].export_to_config()) try: - with open(filepath, 'wb') as fd: - fd.write( - '\n'.join(lines).encode(self.encoding) - ) + with open(filepath, "wb") as fd: + fd.write("\n".join(lines).encode(self.encoding)) # Privacy! os.chmod(filepath, stat.S_IRUSR | stat.S_IWUSR) except Exception: # pylint: disable=broad-except - log.exception( - 'Failed to write generated configuration file %s', filepath) + log.exception("Failed to write generated configuration file %s", filepath) return False self.load_file(filepath) return True @property def _ordered_section_names(self): - """ Get ordered list of section names """ + """Get ordered list of section names""" return sorted(self.sections.keys(), key=lambda section: self.sections[section].order) def get_arguments_parser(self, reconfigure=False, **kwargs): - """ Get arguments parser """ + """Get arguments parser""" if self.options_parser: return self.options_parser self.options_parser = argparse.ArgumentParser( - description=kwargs.pop('description', self.appname), + description=kwargs.pop("description", self.appname), formatter_class=RawWrappedTextHelpFormatter, - **kwargs) - - config_file_help = ( - f'Configuration file to use (default: {self.config_filepath})' + **kwargs, ) + + config_file_help = f"Configuration file to use (default: {self.config_filepath})" if self.config_file_env_variable: config_file_help += ( - '\n\nYou also could set ' - f'{self.config_file_env_variable} environment variable to ' - 'specify your configuration file path.' + "\n\nYou also could set " + f"{self.config_file_env_variable} environment variable to " + "specify your configuration file path." ) self.options_parser.add_argument( - '-c', - '--config', - default=self.config_filepath, - help=config_file_help + "-c", "--config", default=self.config_filepath, help=config_file_help ) self.options_parser.add_argument( - '--save', - action='store_true', - dest='save', - help='Save current configuration to file', + "--save", + action="store_true", + dest="save", + help="Save current configuration to file", ) if reconfigure: self.options_parser.add_argument( - '--reconfigure', - action='store_true', - dest='mylib_config_reconfigure', - help='Reconfigure and update configuration file', + "--reconfigure", + action="store_true", + dest="mylib_config_reconfigure", + help="Reconfigure and update configuration file", ) self.options_parser.add_argument( - '-d', - '--debug', - action='store_true', - help='Show debug messages' + "-d", "--debug", action="store_true", help="Show debug messages" ) self.options_parser.add_argument( - '-v', - '--verbose', - action='store_true', - help='Show verbose messages' + "-v", "--verbose", action="store_true", help="Show verbose messages" ) - section = self.add_section('console', comment='Console logging') + section = self.add_section("console", comment="Console logging") section.add_option( - BooleanOption, 'enabled', default=False, - arg='--console', short_arg='-C', - comment='Enable/disable console log') + BooleanOption, + "enabled", + default=False, + arg="--console", + short_arg="-C", + comment="Enable/disable console log", + ) section.add_option( - BooleanOption, 'force_stderr', default=False, - arg='--console-stderr', - comment='Force console log on stderr') + BooleanOption, + "force_stderr", + default=False, + arg="--console-stderr", + comment="Force console log on stderr", + ) section.add_option( - StringOption, 'log_format', default=DEFAULT_CONSOLE_LOG_FORMAT, - arg='--console-log-format', comment='Console log format') + StringOption, + "log_format", + default=DEFAULT_CONSOLE_LOG_FORMAT, + arg="--console-log-format", + comment="Console log format", + ) self.add_options_to_parser(self.options_parser) return self.options_parser - def parse_arguments_options(self, argv=None, parser=None, create=True, ask_values=True, - exit_after_created=True, execute_callback=True, - hardcoded_options=None): + def parse_arguments_options( + self, + argv=None, + parser=None, + create=True, + ask_values=True, + exit_after_created=True, + execute_callback=True, + hardcoded_options=None, + ): """ Parse arguments options @@ -879,10 +862,7 @@ class Config: # pylint: disable=too-many-instance-attributes already_saved = False if not os.path.isfile(options.config) and (create or options.save): - log.warning( - "Configuration file is missing, generate it (%s)", - options.config - ) + log.warning("Configuration file is missing, generate it (%s)", options.config) if ask_values: self.ask_values(set_it=True) self.save(options.config) @@ -891,10 +871,10 @@ class Config: # pylint: disable=too-many-instance-attributes already_saved = True # Load configuration file - if os.path.isfile(options.config) and not self.load_file(options.config, execute_callback=False): - parser.error( - f'Failed to load configuration from file {options.config}' - ) + if os.path.isfile(options.config) and not self.load_file( + options.config, execute_callback=False + ): + parser.error(f"Failed to load configuration from file {options.config}") if options.save and not already_saved: self.save() @@ -907,27 +887,28 @@ class Config: # pylint: disable=too-many-instance-attributes if hardcoded_options: assert isinstance(hardcoded_options, list), ( - 'hardcoded_options must be a list of tuple of 3 elements: ' - 'the section and the option names and the value.') + "hardcoded_options must be a list of tuple of 3 elements: " + "the section and the option names and the value." + ) for opt_info in hardcoded_options: assert isinstance(opt_info, tuple) and len(opt_info) == 3, ( - 'Invalid hard-coded option value: it must be a tuple of 3 ' - 'elements: the section and the option names and the value.' + "Invalid hard-coded option value: it must be a tuple of 3 " + "elements: the section and the option names and the value." ) self.set(*opt_info) - if self.get('console', 'enabled'): + if self.get("console", "enabled"): stdout_console_handler = logging.StreamHandler( - sys.stderr if self.get('console', 'force_stderr') - else sys.stdout) + sys.stderr if self.get("console", "force_stderr") else sys.stdout + ) stdout_console_handler.addFilter(StdoutInfoFilter()) stdout_console_handler.setLevel(logging.DEBUG) stderr_console_handler = logging.StreamHandler(sys.stderr) stderr_console_handler.setLevel(logging.WARNING) - if self.get('console', 'log_format'): - console_formater = logging.Formatter(self.get('console', 'log_format')) + if self.get("console", "log_format"): + console_formater = logging.Formatter(self.get("console", "log_format")) stdout_console_handler.setFormatter(console_formater) stderr_console_handler.setFormatter(console_formater) @@ -937,7 +918,7 @@ class Config: # pylint: disable=too-many-instance-attributes if execute_callback: self._loaded() - if self.get_option('mylib_config_reconfigure', default=False): + if self.get_option("mylib_config_reconfigure", default=False): if self.ask_values(set_it=True) and self.save(): sys.exit(0) sys.exit(1) @@ -945,15 +926,15 @@ class Config: # pylint: disable=too-many-instance-attributes return options def load_options(self, options, execute_callback=True): - """ Register arguments parser options """ + """Register arguments parser options""" assert isinstance(options, argparse.Namespace) self.options = options - log.debug('Argument options: %s', options) + log.debug("Argument options: %s", options) if execute_callback: self._loaded() def add_options_to_parser(self, parser): - """ Add sections and their options to parser """ + """Add sections and their options to parser""" for section in self._ordered_section_names: self.sections[section].add_options_to_parser(parser) @@ -998,63 +979,70 @@ class Config: # pylint: disable=too-many-instance-attributes ) parser.add_argument( - '-i', '--interactive', - action='store_true', dest='interactive', - help="Enable configuration interactive mode" + "-i", + "--interactive", + action="store_true", + dest="interactive", + help="Enable configuration interactive mode", ) parser.add_argument( - '-O', '--overwrite', - action='store_true', dest='overwrite', - help="Overwrite configuration file if exists" + "-O", + "--overwrite", + action="store_true", + dest="overwrite", + help="Overwrite configuration file if exists", ) parser.add_argument( - '-V', '--validate', - action='store_true', dest='validate', + "-V", + "--validate", + action="store_true", + dest="validate", help=( - "Validate configuration: initialize application to test if provided parameters works.\n\n" - "Note: Validation will occured after configuration file creation or update. On error, " - "re-run with -O/--overwrite parameter to fix it." - ) + "Validate configuration: initialize application to test if provided parameters" + " works.\n\nNote: Validation will occured after configuration file creation or" + " update. On error, re-run with -O/--overwrite parameter to fix it." + ), ) - options = self.parse_arguments_options( - argv, create=False, execute_callback=False) + options = self.parse_arguments_options(argv, create=False, execute_callback=False) if os.path.exists(options.config) and not options.overwrite: - print(f'Configuration file {options.config} already exists') + print(f"Configuration file {options.config} already exists") sys.exit(1) if options.interactive: self.ask_values(set_it=True) if self.save(options.config): - print(f'Configuration file {options.config} created.') + print(f"Configuration file {options.config} created.") if options.validate: - print('Validate your configuration...') + print("Validate your configuration...") try: if self._loaded(): - print('Your configuration seem valid.') + print("Your configuration seem valid.") else: - print('Error(s) occurred validating your configuration. See logs for details.') + print( + "Error(s) occurred validating your configuration. See logs for details." + ) sys.exit(1) except Exception: # pylint: disable=broad-except print( - 'Exception occurred validating your configuration:\n' - f'{traceback.format_exc()}' - '\n\nSee logs for details.' + "Exception occurred validating your configuration:\n" + f"{traceback.format_exc()}" + "\n\nSee logs for details." ) sys.exit(2) else: - print(f'Error occured creating configuration file {options.config}') + print(f"Error occured creating configuration file {options.config}") sys.exit(1) sys.exit(0) @property def config_dir(self): - """ Retrieve configuration directory path """ + """Retrieve configuration directory path""" if self._filepath: return os.path.dirname(self._filepath) if self.default_config_dirpath: @@ -1063,15 +1051,13 @@ class Config: # pylint: disable=too-many-instance-attributes @property def config_filepath(self): - """ Retrieve configuration file path """ + """Retrieve configuration file path""" if self._filepath: return self._filepath if self.config_file_env_variable and os.environ.get(self.config_file_env_variable): return os.environ.get(self.config_file_env_variable) return os.path.join( - self.config_dir, - f'{self.shortname}.ini' if self.shortname - else 'config.ini' + self.config_dir, f"{self.shortname}.ini" if self.shortname else "config.ini" ) @@ -1104,11 +1090,11 @@ class ConfigurableObject: _config = None _config_section = None - def __init__(self, options=None, options_prefix=None, config=None, config_section=None, - **kwargs): - + def __init__( + self, options=None, options_prefix=None, config=None, config_section=None, **kwargs + ): for key, value in kwargs.items(): - assert key in self._defaults, f'Unknown {key} option' + assert key in self._defaults, f"Unknown {key} option" self._kwargs[key] = value if options: @@ -1116,9 +1102,9 @@ class ConfigurableObject: if options_prefix is not None: self._options_prefix = options_prefix elif self._config_name: - self._options_prefix = self._config_name + '_' + self._options_prefix = self._config_name + "_" else: - raise Exception(f'No configuration name defined for {__name__}') + raise Exception(f"No configuration name defined for {__name__}") if config: self._config = config @@ -1127,10 +1113,10 @@ class ConfigurableObject: elif self._config_name: self._config_section = self._config_name else: - raise Exception(f'No configuration name defined for {__name__}') + raise Exception(f"No configuration name defined for {__name__}") def _get_option(self, option, default=None, required=False): - """ Retreive option value """ + """Retreive option value""" if self._kwargs and option in self._kwargs: return self._kwargs[option] @@ -1140,21 +1126,26 @@ class ConfigurableObject: if self._config and self._config.defined(self._config_section, option): return self._config.get(self._config_section, option) - assert not required, f'Options {option} not defined' + assert not required, f"Options {option} not defined" return default if default is not None else self._defaults.get(option) - def configure(self, comment=None, ** kwargs): - """ Configure options on registered mylib.Config object """ - assert self._config, "mylib.Config object not registered. Must be passed to __init__ as config keyword argument." + def configure(self, comment=None, **kwargs): + """Configure options on registered mylib.Config object""" + assert self._config, ( + "mylib.Config object not registered. Must be passed to __init__ as config keyword" + " argument." + ) return self._config.add_section( self._config_section, comment=comment if comment else self._config_comment, - loaded_callback=self.initialize, **kwargs) + loaded_callback=self.initialize, + **kwargs, + ) def initialize(self, loaded_config=None): - """ Configuration initialized hook """ + """Configuration initialized hook""" if loaded_config: self.config = loaded_config # pylint: disable=attribute-defined-outside-init @@ -1177,11 +1168,12 @@ class ConfigSectionAsDictWrapper: self.__section.set(key, value) def __delitem__(self, key): - raise Exception('Deleting a configuration option is not supported') + raise Exception("Deleting a configuration option is not supported") # pylint: disable=too-few-public-methods class StdoutInfoFilter(logging.Filter): - """ Logging filter to keep messages only >= logging.INFO """ + """Logging filter to keep messages only >= logging.INFO""" + def filter(self, record): return record.levelno in (logging.DEBUG, logging.INFO) diff --git a/mylib/db.py b/mylib/db.py index 227208a..cfebaa1 100644 --- a/mylib/db.py +++ b/mylib/db.py @@ -1,10 +1,7 @@ -# -*- coding: utf-8 -*- - """ Basic SQL DB client """ import logging - log = logging.getLogger(__name__) @@ -12,8 +9,9 @@ log = logging.getLogger(__name__) # Exceptions # + class DBException(Exception): - """ That is the base exception class for all the other exceptions provided by this module. """ + """That is the base exception class for all the other exceptions provided by this module.""" def __init__(self, error, *args, **kwargs): for arg, value in kwargs.items(): @@ -29,7 +27,8 @@ class DBNotImplemented(DBException, RuntimeError): def __init__(self, method, class_name): super().__init__( "The method {method} is not yet implemented in class {class_name}", - method=method, class_name=class_name + method=method, + class_name=class_name, ) @@ -39,10 +38,7 @@ class DBFailToConnect(DBException, RuntimeError): """ def __init__(self, uri): - super().__init__( - "An error occured during database connection ({uri})", - uri=uri - ) + super().__init__("An error occured during database connection ({uri})", uri=uri) class DBDuplicatedSQLParameter(DBException, KeyError): @@ -53,8 +49,7 @@ class DBDuplicatedSQLParameter(DBException, KeyError): def __init__(self, parameter_name): super().__init__( - "Duplicated SQL parameter '{parameter_name}'", - parameter_name=parameter_name + "Duplicated SQL parameter '{parameter_name}'", parameter_name=parameter_name ) @@ -65,10 +60,7 @@ class DBUnsupportedWHEREClauses(DBException, TypeError): """ def __init__(self, where_clauses): - super().__init__( - "Unsupported WHERE clauses: {where_clauses}", - where_clauses=where_clauses - ) + super().__init__("Unsupported WHERE clauses: {where_clauses}", where_clauses=where_clauses) class DBInvalidOrderByClause(DBException, TypeError): @@ -79,13 +71,14 @@ class DBInvalidOrderByClause(DBException, TypeError): def __init__(self, order_by): super().__init__( - "Invalid ORDER BY clause: {order_by}. Must be a string or a list of two values (ordering field name and direction)", - order_by=order_by + "Invalid ORDER BY clause: {order_by}. Must be a string or a list of two values" + " (ordering field name and direction)", + order_by=order_by, ) class DB: - """ Database client """ + """Database client""" just_try = False @@ -93,14 +86,14 @@ class DB: self.just_try = just_try self._conn = None for arg, value in kwargs.items(): - setattr(self, f'_{arg}', value) + setattr(self, f"_{arg}", value) def connect(self, exit_on_error=True): - """ Connect to DB server """ - raise DBNotImplemented('connect', self.__class__.__name__) + """Connect to DB server""" + raise DBNotImplemented("connect", self.__class__.__name__) def close(self): - """ Close connection with DB server (if opened) """ + """Close connection with DB server (if opened)""" if self._conn: self._conn.close() self._conn = None @@ -110,12 +103,11 @@ class DB: log.debug( 'Run SQL query "%s" %s', sql, - "with params = {0}".format( # pylint: disable=consider-using-f-string - ', '.join([ - f'{key} = {value}' - for key, value in params.items() - ]) if params else "without params" - ) + "with params = {}".format( # pylint: disable=consider-using-f-string + ", ".join([f"{key} = {value}" for key, value in params.items()]) + if params + else "without params" + ), ) @staticmethod @@ -123,12 +115,11 @@ class DB: log.exception( 'Error during SQL query "%s" %s', sql, - "with params = {0}".format( # pylint: disable=consider-using-f-string - ', '.join([ - f'{key} = {value}' - for key, value in params.items() - ]) if params else "without params" - ) + "with params = {}".format( # pylint: disable=consider-using-f-string + ", ".join([f"{key} = {value}" for key, value in params.items()]) + if params + else "without params" + ), ) def doSQL(self, sql, params=None): @@ -141,7 +132,7 @@ class DB: :return: True on success, False otherwise :rtype: bool """ - raise DBNotImplemented('doSQL', self.__class__.__name__) + raise DBNotImplemented("doSQL", self.__class__.__name__) def doSelect(self, sql, params=None): """ @@ -153,7 +144,7 @@ class DB: :return: List of selected rows as dict on success, False otherwise :rtype: list, bool """ - raise DBNotImplemented('doSelect', self.__class__.__name__) + raise DBNotImplemented("doSelect", self.__class__.__name__) # # SQL helpers @@ -161,22 +152,20 @@ class DB: @staticmethod def _quote_table_name(table): - """ Quote table name """ - return '"{0}"'.format( # pylint: disable=consider-using-f-string - '"."'.join( - table.split('.') - ) + """Quote table name""" + return '"{}"'.format( # pylint: disable=consider-using-f-string + '"."'.join(table.split(".")) ) @staticmethod def _quote_field_name(field): - """ Quote table name """ + """Quote table name""" return f'"{field}"' @staticmethod def format_param(param): - """ Format SQL query parameter for prepared query """ - return f'%({param})s' + """Format SQL query parameter for prepared query""" + return f"%({param})s" @classmethod def _combine_params(cls, params, to_add=None, **kwargs): @@ -201,7 +190,8 @@ class DB: - a dict of WHERE clauses with field name as key and WHERE clause value as value - a list of any of previous valid WHERE clauses :param params: Dict of other already set SQL query parameters (optional) - :param where_op: SQL operator used to combine WHERE clauses together (optional, default: AND) + :param where_op: SQL operator used to combine WHERE clauses together (optional, default: + AND) :return: A tuple of two elements: raw SQL WHERE combined clauses and parameters on success :rtype: string, bool @@ -209,24 +199,27 @@ class DB: if params is None: params = {} if where_op is None: - where_op = 'AND' + where_op = "AND" if isinstance(where_clauses, str): return (where_clauses, params) - if isinstance(where_clauses, tuple) and len(where_clauses) == 2 and isinstance(where_clauses[1], dict): + if ( + isinstance(where_clauses, tuple) + and len(where_clauses) == 2 + and isinstance(where_clauses[1], dict) + ): cls._combine_params(params, where_clauses[1]) return (where_clauses[0], params) if isinstance(where_clauses, (list, tuple)): sql_where_clauses = [] for where_clause in where_clauses: - sql2, params = cls._format_where_clauses(where_clause, params=params, where_op=where_op) + sql2, params = cls._format_where_clauses( + where_clause, params=params, where_op=where_op + ) sql_where_clauses.append(sql2) - return ( - f' {where_op} '.join(sql_where_clauses), - params - ) + return (f" {where_op} ".join(sql_where_clauses), params) if isinstance(where_clauses, dict): sql_where_clauses = [] @@ -235,16 +228,13 @@ class DB: if field in params: idx = 1 while param in params: - param = f'{field}_{idx}' + param = f"{field}_{idx}" idx += 1 cls._combine_params(params, {param: value}) sql_where_clauses.append( - f'{cls._quote_field_name(field)} = {cls.format_param(param)}' + f"{cls._quote_field_name(field)} = {cls.format_param(param)}" ) - return ( - f' {where_op} '.join(sql_where_clauses), - params - ) + return (f" {where_op} ".join(sql_where_clauses), params) raise DBUnsupportedWHEREClauses(where_clauses) @classmethod @@ -255,29 +245,26 @@ class DB: :param sql: The SQL query to complete :param params: The dict of parameters of the SQL query to complete :param where_clauses: The WHERE clause (see _format_where_clauses()) - :param where_op: SQL operator used to combine WHERE clauses together (optional, default: see _format_where_clauses()) + :param where_op: SQL operator used to combine WHERE clauses together (optional, default: + see _format_where_clauses()) :return: :rtype: A tuple of two elements: raw SQL WHERE combined clauses and parameters """ if where_clauses: - sql_where, params = cls._format_where_clauses(where_clauses, params=params, where_op=where_op) + sql_where, params = cls._format_where_clauses( + where_clauses, params=params, where_op=where_op + ) sql += " WHERE " + sql_where return (sql, params) def insert(self, table, values, just_try=False): - """ Run INSERT SQL query """ + """Run INSERT SQL query""" # pylint: disable=consider-using-f-string - sql = 'INSERT INTO {0} ({1}) VALUES ({2})'.format( + sql = "INSERT INTO {} ({}) VALUES ({})".format( self._quote_table_name(table), - ', '.join([ - self._quote_field_name(field) - for field in values.keys() - ]), - ", ".join([ - self.format_param(key) - for key in values - ]) + ", ".join([self._quote_field_name(field) for field in values.keys()]), + ", ".join([self.format_param(key) for key in values]), ) if just_try: @@ -291,21 +278,20 @@ class DB: return True def update(self, table, values, where_clauses, where_op=None, just_try=False): - """ Run UPDATE SQL query """ + """Run UPDATE SQL query""" # pylint: disable=consider-using-f-string - sql = 'UPDATE {0} SET {1}'.format( + sql = "UPDATE {} SET {}".format( self._quote_table_name(table), - ", ".join([ - f'{self._quote_field_name(key)} = {self.format_param(key)}' - for key in values - ]) + ", ".join( + [f"{self._quote_field_name(key)} = {self.format_param(key)}" for key in values] + ), ) params = values try: sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op) except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses): - log.error('Fail to add WHERE clauses', exc_info=True) + log.error("Fail to add WHERE clauses", exc_info=True) return False if just_try: @@ -318,15 +304,15 @@ class DB: return False return True - def delete(self, table, where_clauses, where_op='AND', just_try=False): - """ Run DELETE SQL query """ - sql = f'DELETE FROM {self._quote_table_name(table)}' + def delete(self, table, where_clauses, where_op="AND", just_try=False): + """Run DELETE SQL query""" + sql = f"DELETE FROM {self._quote_table_name(table)}" params = {} try: sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op) except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses): - log.error('Fail to add WHERE clauses', exc_info=True) + log.error("Fail to add WHERE clauses", exc_info=True) return False if just_try: @@ -340,8 +326,8 @@ class DB: return True def truncate(self, table, just_try=False): - """ Run TRUNCATE SQL query """ - sql = f'TRUNCATE TABLE {self._quote_table_name(table)}' + """Run TRUNCATE SQL query""" + sql = f"TRUNCATE TABLE {self._quote_table_name(table)}" if just_try: log.debug("Just-try mode: execute TRUNCATE query: %s", sql) @@ -353,33 +339,36 @@ class DB: return False return True - def select(self, table, where_clauses=None, fields=None, where_op='AND', order_by=None, just_try=False): - """ Run SELECT SQL query """ + def select( + self, table, where_clauses=None, fields=None, where_op="AND", order_by=None, just_try=False + ): + """Run SELECT SQL query""" sql = "SELECT " if fields is None: sql += "*" elif isinstance(fields, str): - sql += f'{self._quote_field_name(fields)}' + sql += f"{self._quote_field_name(fields)}" else: - sql += ', '.join([self._quote_field_name(field) for field in fields]) + sql += ", ".join([self._quote_field_name(field) for field in fields]) - sql += f' FROM {self._quote_table_name(table)}' + sql += f" FROM {self._quote_table_name(table)}" params = {} try: sql, params = self._add_where_clauses(sql, params, where_clauses, where_op=where_op) except (DBDuplicatedSQLParameter, DBUnsupportedWHEREClauses): - log.error('Fail to add WHERE clauses', exc_info=True) + log.error("Fail to add WHERE clauses", exc_info=True) return False if order_by: if isinstance(order_by, str): - sql += f' ORDER BY {order_by}' + sql += f" ORDER BY {order_by}" elif ( - isinstance(order_by, (list, tuple)) and len(order_by) == 2 + isinstance(order_by, (list, tuple)) + and len(order_by) == 2 and isinstance(order_by[0], str) and isinstance(order_by[1], str) - and order_by[1].upper() in ('ASC', 'UPPER') + and order_by[1].upper() in ("ASC", "UPPER") ): sql += f' ORDER BY "{order_by[0]}" {order_by[1].upper()}' else: diff --git a/mylib/email.py b/mylib/email.py index ed26b0d..1cca112 100644 --- a/mylib/email.py +++ b/mylib/email.py @@ -1,49 +1,51 @@ -# -*- coding: utf-8 -*- - """ Email client to forge and send emails """ +import email.utils import logging import os import smtplib -import email.utils -from email.mime.text import MIMEText -from email.mime.multipart import MIMEMultipart -from email.mime.base import MIMEBase from email.encoders import encode_base64 +from email.mime.base import MIMEBase +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText from mako.template import Template as MakoTemplate -from mylib.config import ConfigurableObject -from mylib.config import BooleanOption -from mylib.config import IntegerOption -from mylib.config import PasswordOption -from mylib.config import StringOption +from mylib.config import ( + BooleanOption, + ConfigurableObject, + IntegerOption, + PasswordOption, + StringOption, +) log = logging.getLogger(__name__) -class EmailClient(ConfigurableObject): # pylint: disable=useless-object-inheritance,too-many-instance-attributes +class EmailClient( + ConfigurableObject +): # pylint: disable=useless-object-inheritance,too-many-instance-attributes """ Email client This class abstract all interactions with the SMTP server. """ - _config_name = 'email' - _config_comment = 'Email' + _config_name = "email" + _config_comment = "Email" _defaults = { - 'smtp_host': 'localhost', - 'smtp_port': 25, - 'smtp_ssl': False, - 'smtp_tls': False, - 'smtp_user': None, - 'smtp_password': None, - 'smtp_debug': False, - 'sender_name': 'No reply', - 'sender_email': 'noreply@localhost', - 'encoding': 'utf-8', - 'catch_all_addr': None, - 'just_try': False, + "smtp_host": "localhost", + "smtp_port": 25, + "smtp_ssl": False, + "smtp_tls": False, + "smtp_user": None, + "smtp_password": None, + "smtp_debug": False, + "sender_name": "No reply", + "sender_email": "noreply@localhost", + "encoding": "utf-8", + "catch_all_addr": None, + "just_try": False, } templates = {} @@ -55,61 +57,107 @@ class EmailClient(ConfigurableObject): # pylint: disable=useless-object-inherit self.templates = templates if templates else {} # pylint: disable=arguments-differ,arguments-renamed - def configure(self, use_smtp=True, just_try=True, ** kwargs): - """ Configure options on registered mylib.Config object """ + def configure(self, use_smtp=True, just_try=True, **kwargs): + """Configure options on registered mylib.Config object""" section = super().configure(**kwargs) if use_smtp: section.add_option( - StringOption, 'smtp_host', default=self._defaults['smtp_host'], - comment='SMTP server hostname/IP address') + StringOption, + "smtp_host", + default=self._defaults["smtp_host"], + comment="SMTP server hostname/IP address", + ) section.add_option( - IntegerOption, 'smtp_port', default=self._defaults['smtp_port'], - comment='SMTP server port') + IntegerOption, + "smtp_port", + default=self._defaults["smtp_port"], + comment="SMTP server port", + ) section.add_option( - BooleanOption, 'smtp_ssl', default=self._defaults['smtp_ssl'], - comment='Use SSL on SMTP server connection') + BooleanOption, + "smtp_ssl", + default=self._defaults["smtp_ssl"], + comment="Use SSL on SMTP server connection", + ) section.add_option( - BooleanOption, 'smtp_tls', default=self._defaults['smtp_tls'], - comment='Use TLS on SMTP server connection') + BooleanOption, + "smtp_tls", + default=self._defaults["smtp_tls"], + comment="Use TLS on SMTP server connection", + ) section.add_option( - StringOption, 'smtp_user', default=self._defaults['smtp_user'], - comment='SMTP authentication username') + StringOption, + "smtp_user", + default=self._defaults["smtp_user"], + comment="SMTP authentication username", + ) section.add_option( - PasswordOption, 'smtp_password', default=self._defaults['smtp_password'], + PasswordOption, + "smtp_password", + default=self._defaults["smtp_password"], comment='SMTP authentication password (set to "keyring" to use XDG keyring)', - username_option='smtp_user', keyring_value='keyring') + username_option="smtp_user", + keyring_value="keyring", + ) section.add_option( - BooleanOption, 'smtp_debug', default=self._defaults['smtp_debug'], - comment='Enable SMTP debugging') + BooleanOption, + "smtp_debug", + default=self._defaults["smtp_debug"], + comment="Enable SMTP debugging", + ) section.add_option( - StringOption, 'sender_name', default=self._defaults['sender_name'], - comment='Sender name') + StringOption, + "sender_name", + default=self._defaults["sender_name"], + comment="Sender name", + ) section.add_option( - StringOption, 'sender_email', default=self._defaults['sender_email'], - comment='Sender email address') + StringOption, + "sender_email", + default=self._defaults["sender_email"], + comment="Sender email address", + ) section.add_option( - StringOption, 'encoding', default=self._defaults['encoding'], - comment='Email encoding') + StringOption, "encoding", default=self._defaults["encoding"], comment="Email encoding" + ) section.add_option( - StringOption, 'catch_all_addr', default=self._defaults['catch_all_addr'], - comment='Catch all sent emails to this specified email address') + StringOption, + "catch_all_addr", + default=self._defaults["catch_all_addr"], + comment="Catch all sent emails to this specified email address", + ) if just_try: section.add_option( - BooleanOption, 'just_try', default=self._defaults['just_try'], - comment='Just-try mode: do not really send emails') + BooleanOption, + "just_try", + default=self._defaults["just_try"], + comment="Just-try mode: do not really send emails", + ) return section - def forge_message(self, rcpt_to, subject=None, html_body=None, text_body=None, # pylint: disable=too-many-arguments,too-many-locals - attachment_files=None, attachment_payloads=None, sender_name=None, - sender_email=None, encoding=None, template=None, **template_vars): + def forge_message( + self, + rcpt_to, + subject=None, + html_body=None, + text_body=None, # pylint: disable=too-many-arguments,too-many-locals + attachment_files=None, + attachment_payloads=None, + sender_name=None, + sender_email=None, + encoding=None, + template=None, + **template_vars, + ): """ Forge a message - :param rcpt_to: The recipient of the email. Could be a tuple(name, email) or just the email of the recipient. + :param rcpt_to: The recipient of the email. Could be a tuple(name, email) or + just the email of the recipient. :param subject: The subject of the email. :param html_body: The HTML body of the email :param text_body: The plain text body of the email @@ -122,64 +170,69 @@ class EmailClient(ConfigurableObject): # pylint: disable=useless-object-inherit All other parameters will be consider as template variables. """ - msg = MIMEMultipart('alternative') - msg['To'] = email.utils.formataddr(rcpt_to) if isinstance(rcpt_to, tuple) else rcpt_to - msg['From'] = email.utils.formataddr( + msg = MIMEMultipart("alternative") + msg["To"] = email.utils.formataddr(rcpt_to) if isinstance(rcpt_to, tuple) else rcpt_to + msg["From"] = email.utils.formataddr( ( - sender_name or self._get_option('sender_name'), - sender_email or self._get_option('sender_email') + sender_name or self._get_option("sender_name"), + sender_email or self._get_option("sender_email"), ) ) if subject: - msg['Subject'] = subject.format(**template_vars) - msg['Date'] = email.utils.formatdate(None, True) - encoding = encoding if encoding else self._get_option('encoding') + msg["Subject"] = subject.format(**template_vars) + msg["Date"] = email.utils.formatdate(None, True) + encoding = encoding if encoding else self._get_option("encoding") if template: - assert template in self.templates, f'Unknwon template {template}' + assert template in self.templates, f"Unknwon template {template}" # Handle subject from template if not subject: - assert self.templates[template].get('subject'), f'No subject defined in template {template}' - msg['Subject'] = self.templates[template]['subject'].format(**template_vars) + assert self.templates[template].get( + "subject" + ), f"No subject defined in template {template}" + msg["Subject"] = self.templates[template]["subject"].format(**template_vars) # Put HTML part in last one to prefered it parts = [] - if self.templates[template].get('text'): - if isinstance(self.templates[template]['text'], MakoTemplate): - parts.append((self.templates[template]['text'].render(**template_vars), 'plain')) + if self.templates[template].get("text"): + if isinstance(self.templates[template]["text"], MakoTemplate): + parts.append( + (self.templates[template]["text"].render(**template_vars), "plain") + ) else: - parts.append((self.templates[template]['text'].format(**template_vars), 'plain')) - if self.templates[template].get('html'): - if isinstance(self.templates[template]['html'], MakoTemplate): - parts.append((self.templates[template]['html'].render(**template_vars), 'html')) + parts.append( + (self.templates[template]["text"].format(**template_vars), "plain") + ) + if self.templates[template].get("html"): + if isinstance(self.templates[template]["html"], MakoTemplate): + parts.append((self.templates[template]["html"].render(**template_vars), "html")) else: - parts.append((self.templates[template]['html'].format(**template_vars), 'html')) + parts.append((self.templates[template]["html"].format(**template_vars), "html")) for body, mime_type in parts: msg.attach(MIMEText(body.encode(encoding), mime_type, _charset=encoding)) else: - assert subject, 'No subject provided' + assert subject, "No subject provided" if text_body: - msg.attach(MIMEText(text_body.encode(encoding), 'plain', _charset=encoding)) + msg.attach(MIMEText(text_body.encode(encoding), "plain", _charset=encoding)) if html_body: - msg.attach(MIMEText(html_body.encode(encoding), 'html', _charset=encoding)) + msg.attach(MIMEText(html_body.encode(encoding), "html", _charset=encoding)) if attachment_files: for filepath in attachment_files: - with open(filepath, 'rb') as fp: - part = MIMEBase('application', "octet-stream") + with open(filepath, "rb") as fp: + part = MIMEBase("application", "octet-stream") part.set_payload(fp.read()) encode_base64(part) part.add_header( - 'Content-Disposition', - f'attachment; filename="{os.path.basename(filepath)}"') + "Content-Disposition", + f'attachment; filename="{os.path.basename(filepath)}"', + ) msg.attach(part) if attachment_payloads: for filename, payload in attachment_payloads: - part = MIMEBase('application', "octet-stream") + part = MIMEBase("application", "octet-stream") part.set_payload(payload) encode_base64(part) - part.add_header( - 'Content-Disposition', - f'attachment; filename="{filename}"') + part.add_header("Content-Disposition", f'attachment; filename="{filename}"') msg.attach(part) return msg @@ -192,200 +245,184 @@ class EmailClient(ConfigurableObject): # pylint: disable=useless-object-inherit :param msg: The message of this email (as MIMEBase or derivated classes) :param subject: The subject of the email (only if the message is not provided using msg parameter) - :param just_try: Enable just try mode (do not really send email, default: as defined on initialization) + :param just_try: Enable just try mode (do not really send email, default: as defined on + initialization) All other parameters will be consider as parameters to forge the message (only if the message is not provided using msg parameter). """ msg = msg if msg else self.forge_message(rcpt_to, subject, **forge_args) - if just_try or self._get_option('just_try'): - log.debug('Just-try mode: do not really send this email to %s (subject="%s")', rcpt_to, subject or msg.get('subject', 'No subject')) + if just_try or self._get_option("just_try"): + log.debug( + 'Just-try mode: do not really send this email to %s (subject="%s")', + rcpt_to, + subject or msg.get("subject", "No subject"), + ) return True - catch_addr = self._get_option('catch_all_addr') + catch_addr = self._get_option("catch_all_addr") if catch_addr: - log.debug('Catch email originaly send to %s to %s', rcpt_to, catch_addr) + log.debug("Catch email originaly send to %s to %s", rcpt_to, catch_addr) rcpt_to = catch_addr - smtp_host = self._get_option('smtp_host') - smtp_port = self._get_option('smtp_port') + smtp_host = self._get_option("smtp_host") + smtp_port = self._get_option("smtp_port") try: - if self._get_option('smtp_ssl'): + if self._get_option("smtp_ssl"): logging.info("Establish SSL connection to server %s:%s", smtp_host, smtp_port) server = smtplib.SMTP_SSL(smtp_host, smtp_port) else: logging.info("Establish connection to server %s:%s", smtp_host, smtp_port) server = smtplib.SMTP(smtp_host, smtp_port) - if self._get_option('smtp_tls'): - logging.info('Start TLS on SMTP connection') + if self._get_option("smtp_tls"): + logging.info("Start TLS on SMTP connection") server.starttls() except smtplib.SMTPException: - log.error('Error connecting to SMTP server %s:%s', smtp_host, smtp_port, exc_info=True) + log.error("Error connecting to SMTP server %s:%s", smtp_host, smtp_port, exc_info=True) return False - if self._get_option('smtp_debug'): + if self._get_option("smtp_debug"): server.set_debuglevel(True) - smtp_user = self._get_option('smtp_user') - smtp_password = self._get_option('smtp_password') + smtp_user = self._get_option("smtp_user") + smtp_password = self._get_option("smtp_password") if smtp_user and smtp_password: try: - log.info('Try to authenticate on SMTP connection as %s', smtp_user) + log.info("Try to authenticate on SMTP connection as %s", smtp_user) server.login(smtp_user, smtp_password) except smtplib.SMTPException: log.error( - 'Error authenticating on SMTP server %s:%s with user %s', - smtp_host, smtp_port, smtp_user, exc_info=True) + "Error authenticating on SMTP server %s:%s with user %s", + smtp_host, + smtp_port, + smtp_user, + exc_info=True, + ) return False error = False try: - log.info('Sending email to %s', rcpt_to) + log.info("Sending email to %s", rcpt_to) server.sendmail( - self._get_option('sender_email'), + self._get_option("sender_email"), [rcpt_to[1] if isinstance(rcpt_to, tuple) else rcpt_to], - msg.as_string() + msg.as_string(), ) except smtplib.SMTPException: error = True - log.error('Error sending email to %s', rcpt_to, exc_info=True) + log.error("Error sending email to %s", rcpt_to, exc_info=True) finally: server.quit() return not error -if __name__ == '__main__': +if __name__ == "__main__": # Run tests + import argparse import datetime import sys - import argparse - # Options parser parser = argparse.ArgumentParser() parser.add_argument( - '-v', '--verbose', - action="store_true", - dest="verbose", - help="Enable verbose mode" + "-v", "--verbose", action="store_true", dest="verbose", help="Enable verbose mode" ) parser.add_argument( - '-d', '--debug', - action="store_true", - dest="debug", - help="Enable debug mode" + "-d", "--debug", action="store_true", dest="debug", help="Enable debug mode" ) parser.add_argument( - '-l', '--log-file', - action="store", - type=str, - dest="logfile", - help="Log file path" + "-l", "--log-file", action="store", type=str, dest="logfile", help="Log file path" ) parser.add_argument( - '-j', '--just-try', - action="store_true", - dest="just_try", - help="Enable just-try mode" + "-j", "--just-try", action="store_true", dest="just_try", help="Enable just-try mode" ) - email_opts = parser.add_argument_group('Email options') + email_opts = parser.add_argument_group("Email options") email_opts.add_argument( - '-H', '--smtp-host', - action="store", - type=str, - dest="email_smtp_host", - help="SMTP host" + "-H", "--smtp-host", action="store", type=str, dest="email_smtp_host", help="SMTP host" ) email_opts.add_argument( - '-P', '--smtp-port', - action="store", - type=int, - dest="email_smtp_port", - help="SMTP port" + "-P", "--smtp-port", action="store", type=int, dest="email_smtp_port", help="SMTP port" ) email_opts.add_argument( - '-S', '--smtp-ssl', - action="store_true", - dest="email_smtp_ssl", - help="Use SSL" + "-S", "--smtp-ssl", action="store_true", dest="email_smtp_ssl", help="Use SSL" ) email_opts.add_argument( - '-T', '--smtp-tls', - action="store_true", - dest="email_smtp_tls", - help="Use TLS" + "-T", "--smtp-tls", action="store_true", dest="email_smtp_tls", help="Use TLS" ) email_opts.add_argument( - '-u', '--smtp-user', - action="store", - type=str, - dest="email_smtp_user", - help="SMTP username" + "-u", "--smtp-user", action="store", type=str, dest="email_smtp_user", help="SMTP username" ) email_opts.add_argument( - '-p', '--smtp-password', + "-p", + "--smtp-password", action="store", type=str, dest="email_smtp_password", - help="SMTP password" + help="SMTP password", ) email_opts.add_argument( - '-D', '--smtp-debug', + "-D", + "--smtp-debug", action="store_true", dest="email_smtp_debug", - help="Debug SMTP connection" + help="Debug SMTP connection", ) email_opts.add_argument( - '-e', '--email-encoding', + "-e", + "--email-encoding", action="store", type=str, dest="email_encoding", - help="SMTP encoding" + help="SMTP encoding", ) email_opts.add_argument( - '-f', '--sender-name', + "-f", + "--sender-name", action="store", type=str, dest="email_sender_name", - help="Sender name" + help="Sender name", ) email_opts.add_argument( - '-F', '--sender-email', + "-F", + "--sender-email", action="store", type=str, dest="email_sender_email", - help="Sender email" + help="Sender email", ) email_opts.add_argument( - '-C', '--catch-all', + "-C", + "--catch-all", action="store", type=str, dest="email_catch_all", - help="Catch all sent email: specify catch recipient email address" + help="Catch all sent email: specify catch recipient email address", ) - test_opts = parser.add_argument_group('Test email options') + test_opts = parser.add_argument_group("Test email options") test_opts.add_argument( - '-t', '--to', + "-t", + "--to", action="store", type=str, dest="test_to", @@ -393,7 +430,8 @@ if __name__ == '__main__': ) test_opts.add_argument( - '-m', '--mako', + "-m", + "--mako", action="store_true", dest="test_mako", help="Test mako templating", @@ -402,11 +440,11 @@ if __name__ == '__main__': options = parser.parse_args() if not options.test_to: - parser.error('You must specify test email recipient using -t/--to parameter') + parser.error("You must specify test email recipient using -t/--to parameter") sys.exit(1) # Initialize logs - logformat = '%(asctime)s - Test EmailClient - %(levelname)s - %(message)s' + logformat = "%(asctime)s - Test EmailClient - %(levelname)s - %(message)s" if options.debug: loglevel = logging.DEBUG elif options.verbose: @@ -421,9 +459,10 @@ if __name__ == '__main__': if options.email_smtp_user and not options.email_smtp_password: import getpass - options.email_smtp_password = getpass.getpass('Please enter SMTP password: ') - logging.info('Initialize Email client') + options.email_smtp_password = getpass.getpass("Please enter SMTP password: ") + + logging.info("Initialize Email client") email_client = EmailClient( smtp_host=options.email_smtp_host, smtp_port=options.email_smtp_port, @@ -441,20 +480,24 @@ if __name__ == '__main__': test=dict( subject="Test email", text=( - "Just a test email sent at {sent_date}." if not options.test_mako else - MakoTemplate("Just a test email sent at ${sent_date}.") + "Just a test email sent at {sent_date}." + if not options.test_mako + else MakoTemplate("Just a test email sent at ${sent_date}.") ), html=( - "Just a test email. (sent at {sent_date})" if not options.test_mako else - MakoTemplate("Just a test email. (sent at ${sent_date})") - ) + "Just a test email. (sent at {sent_date})" + if not options.test_mako + else MakoTemplate( + "Just a test email. (sent at ${sent_date})" + ) + ), ) - ) + ), ) - logging.info('Send a test email to %s', options.test_to) - if email_client.send(options.test_to, template='test', sent_date=datetime.datetime.now()): - logging.info('Test email sent') + logging.info("Send a test email to %s", options.test_to) + if email_client.send(options.test_to, template="test", sent_date=datetime.datetime.now()): + logging.info("Test email sent") sys.exit(0) - logging.error('Fail to send test email') + logging.error("Fail to send test email") sys.exit(1) diff --git a/mylib/ldap.py b/mylib/ldap.py index 1ee1eca..924252f 100644 --- a/mylib/ldap.py +++ b/mylib/ldap.py @@ -1,15 +1,13 @@ -# -*- coding: utf-8 -*- - """ LDAP server connection helper """ import copy import datetime import logging -import pytz import dateutil.parser import dateutil.tz import ldap +import pytz from ldap import modlist from ldap.controls import SimplePagedResultsControl from ldap.controls.simple import RelaxRulesControl @@ -18,39 +16,33 @@ from ldap.dn import escape_dn_chars, explode_dn from mylib import pretty_format_dict log = logging.getLogger(__name__) -DEFAULT_ENCODING = 'utf-8' +DEFAULT_ENCODING = "utf-8" -def decode_ldap_value(value, encoding='utf-8'): - """ Decoding LDAP attribute values helper """ +def decode_ldap_value(value, encoding="utf-8"): + """Decoding LDAP attribute values helper""" if isinstance(value, bytes): return value.decode(encoding) if isinstance(value, list): return [decode_ldap_value(v) for v in value] if isinstance(value, dict): - return dict( - (key, decode_ldap_value(values)) - for key, values in value.items() - ) + return {key: decode_ldap_value(values) for key, values in value.items()} return value -def encode_ldap_value(value, encoding='utf-8'): - """ Encoding LDAP attribute values helper """ +def encode_ldap_value(value, encoding="utf-8"): + """Encoding LDAP attribute values helper""" if isinstance(value, str): return value.encode(encoding) if isinstance(value, list): return [encode_ldap_value(v) for v in value] if isinstance(value, dict): - return dict( - (key, encode_ldap_value(values)) - for key, values in value.items() - ) + return {key: encode_ldap_value(values) for key, values in value.items()} return value class LdapServer: - """ LDAP server connection helper """ # pylint: disable=useless-object-inheritance + """LDAP server connection helper""" # pylint: disable=useless-object-inheritance uri = None dn = None @@ -59,9 +51,17 @@ class LdapServer: con = 0 - def __init__(self, uri, dn=None, pwd=None, v2=None, - raiseOnError=False, logger=False, checkCert=True, - disableReferral=False): + def __init__( + self, + uri, + dn=None, + pwd=None, + v2=None, + raiseOnError=False, + logger=False, + checkCert=True, + disableReferral=False, + ): self.uri = uri self.dn = dn self.pwd = pwd @@ -81,7 +81,7 @@ class LdapServer: self.logger.log(level, error) def connect(self): - """ Start connection to LDAP server """ + """Start connection to LDAP server""" if self.con == 0: try: if not self.checkCert: @@ -98,37 +98,38 @@ class LdapServer: if self.dn: con.simple_bind_s(self.dn, self.pwd) - elif self.uri.startswith('ldapi://'): + elif self.uri.startswith("ldapi://"): con.sasl_interactive_bind_s("", ldap.sasl.external()) self.con = con return True except ldap.LDAPError as e: # pylint: disable=no-member self._error( - f'LdapServer - Error connecting and binding to LDAP server: {e}', - logging.CRITICAL) + f"LdapServer - Error connecting and binding to LDAP server: {e}", + logging.CRITICAL, + ) return False return True @staticmethod def get_scope(scope): - """ Map scope parameter to python-ldap value """ - if scope == 'base': + """Map scope parameter to python-ldap value""" + if scope == "base": return ldap.SCOPE_BASE # pylint: disable=no-member - if scope == 'one': + if scope == "one": return ldap.SCOPE_ONELEVEL # pylint: disable=no-member - if scope == 'sub': + if scope == "sub": return ldap.SCOPE_SUBTREE # pylint: disable=no-member raise Exception(f'Unknown LDAP scope "{scope}"') def search(self, basedn, filterstr=None, attrs=None, sizelimit=None, scope=None): - """ Run a search on LDAP server """ + """Run a search on LDAP server""" assert self.con or self.connect() res_id = self.con.search( basedn, - self.get_scope(scope if scope else 'sub'), - filterstr if filterstr else '(objectClass=*)', - attrs if attrs else [] + self.get_scope(scope if scope else "sub"), + filterstr if filterstr else "(objectClass=*)", + attrs if attrs else [], ) ret = {} c = 0 @@ -142,65 +143,64 @@ class LdapServer: return ret def get_object(self, dn, filterstr=None, attrs=None): - """ Retrieve a LDAP object specified by its DN """ - result = self.search(dn, filterstr=filterstr, scope='base', attrs=attrs) + """Retrieve a LDAP object specified by its DN""" + result = self.search(dn, filterstr=filterstr, scope="base", attrs=attrs) return result[dn] if dn in result else None - def paged_search(self, basedn, filterstr=None, attrs=None, scope=None, pagesize=None, - sizelimit=None): - """ Run a paged search on LDAP server """ + def paged_search( + self, basedn, filterstr=None, attrs=None, scope=None, pagesize=None, sizelimit=None + ): + """Run a paged search on LDAP server""" assert not self.v2, "Paged search is not available on LDAP version 2" assert self.con or self.connect() # Set parameters default values (if not defined) - filterstr = filterstr if filterstr else '(objectClass=*)' + filterstr = filterstr if filterstr else "(objectClass=*)" attrs = attrs if attrs else [] - scope = scope if scope else 'sub' + scope = scope if scope else "sub" pagesize = pagesize if pagesize else 500 # Initialize SimplePagedResultsControl object page_control = SimplePagedResultsControl( - True, - size=pagesize, - cookie='' # Start without cookie + True, size=pagesize, cookie="" # Start without cookie ) ret = {} pages_count = 0 self.logger.debug( - "LdapServer - Paged search with base DN '%s', filter '%s', scope '%s', pagesize=%d and attrs=%s", + "LdapServer - Paged search with base DN '%s', filter '%s', scope '%s', pagesize=%d" + " and attrs=%s", basedn, filterstr, scope, pagesize, - attrs + attrs, ) while True: pages_count += 1 self.logger.debug( - "LdapServer - Paged search: request page %d with a maximum of %d objects (current total count: %d)", + "LdapServer - Paged search: request page %d with a maximum of %d objects" + " (current total count: %d)", pages_count, pagesize, - len(ret) + len(ret), ) try: res_id = self.con.search_ext( - basedn, - self.get_scope(scope), - filterstr, - attrs, - serverctrls=[page_control] + basedn, self.get_scope(scope), filterstr, attrs, serverctrls=[page_control] ) except ldap.LDAPError as e: # pylint: disable=no-member self._error( - f'LdapServer - Error running paged search on LDAP server: {e}', - logging.CRITICAL) + f"LdapServer - Error running paged search on LDAP server: {e}", logging.CRITICAL + ) return False try: - rtype, rdata, rmsgid, rctrls = self.con.result3(res_id) # pylint: disable=unused-variable + # pylint: disable=unused-variable + rtype, rdata, rmsgid, rctrls = self.con.result3(res_id) except ldap.LDAPError as e: # pylint: disable=no-member self._error( - f'LdapServer - Error pulling paged search result from LDAP server: {e}', - logging.CRITICAL) + f"LdapServer - Error pulling paged search result from LDAP server: {e}", + logging.CRITICAL, + ) return False # Detect and catch PagedResultsControl answer from rctrls @@ -214,8 +214,9 @@ class LdapServer: # If PagedResultsControl answer not detected, paged serach if not result_page_control: self._error( - 'LdapServer - Server ignores RFC2696 control, paged search can not works', - logging.CRITICAL) + "LdapServer - Server ignores RFC2696 control, paged search can not works", + logging.CRITICAL, + ) return False # Store results of this page @@ -236,11 +237,16 @@ class LdapServer: # Otherwise, set cookie for the next search page_control.cookie = result_page_control.cookie - self.logger.debug("LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)", len(ret), pages_count, pagesize) + self.logger.debug( + "LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)", + len(ret), + pages_count, + pagesize, + ) return ret def add_object(self, dn, attrs, encode=False): - """ Add an object in LDAP directory """ + """Add an object in LDAP directory""" ldif = modlist.addModlist(encode_ldap_value(attrs) if encode else attrs) assert self.con or self.connect() try: @@ -248,17 +254,17 @@ class LdapServer: self.con.add_s(dn, ldif) return True except ldap.LDAPError as e: # pylint: disable=no-member - self._error(f'LdapServer - Error adding {dn}: {e}', logging.ERROR) + self._error(f"LdapServer - Error adding {dn}: {e}", logging.ERROR) return False def update_object(self, dn, old, new, ignore_attrs=None, relax=False, encode=False): - """ Update an object in LDAP directory """ + """Update an object in LDAP directory""" assert not relax or not self.v2, "Relax modification is not available on LDAP version 2" ldif = modlist.modifyModlist( encode_ldap_value(old) if encode else old, encode_ldap_value(new) if encode else new, - ignore_attr_types=ignore_attrs if ignore_attrs else [] + ignore_attr_types=ignore_attrs if ignore_attrs else [], ) if not ldif: return True @@ -271,17 +277,17 @@ class LdapServer: return True except ldap.LDAPError as e: # pylint: disable=no-member self._error( - f'LdapServer - Error updating {dn} : {e}\nOld: {old}\nNew: {new}', - logging.ERROR) + f"LdapServer - Error updating {dn} : {e}\nOld: {old}\nNew: {new}", logging.ERROR + ) return False @staticmethod def update_need(old, new, ignore_attrs=None, encode=False): - """ Check if an update is need on a LDAP object based on its old and new attributes values """ + """Check if an update is need on a LDAP object based on its old and new attributes values""" ldif = modlist.modifyModlist( encode_ldap_value(old) if encode else old, encode_ldap_value(new) if encode else new, - ignore_attr_types=ignore_attrs if ignore_attrs else [] + ignore_attr_types=ignore_attrs if ignore_attrs else [], ) if not ldif: return False @@ -289,48 +295,54 @@ class LdapServer: @staticmethod def get_changes(old, new, ignore_attrs=None, encode=False): - """ Retrieve changes (as modlist) on an object based on its old and new attributes values """ + """Retrieve changes (as modlist) on an object based on its old and new attributes values""" return modlist.modifyModlist( encode_ldap_value(old) if encode else old, encode_ldap_value(new) if encode else new, - ignore_attr_types=ignore_attrs if ignore_attrs else [] + ignore_attr_types=ignore_attrs if ignore_attrs else [], ) @staticmethod def format_changes(old, new, ignore_attrs=None, prefix=None, encode=False): - """ Format changes (modlist) on an object based on its old and new attributes values to display/log it """ + """ + Format changes (modlist) on an object based on its old and new attributes values to + display/log it + """ msg = [] - prefix = prefix if prefix else '' - for (op, attr, val) in modlist.modifyModlist( + prefix = prefix if prefix else "" + for op, attr, val in modlist.modifyModlist( encode_ldap_value(old) if encode else old, encode_ldap_value(new) if encode else new, - ignore_attr_types=ignore_attrs if ignore_attrs else [] + ignore_attr_types=ignore_attrs if ignore_attrs else [], ): if op == ldap.MOD_ADD: # pylint: disable=no-member - op = 'ADD' + op = "ADD" elif op == ldap.MOD_DELETE: # pylint: disable=no-member - op = 'DELETE' + op = "DELETE" elif op == ldap.MOD_REPLACE: # pylint: disable=no-member - op = 'REPLACE' + op = "REPLACE" else: - op = f'UNKNOWN (={op})' - if val is None and op == 'DELETE': - msg.append(f'{prefix} - {op} {attr}') + op = f"UNKNOWN (={op})" + if val is None and op == "DELETE": + msg.append(f"{prefix} - {op} {attr}") else: - msg.append(f'{prefix} - {op} {attr}: {val}') - return '\n'.join(msg) + msg.append(f"{prefix} - {op} {attr}: {val}") + return "\n".join(msg) def rename_object(self, dn, new_rdn, new_sup=None, delete_old=True): - """ Rename an object in LDAP directory """ + """Rename an object in LDAP directory""" # If new_rdn is a complete DN, split new RDN and new superior DN if len(explode_dn(new_rdn)) > 1: self.logger.debug( - "LdapServer - Rename with a full new DN detected (%s): split new RDN and new superior DN", - new_rdn + "LdapServer - Rename with a full new DN detected (%s): split new RDN and new" + " superior DN", + new_rdn, ) - assert new_sup is None, "You can't provide a complete DN as new_rdn and also provide new_sup parameter" + assert ( + new_sup is None + ), "You can't provide a complete DN as new_rdn and also provide new_sup parameter" new_dn_parts = explode_dn(new_rdn) - new_sup = ','.join(new_dn_parts[1:]) + new_sup = ",".join(new_dn_parts[1:]) new_rdn = new_dn_parts[0] assert self.con or self.connect() try: @@ -339,41 +351,40 @@ class LdapServer: dn, new_rdn, "same" if new_sup is None else new_sup, - delete_old + delete_old, ) self.con.rename_s(dn, new_rdn, newsuperior=new_sup, delold=delete_old) return True except ldap.LDAPError as e: # pylint: disable=no-member self._error( - f'LdapServer - Error renaming {dn} in {new_rdn} ' + f"LdapServer - Error renaming {dn} in {new_rdn} " f'(new superior: {"same" if new_sup is None else new_sup}, ' - f'delete old: {delete_old}): {e}', - logging.ERROR + f"delete old: {delete_old}): {e}", + logging.ERROR, ) return False def drop_object(self, dn): - """ Drop an object in LDAP directory """ + """Drop an object in LDAP directory""" assert self.con or self.connect() try: self.logger.debug("LdapServer - Delete %s", dn) self.con.delete_s(dn) return True except ldap.LDAPError as e: # pylint: disable=no-member - self._error( - f'LdapServer - Error deleting {dn}: {e}', logging.ERROR) + self._error(f"LdapServer - Error deleting {dn}: {e}", logging.ERROR) return False @staticmethod def get_dn(obj): - """ Retreive an on object DN from its entry in LDAP search result """ + """Retreive an on object DN from its entry in LDAP search result""" return obj[0][0] @staticmethod def get_attr(obj, attr, all_values=None, default=None, decode=False): - """ Retreive an on object attribute value(s) from the object entry in LDAP search result """ + """Retreive an on object attribute value(s) from the object entry in LDAP search result""" if attr not in obj: for k in obj: if k.lower() == attr.lower(): @@ -389,14 +400,14 @@ class LdapServer: class LdapServerException(BaseException): - """ Generic exception raised by LdapServer """ + """Generic exception raised by LdapServer""" def __init__(self, msg): BaseException.__init__(self, msg) class LdapClientException(LdapServerException): - """ Generic exception raised by LdapServer """ + """Generic exception raised by LdapServer""" def __init__(self, msg): LdapServerException.__init__(self, msg) @@ -404,7 +415,7 @@ class LdapClientException(LdapServerException): class LdapClient: - """ LDAP Client (based on python-mylib.LdapServer) """ + """LDAP Client (based on python-mylib.LdapServer)""" _options = {} _config = None @@ -416,104 +427,112 @@ class LdapClient: # Cache objects _cached_objects = None - def __init__(self, options=None, options_prefix=None, config=None, config_section=None, initialize=False): + def __init__( + self, options=None, options_prefix=None, config=None, config_section=None, initialize=False + ): self._options = options if options else {} - self._options_prefix = options_prefix if options_prefix else 'ldap_' + self._options_prefix = options_prefix if options_prefix else "ldap_" self._config = config if config else None - self._config_section = config_section if config_section else 'ldap' + self._config_section = config_section if config_section else "ldap" self._cached_objects = {} if initialize: self.initialize() def _get_option(self, option, default=None, required=False): - """ Retreive option value """ + """Retreive option value""" if self._options and hasattr(self._options, self._options_prefix + option): return getattr(self._options, self._options_prefix + option) if self._config and self._config.defined(self._config_section, option): return self._config.get(self._config_section, option) - assert not required, f'Options {option} not defined' + assert not required, f"Options {option} not defined" return default @property def _just_try(self): - """ Check if just-try mode is enabled """ + """Check if just-try mode is enabled""" return self._get_option( - 'just_try', default=( - self._config.get_option('just_try') if self._config - else False - ) + "just_try", default=(self._config.get_option("just_try") if self._config else False) ) def configure(self, comment=None, **kwargs): - """ Configure options on registered mylib.Config object """ - assert self._config, "mylib.Config object not registered. Must be passed to __init__ as config keyword argument." + """Configure options on registered mylib.Config object""" + assert self._config, ( + "mylib.Config object not registered. Must be passed to __init__ as config keyword" + " argument." + ) # Load configuration option types only here to avoid global # dependency of ldap module with config one. # pylint: disable=import-outside-toplevel - from mylib.config import BooleanOption, StringOption, PasswordOption + from mylib.config import BooleanOption, PasswordOption, StringOption section = self._config.add_section( self._config_section, - comment=comment if comment else 'LDAP connection', - loaded_callback=self.initialize, **kwargs) + comment=comment if comment else "LDAP connection", + loaded_callback=self.initialize, + **kwargs, + ) section.add_option( - StringOption, 'uri', default='ldap://localhost', - comment='LDAP server URI') + StringOption, "uri", default="ldap://localhost", comment="LDAP server URI" + ) + section.add_option(StringOption, "binddn", comment="LDAP Bind DN") section.add_option( - StringOption, 'binddn', comment='LDAP Bind DN') - section.add_option( - PasswordOption, 'bindpwd', + PasswordOption, + "bindpwd", comment='LDAP Bind password (set to "keyring" to use XDG keyring)', - username_option='binddn', keyring_value='keyring') + username_option="binddn", + keyring_value="keyring", + ) section.add_option( - BooleanOption, 'checkcert', default=True, - comment='Check LDAP certificate') + BooleanOption, "checkcert", default=True, comment="Check LDAP certificate" + ) section.add_option( - BooleanOption, 'disablereferral', default=False, - comment='Disable referral following') + BooleanOption, "disablereferral", default=False, comment="Disable referral following" + ) return section def initialize(self, loaded_config=None): - """ Initialize LDAP connection """ + """Initialize LDAP connection""" if loaded_config: self.config = loaded_config - uri = self._get_option('uri', required=True) - binddn = self._get_option('binddn') - log.info("Connect to LDAP server %s as %s", uri, binddn if binddn else 'annonymous') + uri = self._get_option("uri", required=True) + binddn = self._get_option("binddn") + log.info("Connect to LDAP server %s as %s", uri, binddn if binddn else "annonymous") self._conn = LdapServer( - uri, dn=binddn, pwd=self._get_option('bindpwd'), - checkCert=self._get_option('checkcert'), - disableReferral=self._get_option('disablereferral'), - raiseOnError=True + uri, + dn=binddn, + pwd=self._get_option("bindpwd"), + checkCert=self._get_option("checkcert"), + disableReferral=self._get_option("disablereferral"), + raiseOnError=True, ) # Reset cache self._cached_objects = {} return self._conn.connect() def decode(self, value): - """ Decode LDAP attribute value """ + """Decode LDAP attribute value""" if isinstance(value, list): return [self.decode(v) for v in value] if isinstance(value, str): return value return value.decode( - self._get_option('encoding', default=DEFAULT_ENCODING), - self._get_option('encoding_error_policy', default='ignore') + self._get_option("encoding", default=DEFAULT_ENCODING), + self._get_option("encoding_error_policy", default="ignore"), ) def encode(self, value): - """ Encode LDAP attribute value """ + """Encode LDAP attribute value""" if isinstance(value, list): return [self.encode(v) for v in value] if isinstance(value, bytes): return value - return value.encode(self._get_option('encoding', default=DEFAULT_ENCODING)) + return value.encode(self._get_option("encoding", default=DEFAULT_ENCODING)) def _get_obj(self, dn, attrs): """ @@ -548,8 +567,17 @@ class LdapClient: return vals if all_values else vals[0] return default if default or not all_values else [] - def get_objects(self, name, filterstr, basedn, attrs, key_attr=None, warn=True, - paged_search=False, pagesize=None): + def get_objects( + self, + name, + filterstr, + basedn, + attrs, + key_attr=None, + warn=True, + paged_search=False, + pagesize=None, + ): """ Retrieve objects from LDAP @@ -568,25 +596,28 @@ class LdapClient: (optional, default: see LdapServer.paged_search) """ if name in self._cached_objects: - log.debug('Retreived %s objects from cache', name) + log.debug("Retreived %s objects from cache", name) else: assert self._conn or self.initialize() - log.debug('Looking for LDAP %s with (filter="%s" / basedn="%s")', name, filterstr, basedn) + log.debug( + 'Looking for LDAP %s with (filter="%s" / basedn="%s")', name, filterstr, basedn + ) if paged_search: ldap_data = self._conn.paged_search( - basedn=basedn, filterstr=filterstr, attrs=attrs, - pagesize=pagesize + basedn=basedn, filterstr=filterstr, attrs=attrs, pagesize=pagesize ) else: ldap_data = self._conn.search( - basedn=basedn, filterstr=filterstr, attrs=attrs, + basedn=basedn, + filterstr=filterstr, + attrs=attrs, ) if not ldap_data: if warn: - log.warning('No %s found in LDAP', name) + log.warning("No %s found in LDAP", name) else: - log.debug('No %s found in LDAP', name) + log.debug("No %s found in LDAP", name) return {} objects = {} @@ -596,12 +627,12 @@ class LdapClient: continue objects[obj_dn] = self._get_obj(obj_dn, obj_attrs) self._cached_objects[name] = objects - if not key_attr or key_attr == 'dn': + if not key_attr or key_attr == "dn": return self._cached_objects[name] - return dict( - (self.get_attr(self._cached_objects[name][dn], key_attr), self._cached_objects[name][dn]) + return { + self.get_attr(self._cached_objects[name][dn], key_attr): self._cached_objects[name][dn] for dn in self._cached_objects[name] - ) + } def get_object(self, type_name, object_name, filterstr, basedn, attrs, warn=True): """ @@ -620,11 +651,14 @@ class LdapClient: (optional, default: True) """ assert self._conn or self.initialize() - log.debug('Looking for LDAP %s "%s" with (filter="%s" / basedn="%s")', type_name, object_name, filterstr, basedn) - ldap_data = self._conn.search( - basedn=basedn, filterstr=filterstr, - attrs=attrs + log.debug( + 'Looking for LDAP %s "%s" with (filter="%s" / basedn="%s")', + type_name, + object_name, + filterstr, + basedn, ) + ldap_data = self._conn.search(basedn=basedn, filterstr=filterstr, attrs=attrs) if not ldap_data: if warn: @@ -635,7 +669,8 @@ class LdapClient: if len(ldap_data) > 1: raise LdapClientException( - f'More than one {type_name} "{object_name}": {" / ".join(ldap_data.keys())}') + f'More than one {type_name} "{object_name}": {" / ".join(ldap_data.keys())}' + ) dn = next(iter(ldap_data)) return self._get_obj(dn, ldap_data[dn]) @@ -659,9 +694,9 @@ class LdapClient: populate_cache_method() if type_name not in self._cached_objects: if warn: - log.warning('No %s found in LDAP', type_name) + log.warning("No %s found in LDAP", type_name) else: - log.debug('No %s found in LDAP', type_name) + log.debug("No %s found in LDAP", type_name) return None if dn not in self._cached_objects[type_name]: if warn: @@ -686,7 +721,9 @@ class LdapClient: return value in cls.get_attr(obj, attr, all_values=True) return value.lower() in [v.lower() for v in cls.get_attr(obj, attr, all_values=True)] - def get_object_by_attr(self, type_name, attr, value, populate_cache_method=None, case_sensitive=False, warn=True): + def get_object_by_attr( + self, type_name, attr, value, populate_cache_method=None, case_sensitive=False, warn=True + ): """ Retrieve an LDAP object specified by one of its attribute @@ -708,15 +745,15 @@ class LdapClient: populate_cache_method() if type_name not in self._cached_objects: if warn: - log.warning('No %s found in LDAP', type_name) + log.warning("No %s found in LDAP", type_name) else: - log.debug('No %s found in LDAP', type_name) + log.debug("No %s found in LDAP", type_name) return None - matched = dict( - (dn, obj) + matched = { + dn: obj for dn, obj in self._cached_objects[type_name].items() if self.object_attr_mached(obj, attr, value, case_sensitive=case_sensitive) - ) + } if not matched: if warn: log.warning('No %s found with %s="%s"', type_name, attr, value) @@ -726,7 +763,8 @@ class LdapClient: if len(matched) > 1: raise LdapClientException( f'More than one {type_name} with {attr}="{value}" found: ' - f'{" / ".join(matched.keys())}') + f'{" / ".join(matched.keys())}' + ) dn = next(iter(matched)) return matched[dn] @@ -742,7 +780,7 @@ class LdapClient: old = {} new = {} protected_attrs = [a.lower() for a in protected_attrs or []] - protected_attrs.append('dn') + protected_attrs.append("dn") # New/updated attributes for attr in attrs: if protected_attrs and attr.lower() in protected_attrs: @@ -755,7 +793,11 @@ class LdapClient: # Deleted attributes for attr in ldap_obj: - if (not protected_attrs or attr.lower() not in protected_attrs) and ldap_obj[attr] and attr not in attrs: + if ( + (not protected_attrs or attr.lower() not in protected_attrs) + and ldap_obj[attr] + and attr not in attrs + ): old[attr] = self.encode(ldap_obj[attr]) if old == new: return None @@ -771,8 +813,7 @@ class LdapClient: """ assert self._conn or self.initialize() return self._conn.format_changes( - changes[0], changes[1], - ignore_attrs=protected_attrs, prefix=prefix + changes[0], changes[1], ignore_attrs=protected_attrs, prefix=prefix ) def update_need(self, changes, protected_attrs=None): @@ -784,10 +825,7 @@ class LdapClient: if changes is None: return False assert self._conn or self.initialize() - return self._conn.update_need( - changes[0], changes[1], - ignore_attrs=protected_attrs - ) + return self._conn.update_need(changes[0], changes[1], ignore_attrs=protected_attrs) def add_object(self, dn, attrs): """ @@ -796,21 +834,19 @@ class LdapClient: :param dn: The LDAP object DN :param attrs: The LDAP object attributes (as dict) """ - attrs = dict( - (attr, self.encode(values)) - for attr, values in attrs.items() - if attr != 'dn' - ) + attrs = {attr: self.encode(values) for attr, values in attrs.items() if attr != "dn"} try: if self._just_try: - log.debug('Just-try mode : do not really add object in LDAP') + log.debug("Just-try mode : do not really add object in LDAP") return True assert self._conn or self.initialize() return self._conn.add_object(dn, attrs) except LdapServerException: log.error( "An error occurred adding object %s in LDAP:\n%s\n", - dn, pretty_format_dict(attrs), exc_info=True + dn, + pretty_format_dict(attrs), + exc_info=True, ) return False @@ -823,35 +859,42 @@ class LdapClient: :param protected_attrs: An optional list of protected attributes :param rdn_attr: The LDAP object RDN attribute (to detect renaming, default: auto-detected) """ - assert isinstance(changes, (list, tuple)) and len(changes) == 2 and isinstance(changes[0], dict) and isinstance(changes[1], dict), f'changes parameter must be a result of get_changes() method ({type(changes)} given)' + assert ( + isinstance(changes, (list, tuple)) + and len(changes) == 2 + and isinstance(changes[0], dict) + and isinstance(changes[1], dict) + ), f"changes parameter must be a result of get_changes() method ({type(changes)} given)" # In case of RDN change, we need to modify passed changes, copy it to make it unchanged in # this case _changes = copy.deepcopy(changes) if not rdn_attr: - rdn_attr = ldap_obj['dn'].split('=')[0] - log.debug('Auto-detected RDN attribute from DN: %s => %s', ldap_obj['dn'], rdn_attr) + rdn_attr = ldap_obj["dn"].split("=")[0] + log.debug("Auto-detected RDN attribute from DN: %s => %s", ldap_obj["dn"], rdn_attr) old_rdn_values = self.get_attr(_changes[0], rdn_attr, all_values=True) new_rdn_values = self.get_attr(_changes[1], rdn_attr, all_values=True) if old_rdn_values or new_rdn_values: if not new_rdn_values: log.error( "%s : Attribute %s can't be deleted because it's used as RDN.", - ldap_obj['dn'], rdn_attr + ldap_obj["dn"], + rdn_attr, ) return False log.debug( - '%s: Changes detected on %s RDN attribute: must rename object before updating it', - ldap_obj['dn'], rdn_attr + "%s: Changes detected on %s RDN attribute: must rename object before updating it", + ldap_obj["dn"], + rdn_attr, ) # Compute new object DN - dn_parts = explode_dn(self.decode(ldap_obj['dn'])) - basedn = ','.join(dn_parts[1:]) - new_rdn = f'{rdn_attr}={escape_dn_chars(self.decode(new_rdn_values[0]))}' - new_dn = f'{new_rdn},{basedn}' + dn_parts = explode_dn(self.decode(ldap_obj["dn"])) + basedn = ",".join(dn_parts[1:]) + new_rdn = f"{rdn_attr}={escape_dn_chars(self.decode(new_rdn_values[0]))}" + new_dn = f"{new_rdn},{basedn}" # Rename object - log.debug('%s: Rename to %s', ldap_obj['dn'], new_dn) + log.debug("%s: Rename to %s", ldap_obj["dn"], new_dn) if not self.move_object(ldap_obj, new_dn): return False @@ -865,30 +908,29 @@ class LdapClient: # Check that there are other changes if not _changes[0] and not _changes[1]: - log.debug('%s: No other change after renaming', new_dn) + log.debug("%s: No other change after renaming", new_dn) return True # Otherwise, update object DN - ldap_obj['dn'] = new_dn + ldap_obj["dn"] = new_dn else: - log.debug('%s: No change detected on RDN attibute %s', ldap_obj['dn'], rdn_attr) + log.debug("%s: No change detected on RDN attibute %s", ldap_obj["dn"], rdn_attr) try: if self._just_try: - log.debug('Just-try mode : do not really update object in LDAP') + log.debug("Just-try mode : do not really update object in LDAP") return True assert self._conn or self.initialize() return self._conn.update_object( - ldap_obj['dn'], - _changes[0], - _changes[1], - ignore_attrs=protected_attrs + ldap_obj["dn"], _changes[0], _changes[1], ignore_attrs=protected_attrs ) except LdapServerException: log.error( "An error occurred updating object %s in LDAP:\n%s\n -> \n%s\n\n", - ldap_obj['dn'], pretty_format_dict(_changes[0]), pretty_format_dict(_changes[1]), - exc_info=True + ldap_obj["dn"], + pretty_format_dict(_changes[0]), + pretty_format_dict(_changes[1]), + exc_info=True, ) return False @@ -901,14 +943,16 @@ class LdapClient: """ try: if self._just_try: - log.debug('Just-try mode : do not really move object in LDAP') + log.debug("Just-try mode : do not really move object in LDAP") return True assert self._conn or self.initialize() - return self._conn.rename_object(ldap_obj['dn'], new_dn_or_rdn) + return self._conn.rename_object(ldap_obj["dn"], new_dn_or_rdn) except LdapServerException: log.error( "An error occurred moving object %s in LDAP (destination: %s)", - ldap_obj['dn'], new_dn_or_rdn, exc_info=True + ldap_obj["dn"], + new_dn_or_rdn, + exc_info=True, ) return False @@ -920,15 +964,12 @@ class LdapClient: """ try: if self._just_try: - log.debug('Just-try mode : do not really drop object in LDAP') + log.debug("Just-try mode : do not really drop object in LDAP") return True assert self._conn or self.initialize() - return self._conn.drop_object(ldap_obj['dn']) + return self._conn.drop_object(ldap_obj["dn"]) except LdapServerException: - log.error( - "An error occurred removing object %s in LDAP", - ldap_obj['dn'], exc_info=True - ) + log.error("An error occurred removing object %s in LDAP", ldap_obj["dn"], exc_info=True) return False @@ -943,20 +984,29 @@ def parse_datetime(value, to_timezone=None, default_timezone=None, naive=None): :param value: The LDAP date string to convert :param to_timezone: If specified, the return datetime will be converted to this - specific timezone (optional, default : timezone of the LDAP date string) + specific timezone (optional, default : timezone of the LDAP date + string) :param default_timezone: The timezone used if LDAP date string does not specified the timezone (optional, default : server local timezone) - :param naive: Use naive datetime : return naive datetime object (without timezone conversion from LDAP) + :param naive: Use naive datetime : return naive datetime object (without timezone + conversion from LDAP) """ - assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), f'to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})' - assert default_timezone is None or isinstance(default_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), f'default_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not {type(default_timezone)})' + assert to_timezone is None or isinstance( + to_timezone, (datetime.tzinfo, str) + ), f"to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})" + assert default_timezone is None or isinstance( + default_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str) + ), ( + "default_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a" + f" datetime.tzinfo object (not {type(default_timezone)})" + ) date = dateutil.parser.parse(value, dayfirst=False) if not date.tzinfo: if naive: return date if not default_timezone: default_timezone = pytz.utc - elif default_timezone == 'local': + elif default_timezone == "local": default_timezone = dateutil.tz.tzlocal() elif isinstance(default_timezone, str): default_timezone = pytz.timezone(default_timezone) @@ -969,7 +1019,7 @@ def parse_datetime(value, to_timezone=None, default_timezone=None, naive=None): elif naive: return date.replace(tzinfo=None) if to_timezone: - if to_timezone == 'local': + if to_timezone == "local": to_timezone = dateutil.tz.tzlocal() elif isinstance(to_timezone, str): to_timezone = pytz.timezone(to_timezone) @@ -983,7 +1033,8 @@ def parse_date(value, to_timezone=None, default_timezone=None, naive=True): :param value: The LDAP date string to convert :param to_timezone: If specified, the return datetime will be converted to this - specific timezone (optional, default : timezone of the LDAP date string) + specific timezone (optional, default : timezone of the LDAP date + string) :param default_timezone: The timezone used if LDAP date string does not specified the timezone (optional, default : server local timezone) :param naive: Use naive datetime : do not handle timezone conversion from LDAP @@ -999,13 +1050,23 @@ def format_datetime(value, from_timezone=None, to_timezone=None, naive=None): :param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo) (optional, default : server local timezone) :param to_timezone: The timezone used in LDAP (optional, default : UTC) - :param naive: Use naive datetime : datetime store as UTC in LDAP (without conversion) + :param naive: Use naive datetime : datetime store as UTC in LDAP (without + conversion) """ - assert isinstance(value, datetime.datetime), f'First parameter must be an datetime.datetime object (not {type(value)})' - assert from_timezone is None or isinstance(from_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), f'from_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not {type(from_timezone)})' - assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), f'to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})' + assert isinstance( + value, datetime.datetime + ), f"First parameter must be an datetime.datetime object (not {type(value)})" + assert from_timezone is None or isinstance( + from_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str) + ), ( + "from_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a" + f" datetime.tzinfo object (not {type(from_timezone)})" + ) + assert to_timezone is None or isinstance( + to_timezone, (datetime.tzinfo, str) + ), f"to_timezone must be None, a datetime.tzinfo object or a string (not {type(to_timezone)})" if not value.tzinfo and not naive: - if not from_timezone or from_timezone == 'local': + if not from_timezone or from_timezone == "local": from_timezone = dateutil.tz.tzlocal() elif isinstance(from_timezone, str): from_timezone = pytz.timezone(from_timezone) @@ -1021,14 +1082,14 @@ def format_datetime(value, from_timezone=None, to_timezone=None, naive=None): from_value = copy.deepcopy(value) if not to_timezone: to_timezone = pytz.utc - elif to_timezone == 'local': + elif to_timezone == "local": to_timezone = dateutil.tz.tzlocal() elif isinstance(to_timezone, str): to_timezone = pytz.timezone(to_timezone) to_value = from_value.astimezone(to_timezone) if not naive else from_value - datestring = to_value.strftime('%Y%m%d%H%M%S%z') - if datestring.endswith('+0000'): - datestring = datestring.replace('+0000', 'Z') + datestring = to_value.strftime("%Y%m%d%H%M%S%z") + if datestring.endswith("+0000"): + datestring = datestring.replace("+0000", "Z") return datestring @@ -1040,8 +1101,16 @@ def format_date(value, from_timezone=None, to_timezone=None, naive=True): :param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo) (optional, default : server local timezone) :param to_timezone: The timezone used in LDAP (optional, default : UTC) - :param naive: Use naive datetime : do not handle timezone conversion before formating - and return datetime as UTC (because LDAP required a timezone) + :param naive: Use naive datetime : do not handle timezone conversion before + formating and return datetime as UTC (because LDAP required a + timezone) """ - assert isinstance(value, datetime.date), f'First parameter must be an datetime.date object (not {type(value)})' - return format_datetime(datetime.datetime.combine(value, datetime.datetime.min.time()), from_timezone, to_timezone, naive) + assert isinstance( + value, datetime.date + ), f"First parameter must be an datetime.date object (not {type(value)})" + return format_datetime( + datetime.datetime.combine(value, datetime.datetime.min.time()), + from_timezone, + to_timezone, + naive, + ) diff --git a/mylib/mapping.py b/mylib/mapping.py index 19622a3..4ced4e6 100644 --- a/mylib/mapping.py +++ b/mylib/mapping.py @@ -48,19 +48,18 @@ Return format : import logging import re - log = logging.getLogger(__name__) def clean_value(value): - """ Clean value as encoded string """ + """Clean value as encoded string""" if isinstance(value, int): value = str(value) return value def get_values(dst, dst_key, src, m): - """ Extract sources values """ + """Extract sources values""" values = [] if "other_key" in m: if m["other_key"] in dst: diff --git a/mylib/mysql.py b/mylib/mysql.py index 0823fcd..9f5a1cf 100644 --- a/mylib/mysql.py +++ b/mylib/mysql.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ MySQL client """ import logging @@ -8,15 +6,13 @@ import sys import MySQLdb from MySQLdb._exceptions import Error -from mylib.db import DB -from mylib.db import DBFailToConnect - +from mylib.db import DB, DBFailToConnect log = logging.getLogger(__name__) class MyDB(DB): - """ MySQL client """ + """MySQL client""" _host = None _user = None @@ -28,25 +24,33 @@ class MyDB(DB): self._user = user self._pwd = pwd self._db = db - self._charset = charset if charset else 'utf8' + self._charset = charset if charset else "utf8" super().__init__(**kwargs) def connect(self, exit_on_error=True): - """ Connect to MySQL server """ + """Connect to MySQL server""" if self._conn is None: try: self._conn = MySQLdb.connect( - host=self._host, user=self._user, passwd=self._pwd, - db=self._db, charset=self._charset, use_unicode=True) + host=self._host, + user=self._user, + passwd=self._pwd, + db=self._db, + charset=self._charset, + use_unicode=True, + ) except Error as err: log.fatal( - 'An error occured during MySQL database connection (%s@%s:%s).', - self._user, self._host, self._db, exc_info=1 + "An error occured during MySQL database connection (%s@%s:%s).", + self._user, + self._host, + self._db, + exc_info=1, ) if exit_on_error: sys.exit(1) else: - raise DBFailToConnect(f'{self._user}@{self._host}:{self._db}') from err + raise DBFailToConnect(f"{self._user}@{self._host}:{self._db}") from err return True def doSQL(self, sql, params=None): @@ -88,10 +92,7 @@ class MyDB(DB): cursor = self._conn.cursor() cursor.execute(sql, params) return [ - dict( - (field[0], row[idx]) - for idx, field in enumerate(cursor.description) - ) + {field[0]: row[idx] for idx, field in enumerate(cursor.description)} for row in cursor.fetchall() ] except Error: @@ -100,14 +101,12 @@ class MyDB(DB): @staticmethod def _quote_table_name(table): - """ Quote table name """ - return '`{0}`'.format( # pylint: disable=consider-using-f-string - '`.`'.join( - table.split('.') - ) + """Quote table name""" + return "`{}`".format( # pylint: disable=consider-using-f-string + "`.`".join(table.split(".")) ) @staticmethod def _quote_field_name(field): - """ Quote table name """ - return f'`{field}`' + """Quote table name""" + return f"`{field}`" diff --git a/mylib/opening_hours.py b/mylib/opening_hours.py index 228a97f..69d2078 100644 --- a/mylib/opening_hours.py +++ b/mylib/opening_hours.py @@ -1,22 +1,20 @@ -# -*- coding: utf-8 -*- - """ Opening hours helpers """ import datetime +import logging import re import time -import logging log = logging.getLogger(__name__) -week_days = ['lundi', 'mardi', 'mercredi', 'jeudi', 'vendredi', 'samedi', 'dimanche'] -date_format = '%d/%m/%Y' -date_pattern = re.compile('^([0-9]{2})/([0-9]{2})/([0-9]{4})$') -time_pattern = re.compile('^([0-9]{1,2})h([0-9]{2})?$') +week_days = ["lundi", "mardi", "mercredi", "jeudi", "vendredi", "samedi", "dimanche"] +date_format = "%d/%m/%Y" +date_pattern = re.compile("^([0-9]{2})/([0-9]{2})/([0-9]{4})$") +time_pattern = re.compile("^([0-9]{1,2})h([0-9]{2})?$") def easter_date(year): - """ Compute easter date for the specified year """ + """Compute easter date for the specified year""" a = year // 100 b = year % 100 c = (3 * (a + 25)) // 4 @@ -36,30 +34,30 @@ def easter_date(year): def nonworking_french_public_days_of_the_year(year=None): - """ Compute dict of nonworking french public days for the specified year """ + """Compute dict of nonworking french public days for the specified year""" if year is None: year = datetime.date.today().year dp = easter_date(year) return { - '1janvier': datetime.date(year, 1, 1), - 'paques': dp, - 'lundi_paques': (dp + datetime.timedelta(1)), - '1mai': datetime.date(year, 5, 1), - '8mai': datetime.date(year, 5, 8), - 'jeudi_ascension': (dp + datetime.timedelta(39)), - 'pentecote': (dp + datetime.timedelta(49)), - 'lundi_pentecote': (dp + datetime.timedelta(50)), - '14juillet': datetime.date(year, 7, 14), - '15aout': datetime.date(year, 8, 15), - '1novembre': datetime.date(year, 11, 1), - '11novembre': datetime.date(year, 11, 11), - 'noel': datetime.date(year, 12, 25), - 'saint_etienne': datetime.date(year, 12, 26), + "1janvier": datetime.date(year, 1, 1), + "paques": dp, + "lundi_paques": (dp + datetime.timedelta(1)), + "1mai": datetime.date(year, 5, 1), + "8mai": datetime.date(year, 5, 8), + "jeudi_ascension": (dp + datetime.timedelta(39)), + "pentecote": (dp + datetime.timedelta(49)), + "lundi_pentecote": (dp + datetime.timedelta(50)), + "14juillet": datetime.date(year, 7, 14), + "15aout": datetime.date(year, 8, 15), + "1novembre": datetime.date(year, 11, 1), + "11novembre": datetime.date(year, 11, 11), + "noel": datetime.date(year, 12, 25), + "saint_etienne": datetime.date(year, 12, 26), } def parse_exceptional_closures(values): - """ Parse exceptional closures values """ + """Parse exceptional closures values""" exceptional_closures = [] for value in values: days = [] @@ -68,7 +66,7 @@ def parse_exceptional_closures(values): for word in words: if not word: continue - parts = word.split('-') + parts = word.split("-") if len(parts) == 1: # ex: 31/02/2017 ptime = time.strptime(word, date_format) @@ -82,7 +80,7 @@ def parse_exceptional_closures(values): pstart = time.strptime(parts[0], date_format) pstop = time.strptime(parts[1], date_format) if pstop <= pstart: - raise ValueError(f'Day {parts[1]} <= {parts[0]}') + raise ValueError(f"Day {parts[1]} <= {parts[0]}") date = datetime.date(pstart.tm_year, pstart.tm_mon, pstart.tm_mday) stop_date = datetime.date(pstop.tm_year, pstop.tm_mon, pstop.tm_mday) @@ -99,18 +97,18 @@ def parse_exceptional_closures(values): hstart = datetime.time(int(mstart.group(1)), int(mstart.group(2) or 0)) hstop = datetime.time(int(mstop.group(1)), int(mstop.group(2) or 0)) if hstop <= hstart: - raise ValueError(f'Time {parts[1]} <= {parts[0]}') - hours_periods.append({'start': hstart, 'stop': hstop}) + raise ValueError(f"Time {parts[1]} <= {parts[0]}") + hours_periods.append({"start": hstart, "stop": hstop}) else: raise ValueError(f'Invalid number of part in this word: "{word}"') if not days: raise ValueError(f'No days found in value "{value}"') - exceptional_closures.append({'days': days, 'hours_periods': hours_periods}) + exceptional_closures.append({"days": days, "hours_periods": hours_periods}) return exceptional_closures def parse_normal_opening_hours(values): - """ Parse normal opening hours """ + """Parse normal opening hours""" normal_opening_hours = [] for value in values: days = [] @@ -119,7 +117,7 @@ def parse_normal_opening_hours(values): for word in words: if not word: continue - parts = word.split('-') + parts = word.split("-") if len(parts) == 1: # ex: jeudi if word not in week_days: @@ -150,40 +148,51 @@ def parse_normal_opening_hours(values): hstart = datetime.time(int(mstart.group(1)), int(mstart.group(2) or 0)) hstop = datetime.time(int(mstop.group(1)), int(mstop.group(2) or 0)) if hstop <= hstart: - raise ValueError(f'Time {parts[1]} <= {parts[0]}') - hours_periods.append({'start': hstart, 'stop': hstop}) + raise ValueError(f"Time {parts[1]} <= {parts[0]}") + hours_periods.append({"start": hstart, "stop": hstop}) else: raise ValueError(f'Invalid number of part in this word: "{word}"') if not days and not hours_periods: raise ValueError(f'No days or hours period found in this value: "{value}"') - normal_opening_hours.append({'days': days, 'hours_periods': hours_periods}) + normal_opening_hours.append({"days": days, "hours_periods": hours_periods}) return normal_opening_hours def is_closed( - normal_opening_hours_values=None, exceptional_closures_values=None, - nonworking_public_holidays_values=None, exceptional_closure_on_nonworking_public_days=False, - when=None, on_error='raise' + normal_opening_hours_values=None, + exceptional_closures_values=None, + nonworking_public_holidays_values=None, + exceptional_closure_on_nonworking_public_days=False, + when=None, + on_error="raise", ): - """ Check if closed """ + """Check if closed""" if not when: when = datetime.datetime.now() when_date = when.date() when_time = when.time() when_weekday = week_days[when.timetuple().tm_wday] on_error_result = None - if on_error == 'closed': + if on_error == "closed": on_error_result = { - 'closed': True, 'exceptional_closure': False, - 'exceptional_closure_all_day': False} - elif on_error == 'opened': + "closed": True, + "exceptional_closure": False, + "exceptional_closure_all_day": False, + } + elif on_error == "opened": on_error_result = { - 'closed': False, 'exceptional_closure': False, - 'exceptional_closure_all_day': False} + "closed": False, + "exceptional_closure": False, + "exceptional_closure_all_day": False, + } log.debug( "When = %s => date = %s / time = %s / week day = %s", - when, when_date, when_time, when_weekday) + when, + when_date, + when_time, + when_weekday, + ) if nonworking_public_holidays_values: log.debug("Nonworking public holidays: %s", nonworking_public_holidays_values) nonworking_days = nonworking_french_public_days_of_the_year() @@ -191,65 +200,69 @@ def is_closed( if day in nonworking_days and when_date == nonworking_days[day]: log.debug("Non working day: %s", day) return { - 'closed': True, - 'exceptional_closure': exceptional_closure_on_nonworking_public_days, - 'exceptional_closure_all_day': exceptional_closure_on_nonworking_public_days + "closed": True, + "exceptional_closure": exceptional_closure_on_nonworking_public_days, + "exceptional_closure_all_day": exceptional_closure_on_nonworking_public_days, } if exceptional_closures_values: try: exceptional_closures = parse_exceptional_closures(exceptional_closures_values) - log.debug('Exceptional closures: %s', exceptional_closures) + log.debug("Exceptional closures: %s", exceptional_closures) except ValueError as e: log.error("Fail to parse exceptional closures, consider as closed", exc_info=True) if on_error_result is None: raise e from e return on_error_result for cl in exceptional_closures: - if when_date not in cl['days']: - log.debug("when_date (%s) no in days (%s)", when_date, cl['days']) + if when_date not in cl["days"]: + log.debug("when_date (%s) no in days (%s)", when_date, cl["days"]) continue - if not cl['hours_periods']: + if not cl["hours_periods"]: # All day exceptional closure return { - 'closed': True, 'exceptional_closure': True, - 'exceptional_closure_all_day': True} - for hp in cl['hours_periods']: - if hp['start'] <= when_time <= hp['stop']: + "closed": True, + "exceptional_closure": True, + "exceptional_closure_all_day": True, + } + for hp in cl["hours_periods"]: + if hp["start"] <= when_time <= hp["stop"]: return { - 'closed': True, 'exceptional_closure': True, - 'exceptional_closure_all_day': False} + "closed": True, + "exceptional_closure": True, + "exceptional_closure_all_day": False, + } if normal_opening_hours_values: try: normal_opening_hours = parse_normal_opening_hours(normal_opening_hours_values) - log.debug('Normal opening hours: %s', normal_opening_hours) + log.debug("Normal opening hours: %s", normal_opening_hours) except ValueError as e: # pylint: disable=broad-except log.error("Fail to parse normal opening hours, consider as closed", exc_info=True) if on_error_result is None: raise e from e return on_error_result for oh in normal_opening_hours: - if oh['days'] and when_weekday not in oh['days']: - log.debug("when_weekday (%s) no in days (%s)", when_weekday, oh['days']) + if oh["days"] and when_weekday not in oh["days"]: + log.debug("when_weekday (%s) no in days (%s)", when_weekday, oh["days"]) continue - if not oh['hours_periods']: + if not oh["hours_periods"]: # All day opened return { - 'closed': False, 'exceptional_closure': False, - 'exceptional_closure_all_day': False} - for hp in oh['hours_periods']: - if hp['start'] <= when_time <= hp['stop']: + "closed": False, + "exceptional_closure": False, + "exceptional_closure_all_day": False, + } + for hp in oh["hours_periods"]: + if hp["start"] <= when_time <= hp["stop"]: return { - 'closed': False, 'exceptional_closure': False, - 'exceptional_closure_all_day': False} + "closed": False, + "exceptional_closure": False, + "exceptional_closure_all_day": False, + } log.debug("Not in normal opening hours => closed") - return { - 'closed': True, 'exceptional_closure': False, - 'exceptional_closure_all_day': False} + return {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False} # Not a nonworking day, not during exceptional closure and no normal opening # hours defined => Opened - return { - 'closed': False, 'exceptional_closure': False, - 'exceptional_closure_all_day': False} + return {"closed": False, "exceptional_closure": False, "exceptional_closure_all_day": False} diff --git a/mylib/oracle.py b/mylib/oracle.py index 12feebc..6559ea9 100644 --- a/mylib/oracle.py +++ b/mylib/oracle.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ Oracle client """ import logging @@ -7,14 +5,13 @@ import sys import cx_Oracle -from mylib.db import DB -from mylib.db import DBFailToConnect +from mylib.db import DB, DBFailToConnect log = logging.getLogger(__name__) class OracleDB(DB): - """ Oracle client """ + """Oracle client""" _dsn = None _user = None @@ -27,24 +24,22 @@ class OracleDB(DB): super().__init__(**kwargs) def connect(self, exit_on_error=True): - """ Connect to Oracle server """ + """Connect to Oracle server""" if self._conn is None: - log.info('Connect on Oracle server with DSN %s as %s', self._dsn, self._user) + log.info("Connect on Oracle server with DSN %s as %s", self._dsn, self._user) try: - self._conn = cx_Oracle.connect( - user=self._user, - password=self._pwd, - dsn=self._dsn - ) + self._conn = cx_Oracle.connect(user=self._user, password=self._pwd, dsn=self._dsn) except cx_Oracle.Error as err: log.fatal( - 'An error occured during Oracle database connection (%s@%s).', - self._user, self._dsn, exc_info=1 + "An error occured during Oracle database connection (%s@%s).", + self._user, + self._dsn, + exc_info=1, ) if exit_on_error: sys.exit(1) else: - raise DBFailToConnect(f'{self._user}@{self._dsn}') from err + raise DBFailToConnect(f"{self._user}@{self._dsn}") from err return True def doSQL(self, sql, params=None): @@ -107,5 +102,5 @@ class OracleDB(DB): @staticmethod def format_param(param): - """ Format SQL query parameter for prepared query """ - return f':{param}' + """Format SQL query parameter for prepared query""" + return f":{param}" diff --git a/mylib/pbar.py b/mylib/pbar.py index 77efa1d..f060b8b 100644 --- a/mylib/pbar.py +++ b/mylib/pbar.py @@ -1,10 +1,8 @@ -# coding: utf8 - """ Progress bar """ import logging -import progressbar +import progressbar log = logging.getLogger(__name__) @@ -25,15 +23,15 @@ class Pbar: # pylint: disable=useless-object-inheritance self.__count = 0 self.__pbar = progressbar.ProgressBar( widgets=[ - name + ': ', + name + ": ", progressbar.Percentage(), - ' ', + " ", progressbar.Bar(), - ' ', + " ", progressbar.SimpleProgress(), - progressbar.ETA() + progressbar.ETA(), ], - maxval=maxval + maxval=maxval, ).start() else: log.info(name) @@ -49,6 +47,6 @@ class Pbar: # pylint: disable=useless-object-inheritance self.__pbar.update(self.__count) def finish(self): - """ Finish the progress bar """ + """Finish the progress bar""" if self.__pbar: self.__pbar.finish() diff --git a/mylib/pgsql.py b/mylib/pgsql.py index 882e8a2..c960318 100644 --- a/mylib/pgsql.py +++ b/mylib/pgsql.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ PostgreSQL client """ import datetime @@ -14,15 +12,15 @@ log = logging.getLogger(__name__) class PgDB(DB): - """ PostgreSQL client """ + """PostgreSQL client""" _host = None _user = None _pwd = None _db = None - date_format = '%Y-%m-%d' - datetime_format = '%Y-%m-%d %H:%M:%S' + date_format = "%Y-%m-%d" + datetime_format = "%Y-%m-%d %H:%M:%S" def __init__(self, host, user, pwd, db, **kwargs): self._host = host @@ -32,37 +30,40 @@ class PgDB(DB): super().__init__(**kwargs) def connect(self, exit_on_error=True): - """ Connect to PostgreSQL server """ + """Connect to PostgreSQL server""" if self._conn is None: try: log.info( - 'Connect on PostgreSQL server %s as %s on database %s', - self._host, self._user, self._db) + "Connect on PostgreSQL server %s as %s on database %s", + self._host, + self._user, + self._db, + ) self._conn = psycopg2.connect( - dbname=self._db, - user=self._user, - host=self._host, - password=self._pwd + dbname=self._db, user=self._user, host=self._host, password=self._pwd ) except psycopg2.Error as err: log.fatal( - 'An error occured during Postgresql database connection (%s@%s, database=%s).', - self._user, self._host, self._db, exc_info=1 + "An error occured during Postgresql database connection (%s@%s, database=%s).", + self._user, + self._host, + self._db, + exc_info=1, ) if exit_on_error: sys.exit(1) else: - raise DBFailToConnect(f'{self._user}@{self._host}:{self._db}') from err + raise DBFailToConnect(f"{self._user}@{self._host}:{self._db}") from err return True def close(self): - """ Close connection with PostgreSQL server (if opened) """ + """Close connection with PostgreSQL server (if opened)""" if self._conn: self._conn.close() self._conn = None def setEncoding(self, enc): - """ Set connection encoding """ + """Set connection encoding""" if self._conn: try: self._conn.set_client_encoding(enc) @@ -70,7 +71,8 @@ class PgDB(DB): except psycopg2.Error: log.error( 'An error occured setting Postgresql database connection encoding to "%s"', - enc, exc_info=1 + enc, + exc_info=1, ) return False @@ -124,10 +126,7 @@ class PgDB(DB): @staticmethod def _map_row_fields_by_index(fields, row): - return dict( - (field, row[idx]) - for idx, field in enumerate(fields) - ) + return {field: row[idx] for idx, field in enumerate(fields)} # # Depreated helpers @@ -135,9 +134,9 @@ class PgDB(DB): @classmethod def _quote_value(cls, value): - """ Quote a value for SQL query """ + """Quote a value for SQL query""" if value is None: - return 'NULL' + return "NULL" if isinstance(value, (int, float)): return str(value) @@ -148,26 +147,26 @@ class PgDB(DB): value = cls._format_date(value) # pylint: disable=consider-using-f-string - return "'{0}'".format(value.replace("'", "''")) + return "'{}'".format(value.replace("'", "''")) @classmethod def _format_datetime(cls, value): - """ Format datetime object as string """ + """Format datetime object as string""" assert isinstance(value, datetime.datetime) return value.strftime(cls.datetime_format) @classmethod def _format_date(cls, value): - """ Format date object as string """ + """Format date object as string""" assert isinstance(value, (datetime.date, datetime.datetime)) return value.strftime(cls.date_format) @classmethod def time2datetime(cls, time): - """ Convert timestamp to datetime string """ + """Convert timestamp to datetime string""" return cls._format_datetime(datetime.datetime.fromtimestamp(int(time))) @classmethod def time2date(cls, time): - """ Convert timestamp to date string """ + """Convert timestamp to date string""" return cls._format_date(datetime.date.fromtimestamp(int(time))) diff --git a/mylib/report.py b/mylib/report.py index ec3d676..ede6517 100644 --- a/mylib/report.py +++ b/mylib/report.py @@ -1,28 +1,24 @@ -# coding: utf8 - """ Report """ import atexit import logging -from mylib.config import ConfigurableObject -from mylib.config import StringOption +from mylib.config import ConfigurableObject, StringOption from mylib.email import EmailClient - log = logging.getLogger(__name__) class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance - """ Logging report """ + """Logging report""" - _config_name = 'report' - _config_comment = 'Email report' + _config_name = "report" + _config_comment = "Email report" _defaults = { - 'recipient': None, - 'subject': 'Report', - 'loglevel': 'WARNING', - 'logformat': '%(asctime)s - %(levelname)s - %(message)s', + "recipient": None, + "subject": "Report", + "loglevel": "WARNING", + "logformat": "%(asctime)s - %(levelname)s - %(message)s", } content = [] @@ -40,20 +36,28 @@ class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance self.initialize() def configure(self, **kwargs): # pylint: disable=arguments-differ - """ Configure options on registered mylib.Config object """ + """Configure options on registered mylib.Config object""" section = super().configure(**kwargs) + section.add_option(StringOption, "recipient", comment="Report recipient email address") section.add_option( - StringOption, 'recipient', comment='Report recipient email address') + StringOption, + "subject", + default=self._defaults["subject"], + comment="Report email subject", + ) section.add_option( - StringOption, 'subject', default=self._defaults['subject'], - comment='Report email subject') + StringOption, + "loglevel", + default=self._defaults["loglevel"], + comment='Report log level (as accept by python logging, for instance "INFO")', + ) section.add_option( - StringOption, 'loglevel', default=self._defaults['loglevel'], - comment='Report log level (as accept by python logging, for instance "INFO")') - section.add_option( - StringOption, 'logformat', default=self._defaults['logformat'], - comment='Report log level (as accept by python logging, for instance "INFO")') + StringOption, + "logformat", + default=self._defaults["logformat"], + comment='Report log level (as accept by python logging, for instance "INFO")', + ) if not self.email_client: self.email_client = EmailClient(config=self._config) @@ -62,66 +66,70 @@ class Report(ConfigurableObject): # pylint: disable=useless-object-inheritance return section def initialize(self, loaded_config=None): - """ Configuration initialized hook """ + """Configuration initialized hook""" super().initialize(loaded_config=loaded_config) self.handler = logging.StreamHandler(self) - loglevel = self._get_option('loglevel').upper() - assert hasattr(logging, loglevel), ( - f'Invalid report loglevel {loglevel}') + loglevel = self._get_option("loglevel").upper() + assert hasattr(logging, loglevel), f"Invalid report loglevel {loglevel}" self.handler.setLevel(getattr(logging, loglevel)) - self.formatter = logging.Formatter(self._get_option('logformat')) + self.formatter = logging.Formatter(self._get_option("logformat")) self.handler.setFormatter(self.formatter) def get_handler(self): - """ Retreive logging handler """ + """Retreive logging handler""" return self.handler def write(self, msg): - """ Write a message """ + """Write a message""" self.content.append(msg) def get_content(self): - """ Read the report content """ + """Read the report content""" return "".join(self.content) def add_attachment_file(self, filepath): - """ Add attachment file """ + """Add attachment file""" self._attachment_files.append(filepath) def add_attachment_payload(self, payload): - """ Add attachment payload """ + """Add attachment payload""" self._attachment_payloads.append(payload) def send(self, subject=None, rcpt_to=None, email_client=None, just_try=False): - """ Send report using an EmailClient """ + """Send report using an EmailClient""" if rcpt_to is None: - rcpt_to = self._get_option('recipient') + rcpt_to = self._get_option("recipient") if not rcpt_to: - log.debug('No report recipient, do not send report') + log.debug("No report recipient, do not send report") return True if subject is None: - subject = self._get_option('subject') + subject = self._get_option("subject") assert subject, "You must provide report subject using Report.__init__ or Report.send" if email_client is None: email_client = self.email_client assert email_client, ( - "You must provide an email client __init__(), send() or send_at_exit() methods argument email_client") + "You must provide an email client __init__(), send() or send_at_exit() methods argument" + " email_client" + ) content = self.get_content() if not content: - log.debug('Report is empty, do not send it') + log.debug("Report is empty, do not send it") return True msg = email_client.forge_message( - rcpt_to, subject=subject, text_body=content, + rcpt_to, + subject=subject, + text_body=content, attachment_files=self._attachment_files, - attachment_payloads=self._attachment_payloads) + attachment_payloads=self._attachment_payloads, + ) if email_client.send(rcpt_to, msg=msg, just_try=just_try): - log.debug('Report sent to %s', rcpt_to) + log.debug("Report sent to %s", rcpt_to) return True - log.error('Fail to send report to %s', rcpt_to) + log.error("Fail to send report to %s", rcpt_to) return False def send_at_exit(self, **kwargs): - """ Send report at exit """ + """Send report at exit""" atexit.register(self.send, **kwargs) diff --git a/mylib/scripts/email_test.py b/mylib/scripts/email_test.py index affb953..44bef4f 100644 --- a/mylib/scripts/email_test.py +++ b/mylib/scripts/email_test.py @@ -1,22 +1,18 @@ -# -*- coding: utf-8 -*- - """ Test Email client """ import datetime +import getpass import logging import sys -import getpass from mako.template import Template as MakoTemplate -from mylib.scripts.helpers import get_opts_parser, add_email_opts -from mylib.scripts.helpers import init_logging, init_email_client +from mylib.scripts.helpers import add_email_opts, get_opts_parser, init_email_client, init_logging - -log = logging.getLogger('mylib.scripts.email_test') +log = logging.getLogger("mylib.scripts.email_test") def main(argv=None): # pylint: disable=too-many-locals,too-many-statements - """ Script main """ + """Script main""" if argv is None: argv = sys.argv[1:] @@ -24,10 +20,11 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements parser = get_opts_parser(just_try=True) add_email_opts(parser) - test_opts = parser.add_argument_group('Test email options') + test_opts = parser.add_argument_group("Test email options") test_opts.add_argument( - '-t', '--to', + "-t", + "--to", action="store", type=str, dest="test_to", @@ -35,7 +32,8 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements ) test_opts.add_argument( - '-m', '--mako', + "-m", + "--mako", action="store_true", dest="test_mako", help="Test mako templating", @@ -44,14 +42,14 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements options = parser.parse_args() if not options.test_to: - parser.error('You must specify test email recipient using -t/--to parameter') + parser.error("You must specify test email recipient using -t/--to parameter") sys.exit(1) # Initialize logs - init_logging(options, 'Test EmailClient') + init_logging(options, "Test EmailClient") if options.email_smtp_user and not options.email_smtp_password: - options.email_smtp_password = getpass.getpass('Please enter SMTP password: ') + options.email_smtp_password = getpass.getpass("Please enter SMTP password: ") email_client = init_email_client( options, @@ -59,20 +57,24 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements test=dict( subject="Test email", text=( - "Just a test email sent at {sent_date}." if not options.test_mako else - MakoTemplate("Just a test email sent at ${sent_date}.") + "Just a test email sent at {sent_date}." + if not options.test_mako + else MakoTemplate("Just a test email sent at ${sent_date}.") ), html=( - "Just a test email. (sent at {sent_date})" if not options.test_mako else - MakoTemplate("Just a test email. (sent at ${sent_date})") - ) + "Just a test email. (sent at {sent_date})" + if not options.test_mako + else MakoTemplate( + "Just a test email. (sent at ${sent_date})" + ) + ), ) - ) + ), ) - log.info('Send a test email to %s', options.test_to) - if email_client.send(options.test_to, template='test', sent_date=datetime.datetime.now()): - log.info('Test email sent') + log.info("Send a test email to %s", options.test_to) + if email_client.send(options.test_to, template="test", sent_date=datetime.datetime.now()): + log.info("Test email sent") sys.exit(0) - log.error('Fail to send test email') + log.error("Fail to send test email") sys.exit(1) diff --git a/mylib/scripts/email_test_with_config.py b/mylib/scripts/email_test_with_config.py index 639127b..9e966d3 100644 --- a/mylib/scripts/email_test_with_config.py +++ b/mylib/scripts/email_test_with_config.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ Test Email client using mylib.config.Config for configuration """ import datetime import logging @@ -10,16 +8,15 @@ from mako.template import Template as MakoTemplate from mylib.config import Config from mylib.email import EmailClient - log = logging.getLogger(__name__) def main(argv=None): # pylint: disable=too-many-locals,too-many-statements - """ Script main """ + """Script main""" if argv is None: argv = sys.argv[1:] - config = Config(__doc__, __name__.replace('.', '_')) + config = Config(__doc__, __name__.replace(".", "_")) email_client = EmailClient(config=config) email_client.configure() @@ -27,10 +24,11 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements # Options parser parser = config.get_arguments_parser(description=__doc__) - test_opts = parser.add_argument_group('Test email options') + test_opts = parser.add_argument_group("Test email options") test_opts.add_argument( - '-t', '--to', + "-t", + "--to", action="store", type=str, dest="test_to", @@ -38,7 +36,8 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements ) test_opts.add_argument( - '-m', '--mako', + "-m", + "--mako", action="store_true", dest="test_mako", help="Test mako templating", @@ -47,26 +46,30 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements options = config.parse_arguments_options() if not options.test_to: - parser.error('You must specify test email recipient using -t/--to parameter') + parser.error("You must specify test email recipient using -t/--to parameter") sys.exit(1) email_client.templates = dict( test=dict( subject="Test email", text=( - "Just a test email sent at {sent_date}." if not options.test_mako else - MakoTemplate("Just a test email sent at ${sent_date}.") + "Just a test email sent at {sent_date}." + if not options.test_mako + else MakoTemplate("Just a test email sent at ${sent_date}.") ), html=( - "Just a test email. (sent at {sent_date})" if not options.test_mako else - MakoTemplate("Just a test email. (sent at ${sent_date})") - ) + "Just a test email. (sent at {sent_date})" + if not options.test_mako + else MakoTemplate( + "Just a test email. (sent at ${sent_date})" + ) + ), ) ) - logging.info('Send a test email to %s', options.test_to) - if email_client.send(options.test_to, template='test', sent_date=datetime.datetime.now()): - logging.info('Test email sent') + logging.info("Send a test email to %s", options.test_to) + if email_client.send(options.test_to, template="test", sent_date=datetime.datetime.now()): + logging.info("Test email sent") sys.exit(0) - logging.error('Fail to send test email') + logging.error("Fail to send test email") sys.exit(1) diff --git a/mylib/scripts/helpers.py b/mylib/scripts/helpers.py index b3b6b60..e12a00d 100644 --- a/mylib/scripts/helpers.py +++ b/mylib/scripts/helpers.py @@ -1,20 +1,18 @@ -# coding: utf8 - """ Scripts helpers """ import argparse import getpass import logging +import os.path import socket import sys -import os.path log = logging.getLogger(__name__) def init_logging(options, name, report=None): - """ Initialize logging from calling script options """ - logformat = f'%(asctime)s - {name} - %(levelname)s - %(message)s' + """Initialize logging from calling script options""" + logformat = f"%(asctime)s - {name} - %(levelname)s - %(message)s" if options.debug: loglevel = logging.DEBUG elif options.verbose: @@ -33,193 +31,201 @@ def init_logging(options, name, report=None): def get_default_opt_value(config, default_config, key): - """ Retreive default option value from config or default config dictionaries """ + """Retreive default option value from config or default config dictionaries""" if config and key in config: return config[key] return default_config.get(key) def get_opts_parser(desc=None, just_try=False, just_one=False, progress=False, config=None): - """ Retrieve options parser """ + """Retrieve options parser""" default_config = dict(logfile=None) parser = argparse.ArgumentParser(description=desc) parser.add_argument( - '-v', '--verbose', - action="store_true", - dest="verbose", - help="Enable verbose mode" + "-v", "--verbose", action="store_true", dest="verbose", help="Enable verbose mode" ) parser.add_argument( - '-d', '--debug', - action="store_true", - dest="debug", - help="Enable debug mode" + "-d", "--debug", action="store_true", dest="debug", help="Enable debug mode" ) parser.add_argument( - '-l', '--log-file', + "-l", + "--log-file", action="store", type=str, dest="logfile", - help=( - 'Log file path (default: ' - f'{get_default_opt_value(config, default_config, "logfile")})'), - default=get_default_opt_value(config, default_config, 'logfile') + help=f'Log file path (default: {get_default_opt_value(config, default_config, "logfile")})', + default=get_default_opt_value(config, default_config, "logfile"), ) parser.add_argument( - '-C', '--console', + "-C", + "--console", action="store_true", dest="console", - help="Always log on console (even if log file is configured)" + help="Always log on console (even if log file is configured)", ) if just_try: parser.add_argument( - '-j', '--just-try', - action="store_true", - dest="just_try", - help="Enable just-try mode" + "-j", "--just-try", action="store_true", dest="just_try", help="Enable just-try mode" ) if just_one: parser.add_argument( - '-J', '--just-one', - action="store_true", - dest="just_one", - help="Enable just-one mode" + "-J", "--just-one", action="store_true", dest="just_one", help="Enable just-one mode" ) if progress: parser.add_argument( - '-p', '--progress', - action="store_true", - dest="progress", - help="Enable progress bar" + "-p", "--progress", action="store_true", dest="progress", help="Enable progress bar" ) return parser def add_email_opts(parser, config=None): - """ Add email options """ - email_opts = parser.add_argument_group('Email options') + """Add email options""" + email_opts = parser.add_argument_group("Email options") default_config = dict( - smtp_host="127.0.0.1", smtp_port=25, smtp_ssl=False, smtp_tls=False, smtp_user=None, - smtp_password=None, smtp_debug=False, email_encoding=sys.getdefaultencoding(), - sender_name=getpass.getuser(), sender_email=f'{getpass.getuser()}@{socket.gethostname()}', - catch_all=False + smtp_host="127.0.0.1", + smtp_port=25, + smtp_ssl=False, + smtp_tls=False, + smtp_user=None, + smtp_password=None, + smtp_debug=False, + email_encoding=sys.getdefaultencoding(), + sender_name=getpass.getuser(), + sender_email=f"{getpass.getuser()}@{socket.gethostname()}", + catch_all=False, ) email_opts.add_argument( - '--smtp-host', + "--smtp-host", action="store", type=str, dest="email_smtp_host", - help=( - 'SMTP host (default: ' - f'{get_default_opt_value(config, default_config, "smtp_host")})'), - default=get_default_opt_value(config, default_config, 'smtp_host') + help=f'SMTP host (default: {get_default_opt_value(config, default_config, "smtp_host")})', + default=get_default_opt_value(config, default_config, "smtp_host"), ) email_opts.add_argument( - '--smtp-port', + "--smtp-port", action="store", type=int, dest="email_smtp_port", help=f'SMTP port (default: {get_default_opt_value(config, default_config, "smtp_port")})', - default=get_default_opt_value(config, default_config, 'smtp_port') + default=get_default_opt_value(config, default_config, "smtp_port"), ) email_opts.add_argument( - '--smtp-ssl', + "--smtp-ssl", action="store_true", dest="email_smtp_ssl", help=f'Use SSL (default: {get_default_opt_value(config, default_config, "smtp_ssl")})', - default=get_default_opt_value(config, default_config, 'smtp_ssl') + default=get_default_opt_value(config, default_config, "smtp_ssl"), ) email_opts.add_argument( - '--smtp-tls', + "--smtp-tls", action="store_true", dest="email_smtp_tls", help=f'Use TLS (default: {get_default_opt_value(config, default_config, "smtp_tls")})', - default=get_default_opt_value(config, default_config, 'smtp_tls') + default=get_default_opt_value(config, default_config, "smtp_tls"), ) email_opts.add_argument( - '--smtp-user', + "--smtp-user", action="store", type=str, dest="email_smtp_user", - help=f'SMTP username (default: {get_default_opt_value(config, default_config, "smtp_user")})', - default=get_default_opt_value(config, default_config, 'smtp_user') + help=( + f'SMTP username (default: {get_default_opt_value(config, default_config, "smtp_user")})' + ), + default=get_default_opt_value(config, default_config, "smtp_user"), ) email_opts.add_argument( - '--smtp-password', + "--smtp-password", action="store", type=str, dest="email_smtp_password", - help=f'SMTP password (default: {get_default_opt_value(config, default_config, "smtp_password")})', - default=get_default_opt_value(config, default_config, 'smtp_password') + help=( + "SMTP password (default:" + f' {get_default_opt_value(config, default_config, "smtp_password")})' + ), + default=get_default_opt_value(config, default_config, "smtp_password"), ) email_opts.add_argument( - '--smtp-debug', + "--smtp-debug", action="store_true", dest="email_smtp_debug", - help=f'Debug SMTP connection (default: {get_default_opt_value(config, default_config, "smtp_debug")})', - default=get_default_opt_value(config, default_config, 'smtp_debug') + help=( + "Debug SMTP connection (default:" + f' {get_default_opt_value(config, default_config, "smtp_debug")})' + ), + default=get_default_opt_value(config, default_config, "smtp_debug"), ) email_opts.add_argument( - '--email-encoding', + "--email-encoding", action="store", type=str, dest="email_encoding", - help=f'SMTP encoding (default: {get_default_opt_value(config, default_config, "email_encoding")})', - default=get_default_opt_value(config, default_config, 'email_encoding') + help=( + "SMTP encoding (default:" + f' {get_default_opt_value(config, default_config, "email_encoding")})' + ), + default=get_default_opt_value(config, default_config, "email_encoding"), ) email_opts.add_argument( - '--sender-name', + "--sender-name", action="store", type=str, dest="email_sender_name", - help=f'Sender name (default: {get_default_opt_value(config, default_config, "sender_name")})', - default=get_default_opt_value(config, default_config, 'sender_name') + help=( + f'Sender name (default: {get_default_opt_value(config, default_config, "sender_name")})' + ), + default=get_default_opt_value(config, default_config, "sender_name"), ) email_opts.add_argument( - '--sender-email', + "--sender-email", action="store", type=str, dest="email_sender_email", - help=f'Sender email (default: {get_default_opt_value(config, default_config, "sender_email")})', - default=get_default_opt_value(config, default_config, 'sender_email') + help=( + "Sender email (default:" + f' {get_default_opt_value(config, default_config, "sender_email")})' + ), + default=get_default_opt_value(config, default_config, "sender_email"), ) email_opts.add_argument( - '--catch-all', + "--catch-all", action="store", type=str, dest="email_catch_all", help=( - 'Catch all sent email: specify catch recipient email address ' - f'(default: {get_default_opt_value(config, default_config, "catch_all")})'), - default=get_default_opt_value(config, default_config, 'catch_all') + "Catch all sent email: specify catch recipient email address " + f'(default: {get_default_opt_value(config, default_config, "catch_all")})' + ), + default=get_default_opt_value(config, default_config, "catch_all"), ) def init_email_client(options, **kwargs): - """ Initialize email client from calling script options """ + """Initialize email client from calling script options""" from mylib.email import EmailClient # pylint: disable=import-outside-toplevel - log.info('Initialize Email client') + + log.info("Initialize Email client") return EmailClient( smtp_host=options.email_smtp_host, smtp_port=options.email_smtp_port, @@ -231,64 +237,62 @@ def init_email_client(options, **kwargs): sender_name=options.email_sender_name, sender_email=options.email_sender_email, catch_all_addr=options.email_catch_all, - just_try=options.just_try if hasattr(options, 'just_try') else False, + just_try=options.just_try if hasattr(options, "just_try") else False, encoding=options.email_encoding, - **kwargs + **kwargs, ) def add_sftp_opts(parser): - """ Add SFTP options to argpase.ArgumentParser """ + """Add SFTP options to argpase.ArgumentParser""" sftp_opts = parser.add_argument_group("SFTP options") sftp_opts.add_argument( - '-H', '--sftp-host', + "-H", + "--sftp-host", action="store", type=str, dest="sftp_host", help="SFTP Host (default: localhost)", - default='localhost' + default="localhost", ) sftp_opts.add_argument( - '--sftp-port', + "--sftp-port", action="store", type=int, dest="sftp_port", help="SFTP Port (default: 22)", - default=22 + default=22, ) sftp_opts.add_argument( - '-u', '--sftp-user', - action="store", - type=str, - dest="sftp_user", - help="SFTP User" + "-u", "--sftp-user", action="store", type=str, dest="sftp_user", help="SFTP User" ) sftp_opts.add_argument( - '-P', '--sftp-password', + "-P", + "--sftp-password", action="store", type=str, dest="sftp_password", - help="SFTP Password" + help="SFTP Password", ) sftp_opts.add_argument( - '--sftp-known-hosts', + "--sftp-known-hosts", action="store", type=str, dest="sftp_known_hosts", help="SFTP known_hosts file path (default: ~/.ssh/known_hosts)", - default=os.path.expanduser('~/.ssh/known_hosts') + default=os.path.expanduser("~/.ssh/known_hosts"), ) sftp_opts.add_argument( - '--sftp-auto-add-unknown-host-key', + "--sftp-auto-add-unknown-host-key", action="store_true", dest="sftp_auto_add_unknown_host_key", - help="Auto-add unknown SSH host key" + help="Auto-add unknown SSH host key", ) return sftp_opts diff --git a/mylib/scripts/ldap_test.py b/mylib/scripts/ldap_test.py index 6be1185..d1cf521 100644 --- a/mylib/scripts/ldap_test.py +++ b/mylib/scripts/ldap_test.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ Test LDAP """ import datetime import logging @@ -8,16 +6,14 @@ import sys import dateutil.tz import pytz -from mylib.ldap import format_datetime, format_date, parse_datetime, parse_date -from mylib.scripts.helpers import get_opts_parser -from mylib.scripts.helpers import init_logging +from mylib.ldap import format_date, format_datetime, parse_date, parse_datetime +from mylib.scripts.helpers import get_opts_parser, init_logging - -log = logging.getLogger('mylib.scripts.ldap_test') +log = logging.getLogger("mylib.scripts.ldap_test") def main(argv=None): # pylint: disable=too-many-locals,too-many-statements - """ Script main """ + """Script main""" if argv is None: argv = sys.argv[1:] @@ -26,52 +22,121 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements options = parser.parse_args() # Initialize logs - init_logging(options, 'Test LDAP helpers') + init_logging(options, "Test LDAP helpers") now = datetime.datetime.now().replace(tzinfo=dateutil.tz.tzlocal()) - print(f'Now = {now}') + print(f"Now = {now}") datestring_now = format_datetime(now) - print(f'format_datetime : {datestring_now}') - print(f'format_datetime (from_timezone=utc) : {format_datetime(now.replace(tzinfo=None), from_timezone=pytz.utc)}') - print(f'format_datetime (from_timezone=local) : {format_datetime(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}') - print(f'format_datetime (from_timezone=local) : {format_datetime(now.replace(tzinfo=None), from_timezone="local")}') - print(f'format_datetime (from_timezone=Paris) : {format_datetime(now.replace(tzinfo=None), from_timezone="Europe/Paris")}') - print(f'format_datetime (to_timezone=utc) : {format_datetime(now, to_timezone=pytz.utc)}') - print(f'format_datetime (to_timezone=local) : {format_datetime(now, to_timezone=dateutil.tz.tzlocal())}') + print(f"format_datetime : {datestring_now}") + print( + "format_datetime (from_timezone=utc) :" + f" {format_datetime(now.replace(tzinfo=None), from_timezone=pytz.utc)}" + ) + print( + "format_datetime (from_timezone=local) :" + f" {format_datetime(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}" + ) + print( + "format_datetime (from_timezone=local) :" + f' {format_datetime(now.replace(tzinfo=None), from_timezone="local")}' + ) + print( + "format_datetime (from_timezone=Paris) :" + f' {format_datetime(now.replace(tzinfo=None), from_timezone="Europe/Paris")}' + ) + print(f"format_datetime (to_timezone=utc) : {format_datetime(now, to_timezone=pytz.utc)}") + print( + "format_datetime (to_timezone=local) :" + f" {format_datetime(now, to_timezone=dateutil.tz.tzlocal())}" + ) print(f'format_datetime (to_timezone=local) : {format_datetime(now, to_timezone="local")}') print(f'format_datetime (to_timezone=Tokyo) : {format_datetime(now, to_timezone="Asia/Tokyo")}') - print(f'format_datetime (naive=True) : {format_datetime(now, naive=True)}') + print(f"format_datetime (naive=True) : {format_datetime(now, naive=True)}") - print(f'format_date : {format_date(now)}') - print(f'format_date (from_timezone=utc) : {format_date(now.replace(tzinfo=None), from_timezone=pytz.utc)}') - print(f'format_date (from_timezone=local) : {format_date(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}') - print(f'format_date (from_timezone=local) : {format_date(now.replace(tzinfo=None), from_timezone="local")}') - print(f'format_date (from_timezone=Paris) : {format_date(now.replace(tzinfo=None), from_timezone="Europe/Paris")}') - print(f'format_date (to_timezone=utc) : {format_date(now, to_timezone=pytz.utc)}') - print(f'format_date (to_timezone=local) : {format_date(now, to_timezone=dateutil.tz.tzlocal())}') + print(f"format_date : {format_date(now)}") + print( + "format_date (from_timezone=utc) :" + f" {format_date(now.replace(tzinfo=None), from_timezone=pytz.utc)}" + ) + print( + "format_date (from_timezone=local) :" + f" {format_date(now.replace(tzinfo=None), from_timezone=dateutil.tz.tzlocal())}" + ) + print( + "format_date (from_timezone=local) :" + f' {format_date(now.replace(tzinfo=None), from_timezone="local")}' + ) + print( + "format_date (from_timezone=Paris) :" + f' {format_date(now.replace(tzinfo=None), from_timezone="Europe/Paris")}' + ) + print(f"format_date (to_timezone=utc) : {format_date(now, to_timezone=pytz.utc)}") + print( + f"format_date (to_timezone=local) : {format_date(now, to_timezone=dateutil.tz.tzlocal())}" + ) print(f'format_date (to_timezone=local) : {format_date(now, to_timezone="local")}') print(f'format_date (to_timezone=Tokyo) : {format_date(now, to_timezone="Asia/Tokyo")}') - print(f'format_date (naive=True) : {format_date(now, naive=True)}') + print(f"format_date (naive=True) : {format_date(now, naive=True)}") - print(f'parse_datetime : {parse_datetime(datestring_now)}') - print(f'parse_datetime (default_timezone=utc) : {parse_datetime(datestring_now[0:-1], default_timezone=pytz.utc)}') - print(f'parse_datetime (default_timezone=local) : {parse_datetime(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}') - print(f'parse_datetime (default_timezone=local) : {parse_datetime(datestring_now[0:-1], default_timezone="local")}') - print(f'parse_datetime (default_timezone=Paris) : {parse_datetime(datestring_now[0:-1], default_timezone="Europe/Paris")}') - print(f'parse_datetime (to_timezone=utc) : {parse_datetime(datestring_now, to_timezone=pytz.utc)}') - print(f'parse_datetime (to_timezone=local) : {parse_datetime(datestring_now, to_timezone=dateutil.tz.tzlocal())}') - print(f'parse_datetime (to_timezone=local) : {parse_datetime(datestring_now, to_timezone="local")}') - print(f'parse_datetime (to_timezone=Tokyo) : {parse_datetime(datestring_now, to_timezone="Asia/Tokyo")}') - print(f'parse_datetime (naive=True) : {parse_datetime(datestring_now, naive=True)}') + print(f"parse_datetime : {parse_datetime(datestring_now)}") + print( + "parse_datetime (default_timezone=utc) :" + f" {parse_datetime(datestring_now[0:-1], default_timezone=pytz.utc)}" + ) + print( + "parse_datetime (default_timezone=local) :" + f" {parse_datetime(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}" + ) + print( + "parse_datetime (default_timezone=local) :" + f' {parse_datetime(datestring_now[0:-1], default_timezone="local")}' + ) + print( + "parse_datetime (default_timezone=Paris) :" + f' {parse_datetime(datestring_now[0:-1], default_timezone="Europe/Paris")}' + ) + print( + f"parse_datetime (to_timezone=utc) : {parse_datetime(datestring_now, to_timezone=pytz.utc)}" + ) + print( + "parse_datetime (to_timezone=local) :" + f" {parse_datetime(datestring_now, to_timezone=dateutil.tz.tzlocal())}" + ) + print( + "parse_datetime (to_timezone=local) :" + f' {parse_datetime(datestring_now, to_timezone="local")}' + ) + print( + "parse_datetime (to_timezone=Tokyo) :" + f' {parse_datetime(datestring_now, to_timezone="Asia/Tokyo")}' + ) + print(f"parse_datetime (naive=True) : {parse_datetime(datestring_now, naive=True)}") - print(f'parse_date : {parse_date(datestring_now)}') - print(f'parse_date (default_timezone=utc) : {parse_date(datestring_now[0:-1], default_timezone=pytz.utc)}') - print(f'parse_date (default_timezone=local) : {parse_date(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}') - print(f'parse_date (default_timezone=local) : {parse_date(datestring_now[0:-1], default_timezone="local")}') - print(f'parse_date (default_timezone=Paris) : {parse_date(datestring_now[0:-1], default_timezone="Europe/Paris")}') - print(f'parse_date (to_timezone=utc) : {parse_date(datestring_now, to_timezone=pytz.utc)}') - print(f'parse_date (to_timezone=local) : {parse_date(datestring_now, to_timezone=dateutil.tz.tzlocal())}') + print(f"parse_date : {parse_date(datestring_now)}") + print( + "parse_date (default_timezone=utc) :" + f" {parse_date(datestring_now[0:-1], default_timezone=pytz.utc)}" + ) + print( + "parse_date (default_timezone=local) :" + f" {parse_date(datestring_now[0:-1], default_timezone=dateutil.tz.tzlocal())}" + ) + print( + "parse_date (default_timezone=local) :" + f' {parse_date(datestring_now[0:-1], default_timezone="local")}' + ) + print( + "parse_date (default_timezone=Paris) :" + f' {parse_date(datestring_now[0:-1], default_timezone="Europe/Paris")}' + ) + print(f"parse_date (to_timezone=utc) : {parse_date(datestring_now, to_timezone=pytz.utc)}") + print( + "parse_date (to_timezone=local) :" + f" {parse_date(datestring_now, to_timezone=dateutil.tz.tzlocal())}" + ) print(f'parse_date (to_timezone=local) : {parse_date(datestring_now, to_timezone="local")}') - print(f'parse_date (to_timezone=Tokyo) : {parse_date(datestring_now, to_timezone="Asia/Tokyo")}') - print(f'parse_date (naive=True) : {parse_date(datestring_now, naive=True)}') + print( + f'parse_date (to_timezone=Tokyo) : {parse_date(datestring_now, to_timezone="Asia/Tokyo")}' + ) + print(f"parse_date (naive=True) : {parse_date(datestring_now, naive=True)}") diff --git a/mylib/scripts/map_test.py b/mylib/scripts/map_test.py index 28c27de..3fe4bdc 100644 --- a/mylib/scripts/map_test.py +++ b/mylib/scripts/map_test.py @@ -64,6 +64,6 @@ def main(argv=None): "mail": {"order": 12, "key": "email", "convert": lambda x: x.lower().strip()}, } - print('Mapping source:\n' + pretty_format_value(src)) - print('Mapping config:\n' + pretty_format_value(map_c)) - print('Mapping result:\n' + pretty_format_value(map_hash(map_c, src))) + print("Mapping source:\n" + pretty_format_value(src)) + print("Mapping config:\n" + pretty_format_value(map_c)) + print("Mapping result:\n" + pretty_format_value(map_hash(map_c, src))) diff --git a/mylib/scripts/pbar_test.py b/mylib/scripts/pbar_test.py index c1d6f27..2267514 100644 --- a/mylib/scripts/pbar_test.py +++ b/mylib/scripts/pbar_test.py @@ -1,20 +1,16 @@ -# -*- coding: utf-8 -*- - """ Test Progress bar """ import logging -import time import sys +import time from mylib.pbar import Pbar -from mylib.scripts.helpers import get_opts_parser -from mylib.scripts.helpers import init_logging +from mylib.scripts.helpers import get_opts_parser, init_logging - -log = logging.getLogger('mylib.scripts.pbar_test') +log = logging.getLogger("mylib.scripts.pbar_test") def main(argv=None): # pylint: disable=too-many-locals,too-many-statements - """ Script main """ + """Script main""" if argv is None: argv = sys.argv[1:] @@ -23,20 +19,21 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements parser = get_opts_parser(progress=True) parser.add_argument( - '-c', '--count', + "-c", + "--count", action="store", type=int, dest="count", - help=f'Progress bar max value (default: {default_max_val})', - default=default_max_val + help=f"Progress bar max value (default: {default_max_val})", + default=default_max_val, ) options = parser.parse_args() # Initialize logs - init_logging(options, 'Test Pbar') + init_logging(options, "Test Pbar") - pbar = Pbar('Test', options.count, enabled=options.progress) + pbar = Pbar("Test", options.count, enabled=options.progress) for idx in range(0, options.count): # pylint: disable=unused-variable pbar.increment() diff --git a/mylib/scripts/report_test.py b/mylib/scripts/report_test.py index a65c27a..36540fd 100644 --- a/mylib/scripts/report_test.py +++ b/mylib/scripts/report_test.py @@ -1,19 +1,15 @@ -# -*- coding: utf-8 -*- - """ Test report """ import logging import sys from mylib.report import Report -from mylib.scripts.helpers import get_opts_parser, add_email_opts -from mylib.scripts.helpers import init_logging, init_email_client +from mylib.scripts.helpers import add_email_opts, get_opts_parser, init_email_client, init_logging - -log = logging.getLogger('mylib.scripts.report_test') +log = logging.getLogger("mylib.scripts.report_test") def main(argv=None): # pylint: disable=too-many-locals,too-many-statements - """ Script main """ + """Script main""" if argv is None: argv = sys.argv[1:] @@ -21,14 +17,10 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements parser = get_opts_parser(just_try=True) add_email_opts(parser) - report_opts = parser.add_argument_group('Report options') + report_opts = parser.add_argument_group("Report options") report_opts.add_argument( - '-t', '--to', - action="store", - type=str, - dest="report_rcpt", - help="Send report to this email" + "-t", "--to", action="store", type=str, dest="report_rcpt", help="Send report to this email" ) options = parser.parse_args() @@ -37,13 +29,13 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements parser.error("You must specify a report recipient using -t/--to parameter") # Initialize logs - report = Report(rcpt_to=options.report_rcpt, subject='Test report') - init_logging(options, 'Test Report', report=report) + report = Report(rcpt_to=options.report_rcpt, subject="Test report") + init_logging(options, "Test Report", report=report) email_client = init_email_client(options) report.send_at_exit(email_client=email_client) - logging.debug('Test debug message') - logging.info('Test info message') - logging.warning('Test warning message') - logging.error('Test error message') + logging.debug("Test debug message") + logging.info("Test info message") + logging.warning("Test warning message") + logging.error("Test error message") diff --git a/mylib/scripts/sftp_test.py b/mylib/scripts/sftp_test.py index 1e10933..69067f5 100644 --- a/mylib/scripts/sftp_test.py +++ b/mylib/scripts/sftp_test.py @@ -1,26 +1,21 @@ -# -*- coding: utf-8 -*- - """ Test SFTP client """ import atexit -import tempfile +import getpass import logging -import sys import os import random import string +import sys +import tempfile -import getpass - +from mylib.scripts.helpers import add_sftp_opts, get_opts_parser, init_logging from mylib.sftp import SFTPClient -from mylib.scripts.helpers import get_opts_parser, add_sftp_opts -from mylib.scripts.helpers import init_logging - -log = logging.getLogger('mylib.scripts.sftp_test') +log = logging.getLogger("mylib.scripts.sftp_test") def main(argv=None): # pylint: disable=too-many-locals,too-many-statements - """ Script main """ + """Script main""" if argv is None: argv = sys.argv[1:] @@ -28,10 +23,11 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements parser = get_opts_parser(just_try=True) add_sftp_opts(parser) - test_opts = parser.add_argument_group('Test SFTP options') + test_opts = parser.add_argument_group("Test SFTP options") test_opts.add_argument( - '-p', '--remote-upload-path', + "-p", + "--remote-upload-path", action="store", type=str, dest="upload_path", @@ -41,66 +37,68 @@ def main(argv=None): # pylint: disable=too-many-locals,too-many-statements options = parser.parse_args() # Initialize logs - init_logging(options, 'Test SFTP client') + init_logging(options, "Test SFTP client") if options.sftp_user and not options.sftp_password: - options.sftp_password = getpass.getpass('Please enter SFTP password: ') + options.sftp_password = getpass.getpass("Please enter SFTP password: ") - log.info('Initialize Email client') + log.info("Initialize Email client") sftp = SFTPClient(options=options, just_try=options.just_try) sftp.connect() atexit.register(sftp.close) - log.debug('Create tempory file') - test_content = b'Juste un test.' + log.debug("Create tempory file") + test_content = b"Juste un test." tmp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with tmp_file = os.path.join( - tmp_dir.name, - f'tmp{"".join(random.choice(string.ascii_lowercase) for i in range(8))}' + tmp_dir.name, f'tmp{"".join(random.choice(string.ascii_lowercase) for i in range(8))}' ) log.debug('Temporary file path: "%s"', tmp_file) - with open(tmp_file, 'wb') as file_desc: + with open(tmp_file, "wb") as file_desc: file_desc.write(test_content) log.debug( - 'Upload file %s to SFTP server (in %s)', tmp_file, - options.upload_path if options.upload_path else "remote initial connection directory") + "Upload file %s to SFTP server (in %s)", + tmp_file, + options.upload_path if options.upload_path else "remote initial connection directory", + ) if not sftp.upload_file(tmp_file, options.upload_path): - log.error('Fail to upload test file on SFTP server') + log.error("Fail to upload test file on SFTP server") sys.exit(1) - log.info('Test file uploaded on SFTP server') + log.info("Test file uploaded on SFTP server") remote_filepath = ( os.path.join(options.upload_path, os.path.basename(tmp_file)) - if options.upload_path else os.path.basename(tmp_file) + if options.upload_path + else os.path.basename(tmp_file) ) with tempfile.NamedTemporaryFile() as tmp_file2: - log.info('Retrieve test file to %s', tmp_file2.name) + log.info("Retrieve test file to %s", tmp_file2.name) if not sftp.get_file(remote_filepath, tmp_file2.name): - log.error('Fail to retrieve test file') + log.error("Fail to retrieve test file") else: - with open(tmp_file2.name, 'rb') as file_desc: + with open(tmp_file2.name, "rb") as file_desc: content = file_desc.read() - log.debug('Read content: %s', content) + log.debug("Read content: %s", content) if test_content == content: - log.info('Content file retrieved match with uploaded one') + log.info("Content file retrieved match with uploaded one") else: - log.error('Content file retrieved doest not match with uploaded one') + log.error("Content file retrieved doest not match with uploaded one") try: - log.info('Remotly open test file %s', remote_filepath) + log.info("Remotly open test file %s", remote_filepath) file_desc = sftp.open_file(remote_filepath) content = file_desc.read() - log.debug('Read content: %s', content) + log.debug("Read content: %s", content) if test_content == content: - log.info('Content of remote file match with uploaded one') + log.info("Content of remote file match with uploaded one") else: - log.error('Content of remote file doest not match with uploaded one') + log.error("Content of remote file doest not match with uploaded one") except Exception: # pylint: disable=broad-except - log.exception('An exception occurred remotly opening test file %s', remote_filepath) + log.exception("An exception occurred remotly opening test file %s", remote_filepath) if sftp.remove_file(remote_filepath): - log.info('Test file removed on SFTP server') + log.info("Test file removed on SFTP server") else: - log.error('Fail to remove test file on SFTP server') + log.error("Fail to remove test file on SFTP server") diff --git a/mylib/sftp.py b/mylib/sftp.py index b32912f..3b98567 100644 --- a/mylib/sftp.py +++ b/mylib/sftp.py @@ -1,17 +1,17 @@ -# -*- coding: utf-8 -*- - """ SFTP client """ import logging import os -from paramiko import SSHClient, AutoAddPolicy, SFTPAttributes +from paramiko import AutoAddPolicy, SFTPAttributes, SSHClient -from mylib.config import ConfigurableObject -from mylib.config import BooleanOption -from mylib.config import IntegerOption -from mylib.config import PasswordOption -from mylib.config import StringOption +from mylib.config import ( + BooleanOption, + ConfigurableObject, + IntegerOption, + PasswordOption, + StringOption, +) log = logging.getLogger(__name__) @@ -23,16 +23,16 @@ class SFTPClient(ConfigurableObject): This class abstract all interactions with the SFTP server. """ - _config_name = 'sftp' - _config_comment = 'SFTP' + _config_name = "sftp" + _config_comment = "SFTP" _defaults = { - 'host': 'localhost', - 'port': 22, - 'user': None, - 'password': None, - 'known_hosts': os.path.expanduser('~/.ssh/known_hosts'), - 'auto_add_unknown_host_key': False, - 'just_try': False, + "host": "localhost", + "port": 22, + "user": None, + "password": None, + "known_hosts": os.path.expanduser("~/.ssh/known_hosts"), + "auto_add_unknown_host_key": False, + "just_try": False, } ssh_client = None @@ -41,58 +41,77 @@ class SFTPClient(ConfigurableObject): # pylint: disable=arguments-differ,arguments-renamed def configure(self, just_try=True, **kwargs): - """ Configure options on registered mylib.Config object """ + """Configure options on registered mylib.Config object""" section = super().configure(**kwargs) section.add_option( - StringOption, 'host', default=self._defaults['host'], - comment='SFTP server hostname/IP address') + StringOption, + "host", + default=self._defaults["host"], + comment="SFTP server hostname/IP address", + ) section.add_option( - IntegerOption, 'port', default=self._defaults['port'], - comment='SFTP server port') + IntegerOption, "port", default=self._defaults["port"], comment="SFTP server port" + ) section.add_option( - StringOption, 'user', default=self._defaults['user'], - comment='SFTP authentication username') + StringOption, + "user", + default=self._defaults["user"], + comment="SFTP authentication username", + ) section.add_option( - PasswordOption, 'password', default=self._defaults['password'], + PasswordOption, + "password", + default=self._defaults["password"], comment='SFTP authentication password (set to "keyring" to use XDG keyring)', - username_option='user', keyring_value='keyring') + username_option="user", + keyring_value="keyring", + ) section.add_option( - StringOption, 'known_hosts', default=self._defaults['known_hosts'], - comment='SFTP known_hosts filepath') + StringOption, + "known_hosts", + default=self._defaults["known_hosts"], + comment="SFTP known_hosts filepath", + ) section.add_option( - BooleanOption, 'auto_add_unknown_host_key', - default=self._defaults['auto_add_unknown_host_key'], - comment='Auto add unknown host key') + BooleanOption, + "auto_add_unknown_host_key", + default=self._defaults["auto_add_unknown_host_key"], + comment="Auto add unknown host key", + ) if just_try: section.add_option( - BooleanOption, 'just_try', default=self._defaults['just_try'], - comment='Just-try mode: do not really make change on remote SFTP host') + BooleanOption, + "just_try", + default=self._defaults["just_try"], + comment="Just-try mode: do not really make change on remote SFTP host", + ) return section def initialize(self, loaded_config=None): - """ Configuration initialized hook """ + """Configuration initialized hook""" super().__init__(loaded_config=loaded_config) def connect(self): - """ Connect to SFTP server """ + """Connect to SFTP server""" if self.ssh_client: return - host = self._get_option('host') - port = self._get_option('port') + host = self._get_option("host") + port = self._get_option("port") log.info("Connect to SFTP server %s:%d", host, port) self.ssh_client = SSHClient() - if self._get_option('known_hosts'): - self.ssh_client.load_host_keys(self._get_option('known_hosts')) - if self._get_option('auto_add_unknown_host_key'): - log.debug('Set missing host key policy to auto-add') + if self._get_option("known_hosts"): + self.ssh_client.load_host_keys(self._get_option("known_hosts")) + if self._get_option("auto_add_unknown_host_key"): + log.debug("Set missing host key policy to auto-add") self.ssh_client.set_missing_host_key_policy(AutoAddPolicy()) self.ssh_client.connect( - host, port=port, - username=self._get_option('user'), - password=self._get_option('password') + host, + port=port, + username=self._get_option("user"), + password=self._get_option("password"), ) self.sftp_client = self.ssh_client.open_sftp() self.initial_directory = self.sftp_client.getcwd() @@ -103,43 +122,43 @@ class SFTPClient(ConfigurableObject): self.initial_directory = "" def get_file(self, remote_filepath, local_filepath): - """ Retrieve a file from SFTP server """ + """Retrieve a file from SFTP server""" self.connect() log.debug("Retreive file '%s' to '%s'", remote_filepath, local_filepath) return self.sftp_client.get(remote_filepath, local_filepath) is None - def open_file(self, remote_filepath, mode='r'): - """ Remotly open a file on SFTP server """ + def open_file(self, remote_filepath, mode="r"): + """Remotly open a file on SFTP server""" self.connect() log.debug("Remotly open file '%s'", remote_filepath) return self.sftp_client.open(remote_filepath, mode=mode) def upload_file(self, filepath, remote_directory=None): - """ Upload a file on SFTP server """ + """Upload a file on SFTP server""" self.connect() remote_filepath = os.path.join( remote_directory if remote_directory else self.initial_directory, - os.path.basename(filepath) + os.path.basename(filepath), ) log.debug("Upload file '%s' to '%s'", filepath, remote_filepath) - if self._get_option('just_try'): + if self._get_option("just_try"): log.debug( - "Just-try mode: do not really upload file '%s' to '%s'", - filepath, remote_filepath) + "Just-try mode: do not really upload file '%s' to '%s'", filepath, remote_filepath + ) return True result = self.sftp_client.put(filepath, remote_filepath) return isinstance(result, SFTPAttributes) def remove_file(self, filepath): - """ Remove a file on SFTP server """ + """Remove a file on SFTP server""" self.connect() log.debug("Remove file '%s'", filepath) - if self._get_option('just_try'): + if self._get_option("just_try"): log.debug("Just - try mode: do not really remove file '%s'", filepath) return True return self.sftp_client.remove(filepath) is None def close(self): - """ Close SSH/SFTP connection """ + """Close SSH/SFTP connection""" log.debug("Close connection") self.ssh_client.close() diff --git a/mylib/telltale.py b/mylib/telltale.py index b29148c..75e5c9c 100644 --- a/mylib/telltale.py +++ b/mylib/telltale.py @@ -8,45 +8,43 @@ log = logging.getLogger(__name__) class TelltaleFile: - """ Telltale file helper class """ + """Telltale file helper class""" def __init__(self, filepath=None, filename=None, dirpath=None): assert filepath or filename, "filename or filepath is required" if filepath: - assert not filename or os.path.basename(filepath) == filename, "filepath and filename does not match" - assert not dirpath or os.path.dirname(filepath) == dirpath, "filepath and dirpath does not match" + assert ( + not filename or os.path.basename(filepath) == filename + ), "filepath and filename does not match" + assert ( + not dirpath or os.path.dirname(filepath) == dirpath + ), "filepath and dirpath does not match" self.filename = filename if filename else os.path.basename(filepath) self.dirpath = ( - dirpath if dirpath - else ( - os.path.dirname(filepath) if filepath - else os.getcwd() - ) + dirpath if dirpath else (os.path.dirname(filepath) if filepath else os.getcwd()) ) self.filepath = filepath if filepath else os.path.join(self.dirpath, self.filename) @property def last_update(self): - """ Retreive last update datetime of the telltall file """ + """Retreive last update datetime of the telltall file""" try: - return datetime.datetime.fromtimestamp( - os.stat(self.filepath).st_mtime - ) + return datetime.datetime.fromtimestamp(os.stat(self.filepath).st_mtime) except FileNotFoundError: - log.info('Telltale file not found (%s)', self.filepath) + log.info("Telltale file not found (%s)", self.filepath) return None def update(self): - """ Update the telltale file """ - log.info('Update telltale file (%s)', self.filepath) + """Update the telltale file""" + log.info("Update telltale file (%s)", self.filepath) try: os.utime(self.filepath, None) except FileNotFoundError: # pylint: disable=consider-using-with - open(self.filepath, 'a', encoding="utf-8").close() + open(self.filepath, "a", encoding="utf-8").close() def remove(self): - """ Remove the telltale file """ + """Remove the telltale file""" try: os.remove(self.filepath) return True diff --git a/setup.cfg b/setup.cfg index a1cce8e..1899e22 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,3 @@ [flake8] ignore = E501,W503 +max-line-length = 100 diff --git a/setup.py b/setup.py index 42aaf7f..a900f98 100644 --- a/setup.py +++ b/setup.py @@ -1,76 +1,74 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- - -from setuptools import find_packages -from setuptools import setup +"""Setuptools script""" +from setuptools import find_packages, setup extras_require = { - 'dev': [ - 'pytest', - 'mocker', - 'pytest-mock', - 'pylint', - 'flake8', + "dev": [ + "pytest", + "mocker", + "pytest-mock", + "pylint == 2.15.10", + "pre-commit", ], - 'config': [ - 'argcomplete', - 'keyring', - 'systemd-python', + "config": [ + "argcomplete", + "keyring", + "systemd-python", ], - 'ldap': [ - 'python-ldap', - 'python-dateutil', - 'pytz', + "ldap": [ + "python-ldap", + "python-dateutil", + "pytz", ], - 'email': [ - 'mako', + "email": [ + "mako", ], - 'pgsql': [ - 'psycopg2', + "pgsql": [ + "psycopg2", ], - 'oracle': [ - 'cx_Oracle', + "oracle": [ + "cx_Oracle", ], - 'mysql': [ - 'mysqlclient', + "mysql": [ + "mysqlclient", ], - 'sftp': [ - 'paramiko', + "sftp": [ + "paramiko", ], } -install_requires = ['progressbar'] +install_requires = ["progressbar"] for extra, deps in extras_require.items(): - if extra != 'dev': + if extra != "dev": install_requires.extend(deps) -version = '0.1' +version = "0.1" setup( name="mylib", version=version, - description='A set of helpers small libs to make common tasks easier in my script development', + description="A set of helpers small libs to make common tasks easier in my script development", classifiers=[ - 'Programming Language :: Python', + "Programming Language :: Python", ], install_requires=install_requires, extras_require=extras_require, - author='Benjamin Renard', - author_email='brenard@zionetrix.net', - url='https://gogs.zionetrix.net/bn8/python-mylib', + author="Benjamin Renard", + author_email="brenard@zionetrix.net", + url="https://gogs.zionetrix.net/bn8/python-mylib", packages=find_packages(), include_package_data=True, zip_safe=False, entry_points={ - 'console_scripts': [ - 'mylib-test-email = mylib.scripts.email_test:main', - 'mylib-test-email-with-config = mylib.scripts.email_test_with_config:main', - 'mylib-test-map = mylib.scripts.map_test:main', - 'mylib-test-pbar = mylib.scripts.pbar_test:main', - 'mylib-test-report = mylib.scripts.report_test:main', - 'mylib-test-ldap = mylib.scripts.ldap_test:main', - 'mylib-test-sftp = mylib.scripts.sftp_test:main', + "console_scripts": [ + "mylib-test-email = mylib.scripts.email_test:main", + "mylib-test-email-with-config = mylib.scripts.email_test_with_config:main", + "mylib-test-map = mylib.scripts.map_test:main", + "mylib-test-pbar = mylib.scripts.pbar_test:main", + "mylib-test-report = mylib.scripts.report_test:main", + "mylib-test-ldap = mylib.scripts.ldap_test:main", + "mylib-test-sftp = mylib.scripts.sftp_test:main", ], }, ) diff --git a/tests.sh b/tests.sh index ac08b88..0deaf4b 100755 --- a/tests.sh +++ b/tests.sh @@ -22,18 +22,11 @@ echo "Install package with dev dependencies using pip..." $VENV/bin/python3 -m pip install -e ".[dev]" $QUIET_ARG RES=0 -# Run tests -$VENV/bin/python3 -m pytest tests -[ $? -ne 0 ] && RES=1 -# Run pylint -echo "Run pylint..." -$VENV/bin/pylint --extension-pkg-whitelist=cx_Oracle mylib tests -[ $? -ne 0 ] && RES=1 - -# Run flake8 -echo "Run flake8..." -$VENV/bin/flake8 mylib tests +# Run pre-commit +echo "Run pre-commit..." +source $VENV/bin/activate +pre-commit run --all-files [ $? -ne 0 ] && RES=1 # Clean temporary venv diff --git a/tests/test_config.py b/tests/test_config.py index 326eb96..7789775 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,31 +2,29 @@ # pylint: disable=global-variable-not-assigned """ Tests on config lib """ +import configparser import logging import os -import configparser import pytest -from mylib.config import Config, ConfigSection -from mylib.config import BooleanOption -from mylib.config import StringOption +from mylib.config import BooleanOption, Config, ConfigSection, StringOption runned = {} def test_config_init_default_args(): - appname = 'Test app' + appname = "Test app" config = Config(appname) assert config.appname == appname - assert config.version == '0.0' - assert config.encoding == 'utf-8' + assert config.version == "0.0" + assert config.encoding == "utf-8" def test_config_init_custom_args(): - appname = 'Test app' - version = '1.43' - encoding = 'ISO-8859-1' + appname = "Test app" + version = "1.43" + encoding = "ISO-8859-1" config = Config(appname, version=version, encoding=encoding) assert config.appname == appname assert config.version == version @@ -34,8 +32,8 @@ def test_config_init_custom_args(): def test_add_section_default_args(): - config = Config('Test app') - name = 'test_section' + config = Config("Test app") + name = "test_section" section = config.add_section(name) assert isinstance(section, ConfigSection) assert config.sections[name] == section @@ -45,9 +43,9 @@ def test_add_section_default_args(): def test_add_section_custom_args(): - config = Config('Test app') - name = 'test_section' - comment = 'Test' + config = Config("Test app") + name = "test_section" + comment = "Test" order = 20 section = config.add_section(name, comment=comment, order=order) assert isinstance(section, ConfigSection) @@ -57,47 +55,47 @@ def test_add_section_custom_args(): def test_add_section_with_callback(): - config = Config('Test app') - name = 'test_section' + config = Config("Test app") + name = "test_section" global runned - runned['test_add_section_with_callback'] = False + runned["test_add_section_with_callback"] = False def test_callback(loaded_config): global runned assert loaded_config == config - assert runned['test_add_section_with_callback'] is False - runned['test_add_section_with_callback'] = True + assert runned["test_add_section_with_callback"] is False + runned["test_add_section_with_callback"] = True section = config.add_section(name, loaded_callback=test_callback) assert isinstance(section, ConfigSection) assert test_callback in config._loaded_callbacks - assert runned['test_add_section_with_callback'] is False + assert runned["test_add_section_with_callback"] is False config.parse_arguments_options(argv=[], create=False) - assert runned['test_add_section_with_callback'] is True + assert runned["test_add_section_with_callback"] is True assert test_callback in config._loaded_callbacks_executed # Try to execute again to verify callback is not runned again config._loaded() def test_add_section_with_callback_already_loaded(): - config = Config('Test app') - name = 'test_section' + config = Config("Test app") + name = "test_section" config.parse_arguments_options(argv=[], create=False) global runned - runned['test_add_section_with_callback_already_loaded'] = False + runned["test_add_section_with_callback_already_loaded"] = False def test_callback(loaded_config): global runned assert loaded_config == config - assert runned['test_add_section_with_callback_already_loaded'] is False - runned['test_add_section_with_callback_already_loaded'] = True + assert runned["test_add_section_with_callback_already_loaded"] is False + runned["test_add_section_with_callback_already_loaded"] = True section = config.add_section(name, loaded_callback=test_callback) assert isinstance(section, ConfigSection) - assert runned['test_add_section_with_callback_already_loaded'] is True + assert runned["test_add_section_with_callback_already_loaded"] is True assert test_callback in config._loaded_callbacks assert test_callback in config._loaded_callbacks_executed # Try to execute again to verify callback is not runned again @@ -105,10 +103,10 @@ def test_add_section_with_callback_already_loaded(): def test_add_option_default_args(): - config = Config('Test app') - section = config.add_section('my_section') + config = Config("Test app") + section = config.add_section("my_section") assert isinstance(section, ConfigSection) - name = 'my_option' + name = "my_option" option = section.add_option(StringOption, name) assert isinstance(option, StringOption) assert name in section.options and section.options[name] == option @@ -124,17 +122,17 @@ def test_add_option_default_args(): def test_add_option_custom_args(): - config = Config('Test app') - section = config.add_section('my_section') + config = Config("Test app") + section = config.add_section("my_section") assert isinstance(section, ConfigSection) - name = 'my_option' + name = "my_option" kwargs = dict( - default='default value', - comment='my comment', + default="default value", + comment="my comment", no_arg=True, - arg='--my-option', - short_arg='-M', - arg_help='My help' + arg="--my-option", + short_arg="-M", + arg_help="My help", ) option = section.add_option(StringOption, name, **kwargs) assert isinstance(option, StringOption) @@ -148,12 +146,12 @@ def test_add_option_custom_args(): def test_defined(): - config = Config('Test app') - section_name = 'my_section' - opt_name = 'my_option' + config = Config("Test app") + section_name = "my_section" + opt_name = "my_option" assert not config.defined(section_name, opt_name) - section = config.add_section('my_section') + section = config.add_section("my_section") assert isinstance(section, ConfigSection) section.add_option(StringOption, opt_name) @@ -161,29 +159,29 @@ def test_defined(): def test_isset(): - config = Config('Test app') - section_name = 'my_section' - opt_name = 'my_option' + config = Config("Test app") + section_name = "my_section" + opt_name = "my_option" assert not config.isset(section_name, opt_name) - section = config.add_section('my_section') + section = config.add_section("my_section") assert isinstance(section, ConfigSection) option = section.add_option(StringOption, opt_name) assert not config.isset(section_name, opt_name) - config.parse_arguments_options(argv=[option.parser_argument_name, 'value'], create=False) + config.parse_arguments_options(argv=[option.parser_argument_name, "value"], create=False) assert config.isset(section_name, opt_name) def test_not_isset(): - config = Config('Test app') - section_name = 'my_section' - opt_name = 'my_option' + config = Config("Test app") + section_name = "my_section" + opt_name = "my_option" assert not config.isset(section_name, opt_name) - section = config.add_section('my_section') + section = config.add_section("my_section") assert isinstance(section, ConfigSection) section.add_option(StringOption, opt_name) @@ -195,11 +193,11 @@ def test_not_isset(): def test_get(): - config = Config('Test app') - section_name = 'my_section' - opt_name = 'my_option' - opt_value = 'value' - section = config.add_section('my_section') + config = Config("Test app") + section_name = "my_section" + opt_name = "my_option" + opt_value = "value" + section = config.add_section("my_section") option = section.add_option(StringOption, opt_name) config.parse_arguments_options(argv=[option.parser_argument_name, opt_value], create=False) @@ -207,11 +205,11 @@ def test_get(): def test_get_default(): - config = Config('Test app') - section_name = 'my_section' - opt_name = 'my_option' - opt_default_value = 'value' - section = config.add_section('my_section') + config = Config("Test app") + section_name = "my_section" + opt_name = "my_option" + opt_default_value = "value" + section = config.add_section("my_section") section.add_option(StringOption, opt_name, default=opt_default_value) config.parse_arguments_options(argv=[], create=False) @@ -219,8 +217,8 @@ def test_get_default(): def test_logging_splited_stdout_stderr(capsys): - config = Config('Test app') - config.parse_arguments_options(argv=['-C', '-v'], create=False) + config = Config("Test app") + config.parse_arguments_options(argv=["-C", "-v"], create=False) info_msg = "[info]" err_msg = "[error]" logging.getLogger().info(info_msg) @@ -239,9 +237,9 @@ def test_logging_splited_stdout_stderr(capsys): @pytest.fixture() def config_with_file(tmpdir): - config = Config('Test app') - config_dir = tmpdir.mkdir('config') - config_file = config_dir.join('config.ini') + config = Config("Test app") + config_dir = tmpdir.mkdir("config") + config_file = config_dir.join("config.ini") config.save(os.path.join(config_file.dirname, config_file.basename)) return config @@ -250,6 +248,7 @@ def generate_mock_input(expected_prompt, input_value): def mock_input(self, prompt): # pylint: disable=unused-argument assert prompt == expected_prompt return input_value + return mock_input @@ -257,10 +256,9 @@ def generate_mock_input(expected_prompt, input_value): def test_boolean_option_from_config(config_with_file): - section = config_with_file.add_section('test') + section = config_with_file.add_section("test") default = True - option = section.add_option( - BooleanOption, 'test_bool', default=default) + option = section.add_option(BooleanOption, "test_bool", default=default) config_with_file.save() option.set(not default) @@ -273,74 +271,76 @@ def test_boolean_option_from_config(config_with_file): def test_boolean_option_ask_value(mocker): - config = Config('Test app') - section = config.add_section('test') - name = 'test_bool' - option = section.add_option( - BooleanOption, name, default=True) + config = Config("Test app") + section = config.add_section("test") + name = "test_bool" + option = section.add_option(BooleanOption, name, default=True) mocker.patch( - 'mylib.config.BooleanOption._get_user_input', - generate_mock_input(f'{name}: [Y/n] ', 'y') + "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "y") ) assert option.ask_value(set_it=False) is True mocker.patch( - 'mylib.config.BooleanOption._get_user_input', - generate_mock_input(f'{name}: [Y/n] ', 'Y') + "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "Y") ) assert option.ask_value(set_it=False) is True mocker.patch( - 'mylib.config.BooleanOption._get_user_input', - generate_mock_input(f'{name}: [Y/n] ', '') + "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "") ) assert option.ask_value(set_it=False) is True mocker.patch( - 'mylib.config.BooleanOption._get_user_input', - generate_mock_input(f'{name}: [Y/n] ', 'n') + "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "n") ) assert option.ask_value(set_it=False) is False mocker.patch( - 'mylib.config.BooleanOption._get_user_input', - generate_mock_input(f'{name}: [Y/n] ', 'N') + "mylib.config.BooleanOption._get_user_input", generate_mock_input(f"{name}: [Y/n] ", "N") ) assert option.ask_value(set_it=False) is False def test_boolean_option_to_config(): - config = Config('Test app') - section = config.add_section('test') + config = Config("Test app") + section = config.add_section("test") default = True - option = section.add_option(BooleanOption, 'test_bool', default=default) - assert option.to_config(True) == 'true' - assert option.to_config(False) == 'false' + option = section.add_option(BooleanOption, "test_bool", default=default) + assert option.to_config(True) == "true" + assert option.to_config(False) == "false" def test_boolean_option_export_to_config(config_with_file): - section = config_with_file.add_section('test') - name = 'test_bool' - comment = 'Test boolean' + section = config_with_file.add_section("test") + name = "test_bool" + comment = "Test boolean" default = True - option = section.add_option( - BooleanOption, name, default=default, comment=comment) + option = section.add_option(BooleanOption, name, default=default, comment=comment) - assert option.export_to_config() == f"""# {comment} + assert ( + option.export_to_config() + == f"""# {comment} # Default: {str(default).lower()} # {name} = """ + ) option.set(not default) - assert option.export_to_config() == f"""# {comment} + assert ( + option.export_to_config() + == f"""# {comment} # Default: {str(default).lower()} {name} = {str(not default).lower()} """ + ) option.set(default) - assert option.export_to_config() == f"""# {comment} + assert ( + option.export_to_config() + == f"""# {comment} # Default: {str(default).lower()} # {name} = """ + ) diff --git a/tests/test_mysql.py b/tests/test_mysql.py index 889071f..9a1a643 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -2,16 +2,17 @@ """ Tests on opening hours helpers """ import pytest - from MySQLdb._exceptions import Error from mylib.mysql import MyDB class FakeMySQLdbCursor: - """ Fake MySQLdb cursor """ + """Fake MySQLdb cursor""" - def __init__(self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception): + def __init__( + self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception + ): self.expected_sql = expected_sql self.expected_params = expected_params self.expected_return = expected_return @@ -20,13 +21,25 @@ class FakeMySQLdbCursor: def execute(self, sql, params=None): if self.expected_exception: - raise Error(f'{self}.execute({sql}, {params}): expected exception') - if self.expected_just_try and not sql.lower().startswith('select '): - assert False, f'{self}.execute({sql}, {params}) may not be executed in just try mode' + raise Error(f"{self}.execute({sql}, {params}): expected exception") + if self.expected_just_try and not sql.lower().startswith("select "): + assert False, f"{self}.execute({sql}, {params}) may not be executed in just try mode" # pylint: disable=consider-using-f-string - assert sql == self.expected_sql, "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (self, sql, self.expected_sql) + assert ( + sql == self.expected_sql + ), "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % ( + self, + sql, + self.expected_sql, + ) # pylint: disable=consider-using-f-string - assert params == self.expected_params, "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (self, params, self.expected_params) + assert ( + params == self.expected_params + ), "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % ( + self, + params, + self.expected_params, + ) return self.expected_return @property @@ -39,21 +52,19 @@ class FakeMySQLdbCursor: def fetchall(self): if isinstance(self.expected_return, list): return ( - list(row.values()) - if isinstance(row, dict) else row - for row in self.expected_return + list(row.values()) if isinstance(row, dict) else row for row in self.expected_return ) return self.expected_return def __repr__(self): return ( - f'FakeMySQLdbCursor({self.expected_sql}, {self.expected_params}, ' - f'{self.expected_return}, {self.expected_just_try})' + f"FakeMySQLdbCursor({self.expected_sql}, {self.expected_params}, " + f"{self.expected_return}, {self.expected_just_try})" ) class FakeMySQLdb: - """ Fake MySQLdb connection """ + """Fake MySQLdb connection""" expected_sql = None expected_params = None @@ -63,11 +74,14 @@ class FakeMySQLdb: just_try = False def __init__(self, **kwargs): - allowed_kwargs = dict(db=str, user=str, passwd=(str, None), host=str, charset=str, use_unicode=bool) + allowed_kwargs = dict( + db=str, user=str, passwd=(str, None), host=str, charset=str, use_unicode=bool + ) for arg, value in kwargs.items(): assert arg in allowed_kwargs, f'Invalid arg {arg}="{value}"' - assert isinstance(value, allowed_kwargs[arg]), \ - f'Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})' + assert isinstance( + value, allowed_kwargs[arg] + ), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})" setattr(self, arg, value) def close(self): @@ -75,9 +89,11 @@ class FakeMySQLdb: def cursor(self): return FakeMySQLdbCursor( - self.expected_sql, self.expected_params, - self.expected_return, self.expected_just_try or self.just_try, - self.expected_exception + self.expected_sql, + self.expected_params, + self.expected_return, + self.expected_just_try or self.just_try, + self.expected_exception, ) def commit(self): @@ -105,19 +121,19 @@ def fake_mysqldb_connect_just_try(**kwargs): @pytest.fixture def test_mydb(): - return MyDB('127.0.0.1', 'user', 'password', 'dbname') + return MyDB("127.0.0.1", "user", "password", "dbname") @pytest.fixture def fake_mydb(mocker): - mocker.patch('MySQLdb.connect', fake_mysqldb_connect) - return MyDB('127.0.0.1', 'user', 'password', 'dbname') + mocker.patch("MySQLdb.connect", fake_mysqldb_connect) + return MyDB("127.0.0.1", "user", "password", "dbname") @pytest.fixture def fake_just_try_mydb(mocker): - mocker.patch('MySQLdb.connect', fake_mysqldb_connect_just_try) - return MyDB('127.0.0.1', 'user', 'password', 'dbname', just_try=True) + mocker.patch("MySQLdb.connect", fake_mysqldb_connect_just_try) + return MyDB("127.0.0.1", "user", "password", "dbname", just_try=True) @pytest.fixture @@ -132,13 +148,22 @@ def fake_connected_just_try_mydb(fake_just_try_mydb): return fake_just_try_mydb -def generate_mock_args(expected_args=(), expected_kwargs={}, expected_return=True): # pylint: disable=dangerous-default-value +def generate_mock_args( + expected_args=(), expected_kwargs={}, expected_return=True +): # pylint: disable=dangerous-default-value def mock_args(*args, **kwargs): # pylint: disable=consider-using-f-string - assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (args, expected_args) + assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % ( + args, + expected_args, + ) # pylint: disable=consider-using-f-string - assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (kwargs, expected_kwargs) + assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % ( + kwargs, + expected_kwargs, + ) return expected_return + return mock_args @@ -146,13 +171,22 @@ def mock_doSQL_just_try(self, sql, params=None): # pylint: disable=unused-argum assert False, "doSQL() may not be executed in just try mode" -def generate_mock_doSQL(expected_sql, expected_params={}, expected_return=True): # pylint: disable=dangerous-default-value +def generate_mock_doSQL( + expected_sql, expected_params={}, expected_return=True +): # pylint: disable=dangerous-default-value def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument # pylint: disable=consider-using-f-string - assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (sql, expected_sql) + assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % ( + sql, + expected_sql, + ) # pylint: disable=consider-using-f-string - assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (params, expected_params) + assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % ( + params, + expected_params, + ) return expected_return + return mock_doSQL @@ -166,15 +200,11 @@ mock_doSelect_just_try = mock_doSQL_just_try def test_combine_params_with_to_add_parameter(): - assert MyDB._combine_params(dict(test1=1), dict(test2=2)) == dict( - test1=1, test2=2 - ) + assert MyDB._combine_params(dict(test1=1), dict(test2=2)) == dict(test1=1, test2=2) def test_combine_params_with_kargs(): - assert MyDB._combine_params(dict(test1=1), test2=2) == dict( - test1=1, test2=2 - ) + assert MyDB._combine_params(dict(test1=1), test2=2) == dict(test1=1, test2=2) def test_combine_params_with_kargs_and_to_add_parameter(): @@ -184,47 +214,40 @@ def test_combine_params_with_kargs_and_to_add_parameter(): def test_format_where_clauses_params_are_preserved(): - args = ('test = test', dict(test1=1)) + args = ("test = test", dict(test1=1)) assert MyDB._format_where_clauses(*args) == args def test_format_where_clauses_raw(): - assert MyDB._format_where_clauses('test = test') == (('test = test'), {}) + assert MyDB._format_where_clauses("test = test") == ("test = test", {}) def test_format_where_clauses_tuple_clause_with_params(): - where_clauses = ( - 'test1 = %(test1)s AND test2 = %(test2)s', - dict(test1=1, test2=2) - ) + where_clauses = ("test1 = %(test1)s AND test2 = %(test2)s", dict(test1=1, test2=2)) assert MyDB._format_where_clauses(where_clauses) == where_clauses def test_format_where_clauses_dict(): where_clauses = dict(test1=1, test2=2) assert MyDB._format_where_clauses(where_clauses) == ( - '`test1` = %(test1)s AND `test2` = %(test2)s', - where_clauses + "`test1` = %(test1)s AND `test2` = %(test2)s", + where_clauses, ) def test_format_where_clauses_combined_types(): - where_clauses = ( - 'test1 = 1', - ('test2 LIKE %(test2)s', dict(test2=2)), - dict(test3=3, test4=4) - ) + where_clauses = ("test1 = 1", ("test2 LIKE %(test2)s", dict(test2=2)), dict(test3=3, test4=4)) assert MyDB._format_where_clauses(where_clauses) == ( - 'test1 = 1 AND test2 LIKE %(test2)s AND `test3` = %(test3)s AND `test4` = %(test4)s', - dict(test2=2, test3=3, test4=4) + "test1 = 1 AND test2 LIKE %(test2)s AND `test3` = %(test3)s AND `test4` = %(test4)s", + dict(test2=2, test3=3, test4=4), ) def test_format_where_clauses_with_where_op(): where_clauses = dict(test1=1, test2=2) - assert MyDB._format_where_clauses(where_clauses, where_op='OR') == ( - '`test1` = %(test1)s OR `test2` = %(test2)s', - where_clauses + assert MyDB._format_where_clauses(where_clauses, where_op="OR") == ( + "`test1` = %(test1)s OR `test2` = %(test2)s", + where_clauses, ) @@ -232,8 +255,8 @@ def test_add_where_clauses(): sql = "SELECT * FROM table" where_clauses = dict(test1=1, test2=2) assert MyDB._add_where_clauses(sql, None, where_clauses) == ( - sql + ' WHERE `test1` = %(test1)s AND `test2` = %(test2)s', - where_clauses + sql + " WHERE `test1` = %(test1)s AND `test2` = %(test2)s", + where_clauses, ) @@ -242,106 +265,102 @@ def test_add_where_clauses_preserved_params(): where_clauses = dict(test1=1, test2=2) params = dict(fake1=1) assert MyDB._add_where_clauses(sql, params.copy(), where_clauses) == ( - sql + ' WHERE `test1` = %(test1)s AND `test2` = %(test2)s', - dict(**where_clauses, **params) + sql + " WHERE `test1` = %(test1)s AND `test2` = %(test2)s", + dict(**where_clauses, **params), ) def test_add_where_clauses_with_op(): sql = "SELECT * FROM table" - where_clauses = ('test1=1', 'test2=2') - assert MyDB._add_where_clauses(sql, None, where_clauses, where_op='OR') == ( - sql + ' WHERE test1=1 OR test2=2', - {} + where_clauses = ("test1=1", "test2=2") + assert MyDB._add_where_clauses(sql, None, where_clauses, where_op="OR") == ( + sql + " WHERE test1=1 OR test2=2", + {}, ) def test_add_where_clauses_with_duplicated_field(): sql = "UPDATE table SET test1=%(test1)s" - params = dict(test1='new_value') - where_clauses = dict(test1='where_value') + params = dict(test1="new_value") + where_clauses = dict(test1="where_value") assert MyDB._add_where_clauses(sql, params, where_clauses) == ( - sql + ' WHERE `test1` = %(test1_1)s', - dict(test1='new_value', test1_1='where_value') + sql + " WHERE `test1` = %(test1_1)s", + dict(test1="new_value", test1_1="where_value"), ) def test_quote_table_name(): - assert MyDB._quote_table_name("mytable") == '`mytable`' - assert MyDB._quote_table_name("myschema.mytable") == '`myschema`.`mytable`' + assert MyDB._quote_table_name("mytable") == "`mytable`" + assert MyDB._quote_table_name("myschema.mytable") == "`myschema`.`mytable`" def test_insert(mocker, test_mydb): values = dict(test1=1, test2=2) mocker.patch( - 'mylib.mysql.MyDB.doSQL', + "mylib.mysql.MyDB.doSQL", generate_mock_doSQL( - 'INSERT INTO `mytable` (`test1`, `test2`) VALUES (%(test1)s, %(test2)s)', - values - ) + "INSERT INTO `mytable` (`test1`, `test2`) VALUES (%(test1)s, %(test2)s)", values + ), ) - assert test_mydb.insert('mytable', values) + assert test_mydb.insert("mytable", values) def test_insert_just_try(mocker, test_mydb): - mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSQL_just_try) - assert test_mydb.insert('mytable', dict(test1=1, test2=2), just_try=True) + mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSQL_just_try) + assert test_mydb.insert("mytable", dict(test1=1, test2=2), just_try=True) def test_update(mocker, test_mydb): values = dict(test1=1, test2=2) where_clauses = dict(test3=3, test4=4) mocker.patch( - 'mylib.mysql.MyDB.doSQL', + "mylib.mysql.MyDB.doSQL", generate_mock_doSQL( - 'UPDATE `mytable` SET `test1` = %(test1)s, `test2` = %(test2)s WHERE `test3` = %(test3)s AND `test4` = %(test4)s', - dict(**values, **where_clauses) - ) + "UPDATE `mytable` SET `test1` = %(test1)s, `test2` = %(test2)s WHERE `test3` =" + " %(test3)s AND `test4` = %(test4)s", + dict(**values, **where_clauses), + ), ) - assert test_mydb.update('mytable', values, where_clauses) + assert test_mydb.update("mytable", values, where_clauses) def test_update_just_try(mocker, test_mydb): - mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSQL_just_try) - assert test_mydb.update('mytable', dict(test1=1, test2=2), None, just_try=True) + mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSQL_just_try) + assert test_mydb.update("mytable", dict(test1=1, test2=2), None, just_try=True) def test_delete(mocker, test_mydb): where_clauses = dict(test1=1, test2=2) mocker.patch( - 'mylib.mysql.MyDB.doSQL', + "mylib.mysql.MyDB.doSQL", generate_mock_doSQL( - 'DELETE FROM `mytable` WHERE `test1` = %(test1)s AND `test2` = %(test2)s', - where_clauses - ) + "DELETE FROM `mytable` WHERE `test1` = %(test1)s AND `test2` = %(test2)s", where_clauses + ), ) - assert test_mydb.delete('mytable', where_clauses) + assert test_mydb.delete("mytable", where_clauses) def test_delete_just_try(mocker, test_mydb): - mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSQL_just_try) - assert test_mydb.delete('mytable', None, just_try=True) + mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSQL_just_try) + assert test_mydb.delete("mytable", None, just_try=True) def test_truncate(mocker, test_mydb): - mocker.patch( - 'mylib.mysql.MyDB.doSQL', - generate_mock_doSQL('TRUNCATE TABLE `mytable`', None) - ) + mocker.patch("mylib.mysql.MyDB.doSQL", generate_mock_doSQL("TRUNCATE TABLE `mytable`", None)) - assert test_mydb.truncate('mytable') + assert test_mydb.truncate("mytable") def test_truncate_just_try(mocker, test_mydb): - mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSelect_just_try) - assert test_mydb.truncate('mytable', just_try=True) + mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSelect_just_try) + assert test_mydb.truncate("mytable", just_try=True) def test_select(mocker, test_mydb): - fields = ('field1', 'field2') + fields = ("field1", "field2") where_clauses = dict(test3=3, test4=4) expected_return = [ dict(field1=1, field2=2), @@ -349,30 +368,28 @@ def test_select(mocker, test_mydb): ] order_by = "field1, DESC" mocker.patch( - 'mylib.mysql.MyDB.doSelect', + "mylib.mysql.MyDB.doSelect", generate_mock_doSQL( - 'SELECT `field1`, `field2` FROM `mytable` WHERE `test3` = %(test3)s AND `test4` = %(test4)s ORDER BY ' + order_by, - where_clauses, expected_return - ) + "SELECT `field1`, `field2` FROM `mytable` WHERE `test3` = %(test3)s AND `test4` =" + " %(test4)s ORDER BY " + order_by, + where_clauses, + expected_return, + ), ) - assert test_mydb.select('mytable', where_clauses, fields, order_by=order_by) == expected_return + assert test_mydb.select("mytable", where_clauses, fields, order_by=order_by) == expected_return def test_select_without_field_and_order_by(mocker, test_mydb): - mocker.patch( - 'mylib.mysql.MyDB.doSelect', - generate_mock_doSQL( - 'SELECT * FROM `mytable`' - ) - ) + mocker.patch("mylib.mysql.MyDB.doSelect", generate_mock_doSQL("SELECT * FROM `mytable`")) - assert test_mydb.select('mytable') + assert test_mydb.select("mytable") def test_select_just_try(mocker, test_mydb): - mocker.patch('mylib.mysql.MyDB.doSQL', mock_doSelect_just_try) - assert test_mydb.select('mytable', None, None, just_try=True) + mocker.patch("mylib.mysql.MyDB.doSQL", mock_doSelect_just_try) + assert test_mydb.select("mytable", None, None, just_try=True) + # # Tests on main methods @@ -389,12 +406,7 @@ def test_connect(mocker, test_mydb): use_unicode=True, ) - mocker.patch( - 'MySQLdb.connect', - generate_mock_args( - expected_kwargs=expected_kwargs - ) - ) + mocker.patch("MySQLdb.connect", generate_mock_args(expected_kwargs=expected_kwargs)) assert test_mydb.connect() @@ -408,48 +420,61 @@ def test_close_connected(fake_connected_mydb): def test_doSQL(fake_connected_mydb): - fake_connected_mydb._conn.expected_sql = 'DELETE FROM table WHERE test1 = %(test1)s' + fake_connected_mydb._conn.expected_sql = "DELETE FROM table WHERE test1 = %(test1)s" fake_connected_mydb._conn.expected_params = dict(test1=1) - fake_connected_mydb.doSQL(fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params) + fake_connected_mydb.doSQL( + fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params + ) def test_doSQL_without_params(fake_connected_mydb): - fake_connected_mydb._conn.expected_sql = 'DELETE FROM table' + fake_connected_mydb._conn.expected_sql = "DELETE FROM table" fake_connected_mydb.doSQL(fake_connected_mydb._conn.expected_sql) def test_doSQL_just_try(fake_connected_just_try_mydb): - assert fake_connected_just_try_mydb.doSQL('DELETE FROM table') + assert fake_connected_just_try_mydb.doSQL("DELETE FROM table") def test_doSQL_on_exception(fake_connected_mydb): fake_connected_mydb._conn.expected_exception = True - assert fake_connected_mydb.doSQL('DELETE FROM table') is False + assert fake_connected_mydb.doSQL("DELETE FROM table") is False def test_doSelect(fake_connected_mydb): - fake_connected_mydb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s' + fake_connected_mydb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s" fake_connected_mydb._conn.expected_params = dict(test1=1) fake_connected_mydb._conn.expected_return = [dict(test1=1)] - assert fake_connected_mydb.doSelect(fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params) == fake_connected_mydb._conn.expected_return + assert ( + fake_connected_mydb.doSelect( + fake_connected_mydb._conn.expected_sql, fake_connected_mydb._conn.expected_params + ) + == fake_connected_mydb._conn.expected_return + ) def test_doSelect_without_params(fake_connected_mydb): - fake_connected_mydb._conn.expected_sql = 'SELECT * FROM table' + fake_connected_mydb._conn.expected_sql = "SELECT * FROM table" fake_connected_mydb._conn.expected_return = [dict(test1=1)] - assert fake_connected_mydb.doSelect(fake_connected_mydb._conn.expected_sql) == fake_connected_mydb._conn.expected_return + assert ( + fake_connected_mydb.doSelect(fake_connected_mydb._conn.expected_sql) + == fake_connected_mydb._conn.expected_return + ) def test_doSelect_on_exception(fake_connected_mydb): fake_connected_mydb._conn.expected_exception = True - assert fake_connected_mydb.doSelect('SELECT * FROM table') is False + assert fake_connected_mydb.doSelect("SELECT * FROM table") is False def test_doSelect_just_try(fake_connected_just_try_mydb): - fake_connected_just_try_mydb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s' + fake_connected_just_try_mydb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s" fake_connected_just_try_mydb._conn.expected_params = dict(test1=1) fake_connected_just_try_mydb._conn.expected_return = [dict(test1=1)] - assert fake_connected_just_try_mydb.doSelect( - fake_connected_just_try_mydb._conn.expected_sql, - fake_connected_just_try_mydb._conn.expected_params - ) == fake_connected_just_try_mydb._conn.expected_return + assert ( + fake_connected_just_try_mydb.doSelect( + fake_connected_just_try_mydb._conn.expected_sql, + fake_connected_just_try_mydb._conn.expected_params, + ) + == fake_connected_just_try_mydb._conn.expected_return + ) diff --git a/tests/test_opening_hours.py b/tests/test_opening_hours.py index 1311cce..7e578d3 100644 --- a/tests/test_opening_hours.py +++ b/tests/test_opening_hours.py @@ -2,6 +2,7 @@ """ Tests on opening hours helpers """ import datetime + import pytest from mylib import opening_hours @@ -12,14 +13,16 @@ from mylib import opening_hours def test_parse_exceptional_closures_one_day_without_time_period(): - assert opening_hours.parse_exceptional_closures(["22/09/2017"]) == [{'days': [datetime.date(2017, 9, 22)], 'hours_periods': []}] + assert opening_hours.parse_exceptional_closures(["22/09/2017"]) == [ + {"days": [datetime.date(2017, 9, 22)], "hours_periods": []} + ] def test_parse_exceptional_closures_one_day_with_time_period(): assert opening_hours.parse_exceptional_closures(["26/11/2017 9h30-12h30"]) == [ { - 'days': [datetime.date(2017, 11, 26)], - 'hours_periods': [{'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)}] + "days": [datetime.date(2017, 11, 26)], + "hours_periods": [{"start": datetime.time(9, 30), "stop": datetime.time(12, 30)}], } ] @@ -27,11 +30,11 @@ def test_parse_exceptional_closures_one_day_with_time_period(): def test_parse_exceptional_closures_one_day_with_multiple_time_periods(): assert opening_hours.parse_exceptional_closures(["26/11/2017 9h30-12h30 14h-18h"]) == [ { - 'days': [datetime.date(2017, 11, 26)], - 'hours_periods': [ - {'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)}, - {'start': datetime.time(14, 0), 'stop': datetime.time(18, 0)}, - ] + "days": [datetime.date(2017, 11, 26)], + "hours_periods": [ + {"start": datetime.time(9, 30), "stop": datetime.time(12, 30)}, + {"start": datetime.time(14, 0), "stop": datetime.time(18, 0)}, + ], } ] @@ -39,8 +42,12 @@ def test_parse_exceptional_closures_one_day_with_multiple_time_periods(): def test_parse_exceptional_closures_full_days_period(): assert opening_hours.parse_exceptional_closures(["20/09/2017-22/09/2017"]) == [ { - 'days': [datetime.date(2017, 9, 20), datetime.date(2017, 9, 21), datetime.date(2017, 9, 22)], - 'hours_periods': [] + "days": [ + datetime.date(2017, 9, 20), + datetime.date(2017, 9, 21), + datetime.date(2017, 9, 22), + ], + "hours_periods": [], } ] @@ -53,8 +60,12 @@ def test_parse_exceptional_closures_invalid_days_period(): def test_parse_exceptional_closures_days_period_with_time_period(): assert opening_hours.parse_exceptional_closures(["20/09/2017-22/09/2017 9h-12h"]) == [ { - 'days': [datetime.date(2017, 9, 20), datetime.date(2017, 9, 21), datetime.date(2017, 9, 22)], - 'hours_periods': [{'start': datetime.time(9, 0), 'stop': datetime.time(12, 0)}] + "days": [ + datetime.date(2017, 9, 20), + datetime.date(2017, 9, 21), + datetime.date(2017, 9, 22), + ], + "hours_periods": [{"start": datetime.time(9, 0), "stop": datetime.time(12, 0)}], } ] @@ -70,31 +81,38 @@ def test_parse_exceptional_closures_invalid_time_period(): def test_parse_exceptional_closures_multiple_periods(): - assert opening_hours.parse_exceptional_closures(["20/09/2017 25/11/2017-26/11/2017 9h30-12h30 14h-18h"]) == [ + assert opening_hours.parse_exceptional_closures( + ["20/09/2017 25/11/2017-26/11/2017 9h30-12h30 14h-18h"] + ) == [ { - 'days': [ + "days": [ datetime.date(2017, 9, 20), datetime.date(2017, 11, 25), datetime.date(2017, 11, 26), ], - 'hours_periods': [ - {'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)}, - {'start': datetime.time(14, 0), 'stop': datetime.time(18, 0)}, - ] + "hours_periods": [ + {"start": datetime.time(9, 30), "stop": datetime.time(12, 30)}, + {"start": datetime.time(14, 0), "stop": datetime.time(18, 0)}, + ], } ] + # # Tests on parse_normal_opening_hours() # def test_parse_normal_opening_hours_one_day(): - assert opening_hours.parse_normal_opening_hours(["jeudi"]) == [{'days': ["jeudi"], 'hours_periods': []}] + assert opening_hours.parse_normal_opening_hours(["jeudi"]) == [ + {"days": ["jeudi"], "hours_periods": []} + ] def test_parse_normal_opening_hours_multiple_days(): - assert opening_hours.parse_normal_opening_hours(["lundi jeudi"]) == [{'days': ["lundi", "jeudi"], 'hours_periods': []}] + assert opening_hours.parse_normal_opening_hours(["lundi jeudi"]) == [ + {"days": ["lundi", "jeudi"], "hours_periods": []} + ] def test_parse_normal_opening_hours_invalid_day(): @@ -104,13 +122,17 @@ def test_parse_normal_opening_hours_invalid_day(): def test_parse_normal_opening_hours_one_days_period(): assert opening_hours.parse_normal_opening_hours(["lundi-jeudi"]) == [ - {'days': ["lundi", "mardi", "mercredi", "jeudi"], 'hours_periods': []} + {"days": ["lundi", "mardi", "mercredi", "jeudi"], "hours_periods": []} ] def test_parse_normal_opening_hours_one_day_with_one_time_period(): assert opening_hours.parse_normal_opening_hours(["jeudi 9h-12h"]) == [ - {'days': ["jeudi"], 'hours_periods': [{'start': datetime.time(9, 0), 'stop': datetime.time(12, 0)}]}] + { + "days": ["jeudi"], + "hours_periods": [{"start": datetime.time(9, 0), "stop": datetime.time(12, 0)}], + } + ] def test_parse_normal_opening_hours_invalid_days_period(): @@ -122,7 +144,10 @@ def test_parse_normal_opening_hours_invalid_days_period(): def test_parse_normal_opening_hours_one_time_period(): assert opening_hours.parse_normal_opening_hours(["9h-18h30"]) == [ - {'days': [], 'hours_periods': [{'start': datetime.time(9, 0), 'stop': datetime.time(18, 30)}]} + { + "days": [], + "hours_periods": [{"start": datetime.time(9, 0), "stop": datetime.time(18, 30)}], + } ] @@ -132,48 +157,60 @@ def test_parse_normal_opening_hours_invalid_time_period(): def test_parse_normal_opening_hours_multiple_periods(): - assert opening_hours.parse_normal_opening_hours(["lundi-vendredi 9h30-12h30 14h-18h", "samedi 9h30-18h", "dimanche 9h30-12h"]) == [ + assert opening_hours.parse_normal_opening_hours( + ["lundi-vendredi 9h30-12h30 14h-18h", "samedi 9h30-18h", "dimanche 9h30-12h"] + ) == [ { - 'days': ['lundi', 'mardi', 'mercredi', 'jeudi', 'vendredi'], - 'hours_periods': [ - {'start': datetime.time(9, 30), 'stop': datetime.time(12, 30)}, - {'start': datetime.time(14, 0), 'stop': datetime.time(18, 0)}, - ] + "days": ["lundi", "mardi", "mercredi", "jeudi", "vendredi"], + "hours_periods": [ + {"start": datetime.time(9, 30), "stop": datetime.time(12, 30)}, + {"start": datetime.time(14, 0), "stop": datetime.time(18, 0)}, + ], }, { - 'days': ['samedi'], - 'hours_periods': [ - {'start': datetime.time(9, 30), 'stop': datetime.time(18, 0)}, - ] + "days": ["samedi"], + "hours_periods": [ + {"start": datetime.time(9, 30), "stop": datetime.time(18, 0)}, + ], }, { - 'days': ['dimanche'], - 'hours_periods': [ - {'start': datetime.time(9, 30), 'stop': datetime.time(12, 0)}, - ] + "days": ["dimanche"], + "hours_periods": [ + {"start": datetime.time(9, 30), "stop": datetime.time(12, 0)}, + ], }, ] + # # Tests on is_closed # -exceptional_closures = ["22/09/2017", "20/09/2017-22/09/2017", "20/09/2017-22/09/2017 18/09/2017", "25/11/2017", "26/11/2017 9h30-12h30"] -normal_opening_hours = ["lundi-mardi jeudi 9h30-12h30 14h-16h30", "mercredi vendredi 9h30-12h30 14h-17h"] +exceptional_closures = [ + "22/09/2017", + "20/09/2017-22/09/2017", + "20/09/2017-22/09/2017 18/09/2017", + "25/11/2017", + "26/11/2017 9h30-12h30", +] +normal_opening_hours = [ + "lundi-mardi jeudi 9h30-12h30 14h-16h30", + "mercredi vendredi 9h30-12h30 14h-17h", +] nonworking_public_holidays = [ - '1janvier', - 'paques', - 'lundi_paques', - '1mai', - '8mai', - 'jeudi_ascension', - 'lundi_pentecote', - '14juillet', - '15aout', - '1novembre', - '11novembre', - 'noel', + "1janvier", + "paques", + "lundi_paques", + "1mai", + "8mai", + "jeudi_ascension", + "lundi_pentecote", + "14juillet", + "15aout", + "1novembre", + "11novembre", + "noel", ] @@ -182,12 +219,8 @@ def test_is_closed_when_normaly_closed_by_hour(): normal_opening_hours_values=normal_opening_hours, exceptional_closures_values=exceptional_closures, nonworking_public_holidays_values=nonworking_public_holidays, - when=datetime.datetime(2017, 5, 1, 20, 15) - ) == { - 'closed': True, - 'exceptional_closure': False, - 'exceptional_closure_all_day': False - } + when=datetime.datetime(2017, 5, 1, 20, 15), + ) == {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False} def test_is_closed_on_exceptional_closure_full_day(): @@ -195,12 +228,8 @@ def test_is_closed_on_exceptional_closure_full_day(): normal_opening_hours_values=normal_opening_hours, exceptional_closures_values=exceptional_closures, nonworking_public_holidays_values=nonworking_public_holidays, - when=datetime.datetime(2017, 9, 22, 14, 15) - ) == { - 'closed': True, - 'exceptional_closure': True, - 'exceptional_closure_all_day': True - } + when=datetime.datetime(2017, 9, 22, 14, 15), + ) == {"closed": True, "exceptional_closure": True, "exceptional_closure_all_day": True} def test_is_closed_on_exceptional_closure_day(): @@ -208,12 +237,8 @@ def test_is_closed_on_exceptional_closure_day(): normal_opening_hours_values=normal_opening_hours, exceptional_closures_values=exceptional_closures, nonworking_public_holidays_values=nonworking_public_holidays, - when=datetime.datetime(2017, 11, 26, 10, 30) - ) == { - 'closed': True, - 'exceptional_closure': True, - 'exceptional_closure_all_day': False - } + when=datetime.datetime(2017, 11, 26, 10, 30), + ) == {"closed": True, "exceptional_closure": True, "exceptional_closure_all_day": False} def test_is_closed_on_nonworking_public_holidays(): @@ -221,12 +246,8 @@ def test_is_closed_on_nonworking_public_holidays(): normal_opening_hours_values=normal_opening_hours, exceptional_closures_values=exceptional_closures, nonworking_public_holidays_values=nonworking_public_holidays, - when=datetime.datetime(2017, 1, 1, 10, 30) - ) == { - 'closed': True, - 'exceptional_closure': False, - 'exceptional_closure_all_day': False - } + when=datetime.datetime(2017, 1, 1, 10, 30), + ) == {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False} def test_is_closed_when_normaly_closed_by_day(): @@ -234,12 +255,8 @@ def test_is_closed_when_normaly_closed_by_day(): normal_opening_hours_values=normal_opening_hours, exceptional_closures_values=exceptional_closures, nonworking_public_holidays_values=nonworking_public_holidays, - when=datetime.datetime(2017, 5, 6, 14, 15) - ) == { - 'closed': True, - 'exceptional_closure': False, - 'exceptional_closure_all_day': False - } + when=datetime.datetime(2017, 5, 6, 14, 15), + ) == {"closed": True, "exceptional_closure": False, "exceptional_closure_all_day": False} def test_is_closed_when_normaly_opened(): @@ -247,12 +264,8 @@ def test_is_closed_when_normaly_opened(): normal_opening_hours_values=normal_opening_hours, exceptional_closures_values=exceptional_closures, nonworking_public_holidays_values=nonworking_public_holidays, - when=datetime.datetime(2017, 5, 2, 15, 15) - ) == { - 'closed': False, - 'exceptional_closure': False, - 'exceptional_closure_all_day': False - } + when=datetime.datetime(2017, 5, 2, 15, 15), + ) == {"closed": False, "exceptional_closure": False, "exceptional_closure_all_day": False} def test_easter_date(): @@ -272,18 +285,18 @@ def test_easter_date(): def test_nonworking_french_public_days_of_the_year(): assert opening_hours.nonworking_french_public_days_of_the_year(2021) == { - '1janvier': datetime.date(2021, 1, 1), - 'paques': datetime.date(2021, 4, 4), - 'lundi_paques': datetime.date(2021, 4, 5), - '1mai': datetime.date(2021, 5, 1), - '8mai': datetime.date(2021, 5, 8), - 'jeudi_ascension': datetime.date(2021, 5, 13), - 'pentecote': datetime.date(2021, 5, 23), - 'lundi_pentecote': datetime.date(2021, 5, 24), - '14juillet': datetime.date(2021, 7, 14), - '15aout': datetime.date(2021, 8, 15), - '1novembre': datetime.date(2021, 11, 1), - '11novembre': datetime.date(2021, 11, 11), - 'noel': datetime.date(2021, 12, 25), - 'saint_etienne': datetime.date(2021, 12, 26) + "1janvier": datetime.date(2021, 1, 1), + "paques": datetime.date(2021, 4, 4), + "lundi_paques": datetime.date(2021, 4, 5), + "1mai": datetime.date(2021, 5, 1), + "8mai": datetime.date(2021, 5, 8), + "jeudi_ascension": datetime.date(2021, 5, 13), + "pentecote": datetime.date(2021, 5, 23), + "lundi_pentecote": datetime.date(2021, 5, 24), + "14juillet": datetime.date(2021, 7, 14), + "15aout": datetime.date(2021, 8, 15), + "1novembre": datetime.date(2021, 11, 1), + "11novembre": datetime.date(2021, 11, 11), + "noel": datetime.date(2021, 12, 25), + "saint_etienne": datetime.date(2021, 12, 26), } diff --git a/tests/test_oracle.py b/tests/test_oracle.py index 1d93759..1e1396a 100644 --- a/tests/test_oracle.py +++ b/tests/test_oracle.py @@ -8,9 +8,11 @@ from mylib.oracle import OracleDB class FakeCXOracleCursor: - """ Fake cx_Oracle cursor """ + """Fake cx_Oracle cursor""" - def __init__(self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception): + def __init__( + self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception + ): self.expected_sql = expected_sql self.expected_params = expected_params self.expected_return = expected_return @@ -21,13 +23,25 @@ class FakeCXOracleCursor: def execute(self, sql, **params): assert self.opened if self.expected_exception: - raise cx_Oracle.Error(f'{self}.execute({sql}, {params}): expected exception') - if self.expected_just_try and not sql.lower().startswith('select '): - assert False, f'{self}.execute({sql}, {params}) may not be executed in just try mode' + raise cx_Oracle.Error(f"{self}.execute({sql}, {params}): expected exception") + if self.expected_just_try and not sql.lower().startswith("select "): + assert False, f"{self}.execute({sql}, {params}) may not be executed in just try mode" # pylint: disable=consider-using-f-string - assert sql == self.expected_sql, "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (self, sql, self.expected_sql) + assert ( + sql == self.expected_sql + ), "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % ( + self, + sql, + self.expected_sql, + ) # pylint: disable=consider-using-f-string - assert params == self.expected_params, "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (self, params, self.expected_params) + assert ( + params == self.expected_params + ), "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % ( + self, + params, + self.expected_params, + ) return self.expected_return def fetchall(self): @@ -43,13 +57,13 @@ class FakeCXOracleCursor: def __repr__(self): return ( - f'FakeCXOracleCursor({self.expected_sql}, {self.expected_params}, ' - f'{self.expected_return}, {self.expected_just_try})' + f"FakeCXOracleCursor({self.expected_sql}, {self.expected_params}, " + f"{self.expected_return}, {self.expected_just_try})" ) class FakeCXOracle: - """ Fake cx_Oracle connection """ + """Fake cx_Oracle connection""" expected_sql = None expected_params = {} @@ -62,7 +76,9 @@ class FakeCXOracle: allowed_kwargs = dict(dsn=str, user=str, password=(str, None)) for arg, value in kwargs.items(): assert arg in allowed_kwargs, f"Invalid arg {arg}='{value}'" - assert isinstance(value, allowed_kwargs[arg]), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})" + assert isinstance( + value, allowed_kwargs[arg] + ), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})" setattr(self, arg, value) def close(self): @@ -70,9 +86,11 @@ class FakeCXOracle: def cursor(self): return FakeCXOracleCursor( - self.expected_sql, self.expected_params, - self.expected_return, self.expected_just_try or self.just_try, - self.expected_exception + self.expected_sql, + self.expected_params, + self.expected_return, + self.expected_just_try or self.just_try, + self.expected_exception, ) def commit(self): @@ -100,19 +118,19 @@ def fake_cxoracle_connect_just_try(**kwargs): @pytest.fixture def test_oracledb(): - return OracleDB('127.0.0.1/dbname', 'user', 'password') + return OracleDB("127.0.0.1/dbname", "user", "password") @pytest.fixture def fake_oracledb(mocker): - mocker.patch('cx_Oracle.connect', fake_cxoracle_connect) - return OracleDB('127.0.0.1/dbname', 'user', 'password') + mocker.patch("cx_Oracle.connect", fake_cxoracle_connect) + return OracleDB("127.0.0.1/dbname", "user", "password") @pytest.fixture def fake_just_try_oracledb(mocker): - mocker.patch('cx_Oracle.connect', fake_cxoracle_connect_just_try) - return OracleDB('127.0.0.1/dbname', 'user', 'password', just_try=True) + mocker.patch("cx_Oracle.connect", fake_cxoracle_connect_just_try) + return OracleDB("127.0.0.1/dbname", "user", "password", just_try=True) @pytest.fixture @@ -127,13 +145,22 @@ def fake_connected_just_try_oracledb(fake_just_try_oracledb): return fake_just_try_oracledb -def generate_mock_args(expected_args=(), expected_kwargs={}, expected_return=True): # pylint: disable=dangerous-default-value +def generate_mock_args( + expected_args=(), expected_kwargs={}, expected_return=True +): # pylint: disable=dangerous-default-value def mock_args(*args, **kwargs): # pylint: disable=consider-using-f-string - assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (args, expected_args) + assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % ( + args, + expected_args, + ) # pylint: disable=consider-using-f-string - assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (kwargs, expected_kwargs) + assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % ( + kwargs, + expected_kwargs, + ) return expected_return + return mock_args @@ -141,13 +168,22 @@ def mock_doSQL_just_try(self, sql, params=None): # pylint: disable=unused-argum assert False, "doSQL() may not be executed in just try mode" -def generate_mock_doSQL(expected_sql, expected_params={}, expected_return=True): # pylint: disable=dangerous-default-value +def generate_mock_doSQL( + expected_sql, expected_params={}, expected_return=True +): # pylint: disable=dangerous-default-value def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument # pylint: disable=consider-using-f-string - assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (sql, expected_sql) + assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % ( + sql, + expected_sql, + ) # pylint: disable=consider-using-f-string - assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (params, expected_params) + assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % ( + params, + expected_params, + ) return expected_return + return mock_doSQL @@ -161,15 +197,11 @@ mock_doSelect_just_try = mock_doSQL_just_try def test_combine_params_with_to_add_parameter(): - assert OracleDB._combine_params(dict(test1=1), dict(test2=2)) == dict( - test1=1, test2=2 - ) + assert OracleDB._combine_params(dict(test1=1), dict(test2=2)) == dict(test1=1, test2=2) def test_combine_params_with_kargs(): - assert OracleDB._combine_params(dict(test1=1), test2=2) == dict( - test1=1, test2=2 - ) + assert OracleDB._combine_params(dict(test1=1), test2=2) == dict(test1=1, test2=2) def test_combine_params_with_kargs_and_to_add_parameter(): @@ -179,19 +211,16 @@ def test_combine_params_with_kargs_and_to_add_parameter(): def test_format_where_clauses_params_are_preserved(): - args = ('test = test', dict(test1=1)) + args = ("test = test", dict(test1=1)) assert OracleDB._format_where_clauses(*args) == args def test_format_where_clauses_raw(): - assert OracleDB._format_where_clauses('test = test') == (('test = test'), {}) + assert OracleDB._format_where_clauses("test = test") == ("test = test", {}) def test_format_where_clauses_tuple_clause_with_params(): - where_clauses = ( - 'test1 = :test1 AND test2 = :test2', - dict(test1=1, test2=2) - ) + where_clauses = ("test1 = :test1 AND test2 = :test2", dict(test1=1, test2=2)) assert OracleDB._format_where_clauses(where_clauses) == where_clauses @@ -199,27 +228,23 @@ def test_format_where_clauses_dict(): where_clauses = dict(test1=1, test2=2) assert OracleDB._format_where_clauses(where_clauses) == ( '"test1" = :test1 AND "test2" = :test2', - where_clauses + where_clauses, ) def test_format_where_clauses_combined_types(): - where_clauses = ( - 'test1 = 1', - ('test2 LIKE :test2', dict(test2=2)), - dict(test3=3, test4=4) - ) + where_clauses = ("test1 = 1", ("test2 LIKE :test2", dict(test2=2)), dict(test3=3, test4=4)) assert OracleDB._format_where_clauses(where_clauses) == ( 'test1 = 1 AND test2 LIKE :test2 AND "test3" = :test3 AND "test4" = :test4', - dict(test2=2, test3=3, test4=4) + dict(test2=2, test3=3, test4=4), ) def test_format_where_clauses_with_where_op(): where_clauses = dict(test1=1, test2=2) - assert OracleDB._format_where_clauses(where_clauses, where_op='OR') == ( + assert OracleDB._format_where_clauses(where_clauses, where_op="OR") == ( '"test1" = :test1 OR "test2" = :test2', - where_clauses + where_clauses, ) @@ -228,7 +253,7 @@ def test_add_where_clauses(): where_clauses = dict(test1=1, test2=2) assert OracleDB._add_where_clauses(sql, None, where_clauses) == ( sql + ' WHERE "test1" = :test1 AND "test2" = :test2', - where_clauses + where_clauses, ) @@ -238,26 +263,26 @@ def test_add_where_clauses_preserved_params(): params = dict(fake1=1) assert OracleDB._add_where_clauses(sql, params.copy(), where_clauses) == ( sql + ' WHERE "test1" = :test1 AND "test2" = :test2', - dict(**where_clauses, **params) + dict(**where_clauses, **params), ) def test_add_where_clauses_with_op(): sql = "SELECT * FROM table" - where_clauses = ('test1=1', 'test2=2') - assert OracleDB._add_where_clauses(sql, None, where_clauses, where_op='OR') == ( - sql + ' WHERE test1=1 OR test2=2', - {} + where_clauses = ("test1=1", "test2=2") + assert OracleDB._add_where_clauses(sql, None, where_clauses, where_op="OR") == ( + sql + " WHERE test1=1 OR test2=2", + {}, ) def test_add_where_clauses_with_duplicated_field(): sql = "UPDATE table SET test1=:test1" - params = dict(test1='new_value') - where_clauses = dict(test1='where_value') + params = dict(test1="new_value") + where_clauses = dict(test1="where_value") assert OracleDB._add_where_clauses(sql, params, where_clauses) == ( sql + ' WHERE "test1" = :test1_1', - dict(test1='new_value', test1_1='where_value') + dict(test1="new_value", test1_1="where_value"), ) @@ -269,74 +294,72 @@ def test_quote_table_name(): def test_insert(mocker, test_oracledb): values = dict(test1=1, test2=2) mocker.patch( - 'mylib.oracle.OracleDB.doSQL', + "mylib.oracle.OracleDB.doSQL", generate_mock_doSQL( - 'INSERT INTO "mytable" ("test1", "test2") VALUES (:test1, :test2)', - values - ) + 'INSERT INTO "mytable" ("test1", "test2") VALUES (:test1, :test2)', values + ), ) - assert test_oracledb.insert('mytable', values) + assert test_oracledb.insert("mytable", values) def test_insert_just_try(mocker, test_oracledb): - mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSQL_just_try) - assert test_oracledb.insert('mytable', dict(test1=1, test2=2), just_try=True) + mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSQL_just_try) + assert test_oracledb.insert("mytable", dict(test1=1, test2=2), just_try=True) def test_update(mocker, test_oracledb): values = dict(test1=1, test2=2) where_clauses = dict(test3=3, test4=4) mocker.patch( - 'mylib.oracle.OracleDB.doSQL', + "mylib.oracle.OracleDB.doSQL", generate_mock_doSQL( - 'UPDATE "mytable" SET "test1" = :test1, "test2" = :test2 WHERE "test3" = :test3 AND "test4" = :test4', - dict(**values, **where_clauses) - ) + 'UPDATE "mytable" SET "test1" = :test1, "test2" = :test2 WHERE "test3" = :test3 AND' + ' "test4" = :test4', + dict(**values, **where_clauses), + ), ) - assert test_oracledb.update('mytable', values, where_clauses) + assert test_oracledb.update("mytable", values, where_clauses) def test_update_just_try(mocker, test_oracledb): - mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSQL_just_try) - assert test_oracledb.update('mytable', dict(test1=1, test2=2), None, just_try=True) + mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSQL_just_try) + assert test_oracledb.update("mytable", dict(test1=1, test2=2), None, just_try=True) def test_delete(mocker, test_oracledb): where_clauses = dict(test1=1, test2=2) mocker.patch( - 'mylib.oracle.OracleDB.doSQL', + "mylib.oracle.OracleDB.doSQL", generate_mock_doSQL( - 'DELETE FROM "mytable" WHERE "test1" = :test1 AND "test2" = :test2', - where_clauses - ) + 'DELETE FROM "mytable" WHERE "test1" = :test1 AND "test2" = :test2', where_clauses + ), ) - assert test_oracledb.delete('mytable', where_clauses) + assert test_oracledb.delete("mytable", where_clauses) def test_delete_just_try(mocker, test_oracledb): - mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSQL_just_try) - assert test_oracledb.delete('mytable', None, just_try=True) + mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSQL_just_try) + assert test_oracledb.delete("mytable", None, just_try=True) def test_truncate(mocker, test_oracledb): mocker.patch( - 'mylib.oracle.OracleDB.doSQL', - generate_mock_doSQL('TRUNCATE TABLE "mytable"', None) + "mylib.oracle.OracleDB.doSQL", generate_mock_doSQL('TRUNCATE TABLE "mytable"', None) ) - assert test_oracledb.truncate('mytable') + assert test_oracledb.truncate("mytable") def test_truncate_just_try(mocker, test_oracledb): - mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSelect_just_try) - assert test_oracledb.truncate('mytable', just_try=True) + mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSelect_just_try) + assert test_oracledb.truncate("mytable", just_try=True) def test_select(mocker, test_oracledb): - fields = ('field1', 'field2') + fields = ("field1", "field2") where_clauses = dict(test3=3, test4=4) expected_return = [ dict(field1=1, field2=2), @@ -344,30 +367,30 @@ def test_select(mocker, test_oracledb): ] order_by = "field1, DESC" mocker.patch( - 'mylib.oracle.OracleDB.doSelect', + "mylib.oracle.OracleDB.doSelect", generate_mock_doSQL( - 'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = :test3 AND "test4" = :test4 ORDER BY ' + order_by, - where_clauses, expected_return - ) + 'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = :test3 AND "test4" = :test4' + " ORDER BY " + order_by, + where_clauses, + expected_return, + ), ) - assert test_oracledb.select('mytable', where_clauses, fields, order_by=order_by) == expected_return + assert ( + test_oracledb.select("mytable", where_clauses, fields, order_by=order_by) == expected_return + ) def test_select_without_field_and_order_by(mocker, test_oracledb): - mocker.patch( - 'mylib.oracle.OracleDB.doSelect', - generate_mock_doSQL( - 'SELECT * FROM "mytable"' - ) - ) + mocker.patch("mylib.oracle.OracleDB.doSelect", generate_mock_doSQL('SELECT * FROM "mytable"')) - assert test_oracledb.select('mytable') + assert test_oracledb.select("mytable") def test_select_just_try(mocker, test_oracledb): - mocker.patch('mylib.oracle.OracleDB.doSQL', mock_doSelect_just_try) - assert test_oracledb.select('mytable', None, None, just_try=True) + mocker.patch("mylib.oracle.OracleDB.doSQL", mock_doSelect_just_try) + assert test_oracledb.select("mytable", None, None, just_try=True) + # # Tests on main methods @@ -376,17 +399,10 @@ def test_select_just_try(mocker, test_oracledb): def test_connect(mocker, test_oracledb): expected_kwargs = dict( - dsn=test_oracledb._dsn, - user=test_oracledb._user, - password=test_oracledb._pwd + dsn=test_oracledb._dsn, user=test_oracledb._user, password=test_oracledb._pwd ) - mocker.patch( - 'cx_Oracle.connect', - generate_mock_args( - expected_kwargs=expected_kwargs - ) - ) + mocker.patch("cx_Oracle.connect", generate_mock_args(expected_kwargs=expected_kwargs)) assert test_oracledb.connect() @@ -400,50 +416,62 @@ def test_close_connected(fake_connected_oracledb): def test_doSQL(fake_connected_oracledb): - fake_connected_oracledb._conn.expected_sql = 'DELETE FROM table WHERE test1 = :test1' + fake_connected_oracledb._conn.expected_sql = "DELETE FROM table WHERE test1 = :test1" fake_connected_oracledb._conn.expected_params = dict(test1=1) - fake_connected_oracledb.doSQL(fake_connected_oracledb._conn.expected_sql, fake_connected_oracledb._conn.expected_params) + fake_connected_oracledb.doSQL( + fake_connected_oracledb._conn.expected_sql, fake_connected_oracledb._conn.expected_params + ) def test_doSQL_without_params(fake_connected_oracledb): - fake_connected_oracledb._conn.expected_sql = 'DELETE FROM table' + fake_connected_oracledb._conn.expected_sql = "DELETE FROM table" fake_connected_oracledb.doSQL(fake_connected_oracledb._conn.expected_sql) def test_doSQL_just_try(fake_connected_just_try_oracledb): - assert fake_connected_just_try_oracledb.doSQL('DELETE FROM table') + assert fake_connected_just_try_oracledb.doSQL("DELETE FROM table") def test_doSQL_on_exception(fake_connected_oracledb): fake_connected_oracledb._conn.expected_exception = True - assert fake_connected_oracledb.doSQL('DELETE FROM table') is False + assert fake_connected_oracledb.doSQL("DELETE FROM table") is False def test_doSelect(fake_connected_oracledb): - fake_connected_oracledb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = :test1' + fake_connected_oracledb._conn.expected_sql = "SELECT * FROM table WHERE test1 = :test1" fake_connected_oracledb._conn.expected_params = dict(test1=1) fake_connected_oracledb._conn.expected_return = [dict(test1=1)] - assert fake_connected_oracledb.doSelect( - fake_connected_oracledb._conn.expected_sql, - fake_connected_oracledb._conn.expected_params) == fake_connected_oracledb._conn.expected_return + assert ( + fake_connected_oracledb.doSelect( + fake_connected_oracledb._conn.expected_sql, + fake_connected_oracledb._conn.expected_params, + ) + == fake_connected_oracledb._conn.expected_return + ) def test_doSelect_without_params(fake_connected_oracledb): - fake_connected_oracledb._conn.expected_sql = 'SELECT * FROM table' + fake_connected_oracledb._conn.expected_sql = "SELECT * FROM table" fake_connected_oracledb._conn.expected_return = [dict(test1=1)] - assert fake_connected_oracledb.doSelect(fake_connected_oracledb._conn.expected_sql) == fake_connected_oracledb._conn.expected_return + assert ( + fake_connected_oracledb.doSelect(fake_connected_oracledb._conn.expected_sql) + == fake_connected_oracledb._conn.expected_return + ) def test_doSelect_on_exception(fake_connected_oracledb): fake_connected_oracledb._conn.expected_exception = True - assert fake_connected_oracledb.doSelect('SELECT * FROM table') is False + assert fake_connected_oracledb.doSelect("SELECT * FROM table") is False def test_doSelect_just_try(fake_connected_just_try_oracledb): - fake_connected_just_try_oracledb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = :test1' + fake_connected_just_try_oracledb._conn.expected_sql = "SELECT * FROM table WHERE test1 = :test1" fake_connected_just_try_oracledb._conn.expected_params = dict(test1=1) fake_connected_just_try_oracledb._conn.expected_return = [dict(test1=1)] - assert fake_connected_just_try_oracledb.doSelect( - fake_connected_just_try_oracledb._conn.expected_sql, - fake_connected_just_try_oracledb._conn.expected_params - ) == fake_connected_just_try_oracledb._conn.expected_return + assert ( + fake_connected_just_try_oracledb.doSelect( + fake_connected_just_try_oracledb._conn.expected_sql, + fake_connected_just_try_oracledb._conn.expected_params, + ) + == fake_connected_just_try_oracledb._conn.expected_return + ) diff --git a/tests/test_pgsql.py b/tests/test_pgsql.py index 5c23593..ef92c2d 100644 --- a/tests/test_pgsql.py +++ b/tests/test_pgsql.py @@ -8,9 +8,11 @@ from mylib.pgsql import PgDB class FakePsycopg2Cursor: - """ Fake Psycopg2 cursor """ + """Fake Psycopg2 cursor""" - def __init__(self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception): + def __init__( + self, expected_sql, expected_params, expected_return, expected_just_try, expected_exception + ): self.expected_sql = expected_sql self.expected_params = expected_params self.expected_return = expected_return @@ -19,13 +21,25 @@ class FakePsycopg2Cursor: def execute(self, sql, params=None): if self.expected_exception: - raise psycopg2.Error(f'{self}.execute({sql}, {params}): expected exception') - if self.expected_just_try and not sql.lower().startswith('select '): - assert False, f'{self}.execute({sql}, {params}) may not be executed in just try mode' + raise psycopg2.Error(f"{self}.execute({sql}, {params}): expected exception") + if self.expected_just_try and not sql.lower().startswith("select "): + assert False, f"{self}.execute({sql}, {params}) may not be executed in just try mode" # pylint: disable=consider-using-f-string - assert sql == self.expected_sql, "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % (self, sql, self.expected_sql) + assert ( + sql == self.expected_sql + ), "%s.execute(): Invalid SQL query:\n '%s'\nMay be:\n '%s'" % ( + self, + sql, + self.expected_sql, + ) # pylint: disable=consider-using-f-string - assert params == self.expected_params, "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % (self, params, self.expected_params) + assert ( + params == self.expected_params + ), "%s.execute(): Invalid params:\n %s\nMay be:\n %s" % ( + self, + params, + self.expected_params, + ) return self.expected_return def fetchall(self): @@ -33,13 +47,13 @@ class FakePsycopg2Cursor: def __repr__(self): return ( - f'FakePsycopg2Cursor({self.expected_sql}, {self.expected_params}, ' - f'{self.expected_return}, {self.expected_just_try})' + f"FakePsycopg2Cursor({self.expected_sql}, {self.expected_params}, " + f"{self.expected_return}, {self.expected_just_try})" ) class FakePsycopg2: - """ Fake Psycopg2 connection """ + """Fake Psycopg2 connection""" expected_sql = None expected_params = None @@ -52,8 +66,9 @@ class FakePsycopg2: allowed_kwargs = dict(dbname=str, user=str, password=(str, None), host=str) for arg, value in kwargs.items(): assert arg in allowed_kwargs, f'Invalid arg {arg}="{value}"' - assert isinstance(value, allowed_kwargs[arg]), \ - f'Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})' + assert isinstance( + value, allowed_kwargs[arg] + ), f"Arg {arg} not a {allowed_kwargs[arg]} ({type(value)})" setattr(self, arg, value) def close(self): @@ -63,14 +78,16 @@ class FakePsycopg2: self._check_just_try() assert len(arg) == 1 and isinstance(arg[0], str) if self.expected_exception: - raise psycopg2.Error(f'set_client_encoding({arg[0]}): Expected exception') + raise psycopg2.Error(f"set_client_encoding({arg[0]}): Expected exception") return self.expected_return def cursor(self): return FakePsycopg2Cursor( - self.expected_sql, self.expected_params, - self.expected_return, self.expected_just_try or self.just_try, - self.expected_exception + self.expected_sql, + self.expected_params, + self.expected_return, + self.expected_just_try or self.just_try, + self.expected_exception, ) def commit(self): @@ -98,19 +115,19 @@ def fake_psycopg2_connect_just_try(**kwargs): @pytest.fixture def test_pgdb(): - return PgDB('127.0.0.1', 'user', 'password', 'dbname') + return PgDB("127.0.0.1", "user", "password", "dbname") @pytest.fixture def fake_pgdb(mocker): - mocker.patch('psycopg2.connect', fake_psycopg2_connect) - return PgDB('127.0.0.1', 'user', 'password', 'dbname') + mocker.patch("psycopg2.connect", fake_psycopg2_connect) + return PgDB("127.0.0.1", "user", "password", "dbname") @pytest.fixture def fake_just_try_pgdb(mocker): - mocker.patch('psycopg2.connect', fake_psycopg2_connect_just_try) - return PgDB('127.0.0.1', 'user', 'password', 'dbname', just_try=True) + mocker.patch("psycopg2.connect", fake_psycopg2_connect_just_try) + return PgDB("127.0.0.1", "user", "password", "dbname", just_try=True) @pytest.fixture @@ -125,13 +142,22 @@ def fake_connected_just_try_pgdb(fake_just_try_pgdb): return fake_just_try_pgdb -def generate_mock_args(expected_args=(), expected_kwargs={}, expected_return=True): # pylint: disable=dangerous-default-value +def generate_mock_args( + expected_args=(), expected_kwargs={}, expected_return=True +): # pylint: disable=dangerous-default-value def mock_args(*args, **kwargs): # pylint: disable=consider-using-f-string - assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % (args, expected_args) + assert args == expected_args, "Invalid call args:\n %s\nMay be:\n %s" % ( + args, + expected_args, + ) # pylint: disable=consider-using-f-string - assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % (kwargs, expected_kwargs) + assert kwargs == expected_kwargs, "Invalid call kwargs:\n %s\nMay be:\n %s" % ( + kwargs, + expected_kwargs, + ) return expected_return + return mock_args @@ -139,13 +165,22 @@ def mock_doSQL_just_try(self, sql, params=None): # pylint: disable=unused-argum assert False, "doSQL() may not be executed in just try mode" -def generate_mock_doSQL(expected_sql, expected_params={}, expected_return=True): # pylint: disable=dangerous-default-value +def generate_mock_doSQL( + expected_sql, expected_params={}, expected_return=True +): # pylint: disable=dangerous-default-value def mock_doSQL(self, sql, params=None): # pylint: disable=unused-argument # pylint: disable=consider-using-f-string - assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % (sql, expected_sql) + assert sql == expected_sql, "Invalid generated SQL query:\n '%s'\nMay be:\n '%s'" % ( + sql, + expected_sql, + ) # pylint: disable=consider-using-f-string - assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % (params, expected_params) + assert params == expected_params, "Invalid generated params:\n %s\nMay be:\n %s" % ( + params, + expected_params, + ) return expected_return + return mock_doSQL @@ -159,15 +194,11 @@ mock_doSelect_just_try = mock_doSQL_just_try def test_combine_params_with_to_add_parameter(): - assert PgDB._combine_params(dict(test1=1), dict(test2=2)) == dict( - test1=1, test2=2 - ) + assert PgDB._combine_params(dict(test1=1), dict(test2=2)) == dict(test1=1, test2=2) def test_combine_params_with_kargs(): - assert PgDB._combine_params(dict(test1=1), test2=2) == dict( - test1=1, test2=2 - ) + assert PgDB._combine_params(dict(test1=1), test2=2) == dict(test1=1, test2=2) def test_combine_params_with_kargs_and_to_add_parameter(): @@ -177,19 +208,16 @@ def test_combine_params_with_kargs_and_to_add_parameter(): def test_format_where_clauses_params_are_preserved(): - args = ('test = test', dict(test1=1)) + args = ("test = test", dict(test1=1)) assert PgDB._format_where_clauses(*args) == args def test_format_where_clauses_raw(): - assert PgDB._format_where_clauses('test = test') == (('test = test'), {}) + assert PgDB._format_where_clauses("test = test") == ("test = test", {}) def test_format_where_clauses_tuple_clause_with_params(): - where_clauses = ( - 'test1 = %(test1)s AND test2 = %(test2)s', - dict(test1=1, test2=2) - ) + where_clauses = ("test1 = %(test1)s AND test2 = %(test2)s", dict(test1=1, test2=2)) assert PgDB._format_where_clauses(where_clauses) == where_clauses @@ -197,27 +225,23 @@ def test_format_where_clauses_dict(): where_clauses = dict(test1=1, test2=2) assert PgDB._format_where_clauses(where_clauses) == ( '"test1" = %(test1)s AND "test2" = %(test2)s', - where_clauses + where_clauses, ) def test_format_where_clauses_combined_types(): - where_clauses = ( - 'test1 = 1', - ('test2 LIKE %(test2)s', dict(test2=2)), - dict(test3=3, test4=4) - ) + where_clauses = ("test1 = 1", ("test2 LIKE %(test2)s", dict(test2=2)), dict(test3=3, test4=4)) assert PgDB._format_where_clauses(where_clauses) == ( 'test1 = 1 AND test2 LIKE %(test2)s AND "test3" = %(test3)s AND "test4" = %(test4)s', - dict(test2=2, test3=3, test4=4) + dict(test2=2, test3=3, test4=4), ) def test_format_where_clauses_with_where_op(): where_clauses = dict(test1=1, test2=2) - assert PgDB._format_where_clauses(where_clauses, where_op='OR') == ( + assert PgDB._format_where_clauses(where_clauses, where_op="OR") == ( '"test1" = %(test1)s OR "test2" = %(test2)s', - where_clauses + where_clauses, ) @@ -226,7 +250,7 @@ def test_add_where_clauses(): where_clauses = dict(test1=1, test2=2) assert PgDB._add_where_clauses(sql, None, where_clauses) == ( sql + ' WHERE "test1" = %(test1)s AND "test2" = %(test2)s', - where_clauses + where_clauses, ) @@ -236,26 +260,26 @@ def test_add_where_clauses_preserved_params(): params = dict(fake1=1) assert PgDB._add_where_clauses(sql, params.copy(), where_clauses) == ( sql + ' WHERE "test1" = %(test1)s AND "test2" = %(test2)s', - dict(**where_clauses, **params) + dict(**where_clauses, **params), ) def test_add_where_clauses_with_op(): sql = "SELECT * FROM table" - where_clauses = ('test1=1', 'test2=2') - assert PgDB._add_where_clauses(sql, None, where_clauses, where_op='OR') == ( - sql + ' WHERE test1=1 OR test2=2', - {} + where_clauses = ("test1=1", "test2=2") + assert PgDB._add_where_clauses(sql, None, where_clauses, where_op="OR") == ( + sql + " WHERE test1=1 OR test2=2", + {}, ) def test_add_where_clauses_with_duplicated_field(): sql = "UPDATE table SET test1=%(test1)s" - params = dict(test1='new_value') - where_clauses = dict(test1='where_value') + params = dict(test1="new_value") + where_clauses = dict(test1="where_value") assert PgDB._add_where_clauses(sql, params, where_clauses) == ( sql + ' WHERE "test1" = %(test1_1)s', - dict(test1='new_value', test1_1='where_value') + dict(test1="new_value", test1_1="where_value"), ) @@ -267,74 +291,70 @@ def test_quote_table_name(): def test_insert(mocker, test_pgdb): values = dict(test1=1, test2=2) mocker.patch( - 'mylib.pgsql.PgDB.doSQL', + "mylib.pgsql.PgDB.doSQL", generate_mock_doSQL( - 'INSERT INTO "mytable" ("test1", "test2") VALUES (%(test1)s, %(test2)s)', - values - ) + 'INSERT INTO "mytable" ("test1", "test2") VALUES (%(test1)s, %(test2)s)', values + ), ) - assert test_pgdb.insert('mytable', values) + assert test_pgdb.insert("mytable", values) def test_insert_just_try(mocker, test_pgdb): - mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSQL_just_try) - assert test_pgdb.insert('mytable', dict(test1=1, test2=2), just_try=True) + mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSQL_just_try) + assert test_pgdb.insert("mytable", dict(test1=1, test2=2), just_try=True) def test_update(mocker, test_pgdb): values = dict(test1=1, test2=2) where_clauses = dict(test3=3, test4=4) mocker.patch( - 'mylib.pgsql.PgDB.doSQL', + "mylib.pgsql.PgDB.doSQL", generate_mock_doSQL( - 'UPDATE "mytable" SET "test1" = %(test1)s, "test2" = %(test2)s WHERE "test3" = %(test3)s AND "test4" = %(test4)s', - dict(**values, **where_clauses) - ) + 'UPDATE "mytable" SET "test1" = %(test1)s, "test2" = %(test2)s WHERE "test3" =' + ' %(test3)s AND "test4" = %(test4)s', + dict(**values, **where_clauses), + ), ) - assert test_pgdb.update('mytable', values, where_clauses) + assert test_pgdb.update("mytable", values, where_clauses) def test_update_just_try(mocker, test_pgdb): - mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSQL_just_try) - assert test_pgdb.update('mytable', dict(test1=1, test2=2), None, just_try=True) + mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSQL_just_try) + assert test_pgdb.update("mytable", dict(test1=1, test2=2), None, just_try=True) def test_delete(mocker, test_pgdb): where_clauses = dict(test1=1, test2=2) mocker.patch( - 'mylib.pgsql.PgDB.doSQL', + "mylib.pgsql.PgDB.doSQL", generate_mock_doSQL( - 'DELETE FROM "mytable" WHERE "test1" = %(test1)s AND "test2" = %(test2)s', - where_clauses - ) + 'DELETE FROM "mytable" WHERE "test1" = %(test1)s AND "test2" = %(test2)s', where_clauses + ), ) - assert test_pgdb.delete('mytable', where_clauses) + assert test_pgdb.delete("mytable", where_clauses) def test_delete_just_try(mocker, test_pgdb): - mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSQL_just_try) - assert test_pgdb.delete('mytable', None, just_try=True) + mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSQL_just_try) + assert test_pgdb.delete("mytable", None, just_try=True) def test_truncate(mocker, test_pgdb): - mocker.patch( - 'mylib.pgsql.PgDB.doSQL', - generate_mock_doSQL('TRUNCATE TABLE "mytable"', None) - ) + mocker.patch("mylib.pgsql.PgDB.doSQL", generate_mock_doSQL('TRUNCATE TABLE "mytable"', None)) - assert test_pgdb.truncate('mytable') + assert test_pgdb.truncate("mytable") def test_truncate_just_try(mocker, test_pgdb): - mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSelect_just_try) - assert test_pgdb.truncate('mytable', just_try=True) + mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSelect_just_try) + assert test_pgdb.truncate("mytable", just_try=True) def test_select(mocker, test_pgdb): - fields = ('field1', 'field2') + fields = ("field1", "field2") where_clauses = dict(test3=3, test4=4) expected_return = [ dict(field1=1, field2=2), @@ -342,30 +362,28 @@ def test_select(mocker, test_pgdb): ] order_by = "field1, DESC" mocker.patch( - 'mylib.pgsql.PgDB.doSelect', + "mylib.pgsql.PgDB.doSelect", generate_mock_doSQL( - 'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = %(test3)s AND "test4" = %(test4)s ORDER BY ' + order_by, - where_clauses, expected_return - ) + 'SELECT "field1", "field2" FROM "mytable" WHERE "test3" = %(test3)s AND "test4" =' + " %(test4)s ORDER BY " + order_by, + where_clauses, + expected_return, + ), ) - assert test_pgdb.select('mytable', where_clauses, fields, order_by=order_by) == expected_return + assert test_pgdb.select("mytable", where_clauses, fields, order_by=order_by) == expected_return def test_select_without_field_and_order_by(mocker, test_pgdb): - mocker.patch( - 'mylib.pgsql.PgDB.doSelect', - generate_mock_doSQL( - 'SELECT * FROM "mytable"' - ) - ) + mocker.patch("mylib.pgsql.PgDB.doSelect", generate_mock_doSQL('SELECT * FROM "mytable"')) - assert test_pgdb.select('mytable') + assert test_pgdb.select("mytable") def test_select_just_try(mocker, test_pgdb): - mocker.patch('mylib.pgsql.PgDB.doSQL', mock_doSelect_just_try) - assert test_pgdb.select('mytable', None, None, just_try=True) + mocker.patch("mylib.pgsql.PgDB.doSQL", mock_doSelect_just_try) + assert test_pgdb.select("mytable", None, None, just_try=True) + # # Tests on main methods @@ -374,18 +392,10 @@ def test_select_just_try(mocker, test_pgdb): def test_connect(mocker, test_pgdb): expected_kwargs = dict( - dbname=test_pgdb._db, - user=test_pgdb._user, - host=test_pgdb._host, - password=test_pgdb._pwd + dbname=test_pgdb._db, user=test_pgdb._user, host=test_pgdb._host, password=test_pgdb._pwd ) - mocker.patch( - 'psycopg2.connect', - generate_mock_args( - expected_kwargs=expected_kwargs - ) - ) + mocker.patch("psycopg2.connect", generate_mock_args(expected_kwargs=expected_kwargs)) assert test_pgdb.connect() @@ -399,61 +409,74 @@ def test_close_connected(fake_connected_pgdb): def test_setEncoding(fake_connected_pgdb): - assert fake_connected_pgdb.setEncoding('utf8') + assert fake_connected_pgdb.setEncoding("utf8") def test_setEncoding_not_connected(fake_pgdb): - assert fake_pgdb.setEncoding('utf8') is False + assert fake_pgdb.setEncoding("utf8") is False def test_setEncoding_on_exception(fake_connected_pgdb): fake_connected_pgdb._conn.expected_exception = True - assert fake_connected_pgdb.setEncoding('utf8') is False + assert fake_connected_pgdb.setEncoding("utf8") is False def test_doSQL(fake_connected_pgdb): - fake_connected_pgdb._conn.expected_sql = 'DELETE FROM table WHERE test1 = %(test1)s' + fake_connected_pgdb._conn.expected_sql = "DELETE FROM table WHERE test1 = %(test1)s" fake_connected_pgdb._conn.expected_params = dict(test1=1) - fake_connected_pgdb.doSQL(fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params) + fake_connected_pgdb.doSQL( + fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params + ) def test_doSQL_without_params(fake_connected_pgdb): - fake_connected_pgdb._conn.expected_sql = 'DELETE FROM table' + fake_connected_pgdb._conn.expected_sql = "DELETE FROM table" fake_connected_pgdb.doSQL(fake_connected_pgdb._conn.expected_sql) def test_doSQL_just_try(fake_connected_just_try_pgdb): - assert fake_connected_just_try_pgdb.doSQL('DELETE FROM table') + assert fake_connected_just_try_pgdb.doSQL("DELETE FROM table") def test_doSQL_on_exception(fake_connected_pgdb): fake_connected_pgdb._conn.expected_exception = True - assert fake_connected_pgdb.doSQL('DELETE FROM table') is False + assert fake_connected_pgdb.doSQL("DELETE FROM table") is False def test_doSelect(fake_connected_pgdb): - fake_connected_pgdb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s' + fake_connected_pgdb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s" fake_connected_pgdb._conn.expected_params = dict(test1=1) fake_connected_pgdb._conn.expected_return = [dict(test1=1)] - assert fake_connected_pgdb.doSelect(fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params) == fake_connected_pgdb._conn.expected_return + assert ( + fake_connected_pgdb.doSelect( + fake_connected_pgdb._conn.expected_sql, fake_connected_pgdb._conn.expected_params + ) + == fake_connected_pgdb._conn.expected_return + ) def test_doSelect_without_params(fake_connected_pgdb): - fake_connected_pgdb._conn.expected_sql = 'SELECT * FROM table' + fake_connected_pgdb._conn.expected_sql = "SELECT * FROM table" fake_connected_pgdb._conn.expected_return = [dict(test1=1)] - assert fake_connected_pgdb.doSelect(fake_connected_pgdb._conn.expected_sql) == fake_connected_pgdb._conn.expected_return + assert ( + fake_connected_pgdb.doSelect(fake_connected_pgdb._conn.expected_sql) + == fake_connected_pgdb._conn.expected_return + ) def test_doSelect_on_exception(fake_connected_pgdb): fake_connected_pgdb._conn.expected_exception = True - assert fake_connected_pgdb.doSelect('SELECT * FROM table') is False + assert fake_connected_pgdb.doSelect("SELECT * FROM table") is False def test_doSelect_just_try(fake_connected_just_try_pgdb): - fake_connected_just_try_pgdb._conn.expected_sql = 'SELECT * FROM table WHERE test1 = %(test1)s' + fake_connected_just_try_pgdb._conn.expected_sql = "SELECT * FROM table WHERE test1 = %(test1)s" fake_connected_just_try_pgdb._conn.expected_params = dict(test1=1) fake_connected_just_try_pgdb._conn.expected_return = [dict(test1=1)] - assert fake_connected_just_try_pgdb.doSelect( - fake_connected_just_try_pgdb._conn.expected_sql, - fake_connected_just_try_pgdb._conn.expected_params - ) == fake_connected_just_try_pgdb._conn.expected_return + assert ( + fake_connected_just_try_pgdb.doSelect( + fake_connected_just_try_pgdb._conn.expected_sql, + fake_connected_just_try_pgdb._conn.expected_params, + ) + == fake_connected_just_try_pgdb._conn.expected_return + ) diff --git a/tests/test_telltale.py b/tests/test_telltale.py index 945f2f9..007dfff 100644 --- a/tests/test_telltale.py +++ b/tests/test_telltale.py @@ -3,13 +3,14 @@ import datetime import os + import pytest from mylib.telltale import TelltaleFile def test_create_telltale_file(tmp_path): - filename = 'test' + filename = "test" file = TelltaleFile(filename=filename, dirpath=tmp_path) assert file.filename == filename assert file.dirpath == tmp_path @@ -24,15 +25,15 @@ def test_create_telltale_file(tmp_path): def test_create_telltale_file_with_filepath_and_invalid_dirpath(): with pytest.raises(AssertionError): - TelltaleFile(filepath='/tmp/test', dirpath='/var/tmp') + TelltaleFile(filepath="/tmp/test", dirpath="/var/tmp") def test_create_telltale_file_with_filepath_and_invalid_filename(): with pytest.raises(AssertionError): - TelltaleFile(filepath='/tmp/test', filename='other') + TelltaleFile(filepath="/tmp/test", filename="other") def test_remove_telltale_file(tmp_path): - file = TelltaleFile(filename='test', dirpath=tmp_path) + file = TelltaleFile(filename="test", dirpath=tmp_path) file.update() assert file.remove()