diff --git a/my/core/cfg.py b/my/core/cfg.py index 4b5cbede..d69f356b 100644 --- a/my/core/cfg.py +++ b/my/core/cfg.py @@ -44,12 +44,53 @@ def override_config(config: F) -> Iterator[F]: delattr(config, k) -# helper for tests? not sure if could be useful elsewhere +import importlib +import sys +from typing import Optional, Set +ModuleRegex = str @contextmanager -def tmp_config(): - import my.config as C - with override_config(C): - yield C # todo not sure? +def _reload_modules(modules: ModuleRegex) -> Iterator[None]: + def loaded_modules() -> Set[str]: + return {name for name in sys.modules if re.fullmatch(modules, name)} + + modules_before = loaded_modules() + + for m in modules_before: + importlib.reload(sys.modules[m]) + + try: + yield + finally: + modules_after = loaded_modules() + for m in modules_after: + if m in modules_before: + # was previously loaded, so need to reload to pick up old config + importlib.reload(sys.modules[m]) + else: + # wasn't previously loaded, so need to unload it + # otherwise it might fail due to missing config etc + sys.modules.pop(m, None) + + +from contextlib import ExitStack +import re +@contextmanager +def tmp_config(*, modules: Optional[ModuleRegex]=None, config=None): + if modules is None: + assert config is None + if modules is not None: + assert config is not None + + import my.config + with ExitStack() as module_reload_stack, override_config(my.config) as new_config: + if config is not None: + overrides = {k: v for k, v in vars(config).items() if not k.startswith('__')} + for k, v in overrides.items(): + setattr(new_config, k, v) + + if modules is not None: + module_reload_stack.enter_context(_reload_modules(modules)) + yield new_config def test_tmp_config() -> None: diff --git a/my/simple.py b/my/simple.py new file mode 100644 index 00000000..7462291c --- /dev/null +++ b/my/simple.py @@ -0,0 +1,21 @@ +''' +Just a demo module for testing and documentation purposes +''' +from dataclasses import dataclass +from typing import Iterator + +from my.core import make_config + +from my.config import simple as user_config + + +@dataclass +class simple(user_config): + count: int + + +config = make_config(simple) + + +def items() -> Iterator[int]: + yield from range(config.count) diff --git a/tests/test_tmp_config.py b/tests/test_tmp_config.py new file mode 100644 index 00000000..55a21f01 --- /dev/null +++ b/tests/test_tmp_config.py @@ -0,0 +1,30 @@ +from pathlib import Path +import tempfile + +from my.core.cfg import tmp_config + +import pytest + + +def _init_default_config(): + import my.config + class default_config: + count = 5 + my.config.simple = default_config # type: ignore[attr-defined] +_init_default_config() + + +from my.simple import items + + +def test_tmp_config() -> None: + assert len(list(items())) == 5 + + class config: + class simple: + count = 3 + + with tmp_config(modules='my.simple', config=config): + assert len(list(items())) == 3 + + assert len(list(items())) == 5 diff --git a/tox.ini b/tox.ini index ed2a0845..5ae76f3a 100644 --- a/tox.ini +++ b/tox.ini @@ -20,9 +20,10 @@ passenv = commands = pip install -e .[testing] {envpython} -m pytest \ - tests/core.py \ - tests/sqlite.py \ - tests/get_files.py \ + tests/core.py \ + tests/sqlite.py \ + tests/get_files.py \ + tests/test_tmp_config.py \ {posargs}