Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Let stub-defined __all__ override imports #133

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion lazy_loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@
def __init__(self):
self._submodules = set()
self._submod_attrs = {}
self._all = None

def visit_ImportFrom(self, node: ast.ImportFrom):
if node.level != 1:
Expand All @@ -300,6 +301,38 @@
else:
self._submodules.update(alias.name for alias in node.names)

def visit_Assign(self, node: ast.Assign):
assigned_list = None
for name in node.targets:
if name.id == "__all__":
assigned_list = node.value

if assigned_list is None:
return # early

Check warning on line 311 in lazy_loader/__init__.py

View check run for this annotation

Codecov / codecov/patch

lazy_loader/__init__.py#L311

Added line #L311 was not covered by tests
elif not isinstance(assigned_list, ast.List):
msg = (

Check warning on line 313 in lazy_loader/__init__.py

View check run for this annotation

Codecov / codecov/patch

lazy_loader/__init__.py#L313

Added line #L313 was not covered by tests
f"expected a list assigned to `__all__`, found {type(assigned_list)!r}"
)
raise ValueError(msg)

Check warning on line 316 in lazy_loader/__init__.py

View check run for this annotation

Codecov / codecov/patch

lazy_loader/__init__.py#L316

Added line #L316 was not covered by tests

if self._all is not None:
msg = "expected only one definition of `__all__` in stub"
raise ValueError(msg)

Check warning on line 320 in lazy_loader/__init__.py

View check run for this annotation

Codecov / codecov/patch

lazy_loader/__init__.py#L319-L320

Added lines #L319 - L320 were not covered by tests
self._all = set()

for constant in assigned_list.elts:
if (
not isinstance(constant, ast.Constant)
or not isinstance(constant.value, str)
or assigned_list == ""
):
msg = (

Check warning on line 329 in lazy_loader/__init__.py

View check run for this annotation

Codecov / codecov/patch

lazy_loader/__init__.py#L329

Added line #L329 was not covered by tests
"expected `__all__` to contain only non-empty strings, "
f"got {constant!r}"
)
raise ValueError(msg)

Check warning on line 333 in lazy_loader/__init__.py

View check run for this annotation

Codecov / codecov/patch

lazy_loader/__init__.py#L333

Added line #L333 was not covered by tests
self._all.add(constant.value)


def attach_stub(package_name: str, filename: str):
"""Attach lazily loaded submodules, functions from a type stub.
Expand All @@ -308,6 +341,10 @@
infer ``submodules`` and ``submod_attrs``. This allows static type checkers
to find imports, while still providing lazy loading at runtime.

If the stub file defines `__all__`, it must contain a simple list of
non-empty strings. In this case, the content of `__dir__()` may be
intentionally different from `__all__`.

Parameters
----------
package_name : str
Expand Down Expand Up @@ -339,4 +376,10 @@

visitor = _StubVisitor()
visitor.visit(stub_node)
return attach(package_name, visitor._submodules, visitor._submod_attrs)

__getattr__, __dir__, __all__ = attach(
package_name, visitor._submodules, visitor._submod_attrs
)
if visitor._all is not None:
__all__ = visitor._all
return __getattr__, __dir__, __all__
29 changes: 29 additions & 0 deletions lazy_loader/tests/test_lazy_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,35 @@ def test_stub_loading_parity():
assert stub_getter("some_func") == fake_pkg.some_func


FAKE_STUB_OVERRIDE_ALL = """
__all__ = [
"rank",
"gaussian",
"sobel",
"scharr",
"roberts",
# `prewitt` not included!
"__version__", # included but not imported in stub
]

from . import rank
from ._gaussian import gaussian
from .edges import sobel, scharr, prewitt, roberts
"""


def test_stub_override_all(tmp_path):
stub = tmp_path / "stub.pyi"
stub.write_text(FAKE_STUB_OVERRIDE_ALL)
_get, _dir, _all = lazy.attach_stub("my_module", str(stub))

expect_dir = {"gaussian", "sobel", "scharr", "prewitt", "roberts", "rank"}
assert set(_dir()) == expect_dir

expect_all = {"rank", "gaussian", "sobel", "scharr", "roberts", "__version__"}
assert set(_all) == expect_all


def test_stub_loading_errors(tmp_path):
stub = tmp_path / "stub.pyi"
stub.write_text("from ..mod import func\n")
Expand Down
Loading