diff --git a/lazy_loader/__init__.py b/lazy_loader/__init__.py index deea54f..784daa3 100644 --- a/lazy_loader/__init__.py +++ b/lazy_loader/__init__.py @@ -282,6 +282,7 @@ class _StubVisitor(ast.NodeVisitor): def __init__(self): self._submodules = set() self._submod_attrs = {} + self._all = None def visit_ImportFrom(self, node: ast.ImportFrom): if node.level != 1: @@ -300,6 +301,38 @@ def visit_ImportFrom(self, node: ast.ImportFrom): else: self._submodules.update(alias.name for alias in node.names) + def visit_Assign(self, node: ast.Assign): + assigned_list = None + for name in node.targets: + if name.id == "__all__": + assigned_list = node.value + + if assigned_list is None: + return # early + elif not isinstance(assigned_list, ast.List): + msg = ( + f"expected a list assigned to `__all__`, found {type(assigned_list)!r}" + ) + raise ValueError(msg) + + if self._all is not None: + msg = "expected only one definition of `__all__` in stub" + raise ValueError(msg) + self._all = set() + + for constant in assigned_list.elts: + if ( + not isinstance(constant, ast.Constant) + or not isinstance(constant.value, str) + or assigned_list == "" + ): + msg = ( + "expected `__all__` to contain only non-empty strings, " + f"got {constant!r}" + ) + raise ValueError(msg) + self._all.add(constant.value) + def attach_stub(package_name: str, filename: str): """Attach lazily loaded submodules, functions from a type stub. @@ -308,6 +341,10 @@ def attach_stub(package_name: str, filename: str): infer ``submodules`` and ``submod_attrs``. This allows static type checkers to find imports, while still providing lazy loading at runtime. + If the stub file defines `__all__`, it must contain a simple list of + non-empty strings. In this case, the content of `__dir__()` may be + intentionally different from `__all__`. + Parameters ---------- package_name : str @@ -339,4 +376,10 @@ def attach_stub(package_name: str, filename: str): visitor = _StubVisitor() visitor.visit(stub_node) - return attach(package_name, visitor._submodules, visitor._submod_attrs) + + __getattr__, __dir__, __all__ = attach( + package_name, visitor._submodules, visitor._submod_attrs + ) + if visitor._all is not None: + __all__ = visitor._all + return __getattr__, __dir__, __all__ diff --git a/lazy_loader/tests/test_lazy_loader.py b/lazy_loader/tests/test_lazy_loader.py index 42d97f8..dbe41b1 100644 --- a/lazy_loader/tests/test_lazy_loader.py +++ b/lazy_loader/tests/test_lazy_loader.py @@ -143,6 +143,35 @@ def test_stub_loading_parity(): assert stub_getter("some_func") == fake_pkg.some_func +FAKE_STUB_OVERRIDE_ALL = """ +__all__ = [ + "rank", + "gaussian", + "sobel", + "scharr", + "roberts", + # `prewitt` not included! + "__version__", # included but not imported in stub +] + +from . import rank +from ._gaussian import gaussian +from .edges import sobel, scharr, prewitt, roberts +""" + + +def test_stub_override_all(tmp_path): + stub = tmp_path / "stub.pyi" + stub.write_text(FAKE_STUB_OVERRIDE_ALL) + _get, _dir, _all = lazy.attach_stub("my_module", str(stub)) + + expect_dir = {"gaussian", "sobel", "scharr", "prewitt", "roberts", "rank"} + assert set(_dir()) == expect_dir + + expect_all = {"rank", "gaussian", "sobel", "scharr", "roberts", "__version__"} + assert set(_all) == expect_all + + def test_stub_loading_errors(tmp_path): stub = tmp_path / "stub.pyi" stub.write_text("from ..mod import func\n")