From 7577a30355430084b63f93858682ac94a1e2b0c5 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Thu, 16 Jan 2025 13:41:18 -0800 Subject: [PATCH] Workaround for `__type_params__` bug in Python 3.12.0 (#236) * Add test from #235 * torch import try/except * test gen, ruff * Add 3.12.0 to pytest yml * Workaround for `type[T]` bug in Python 3.12.0 --- .github/workflows/pytest.yml | 2 +- src/tyro/_resolver.py | 14 ++++---- tests/test_new_style_annotations_min_py39.py | 33 ++++++++++++++++++- ...ew_style_annotations_min_py39_generated.py | 33 ++++++++++++++++++- 4 files changed, 73 insertions(+), 9 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 1ee6ff09..e9767170 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-22.04 strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.12.0", "3.13"] steps: - uses: actions/checkout@v2 diff --git a/src/tyro/_resolver.py b/src/tyro/_resolver.py index e776beea..b61a574e 100644 --- a/src/tyro/_resolver.py +++ b/src/tyro/_resolver.py @@ -381,12 +381,14 @@ def concretize_type_params( typ = resolve_newtype_and_aliases(typ) type_from_typevar = {} GenericAlias = getattr(types, "GenericAlias", None) - while ( - GenericAlias is not None - and isinstance(typ, GenericAlias) - and len(getattr(typ, "__type_params__", ())) > 0 - ): - for k, v in zip(typ.__type_params__, get_args(typ)): # type: ignore + while GenericAlias is not None and isinstance(typ, GenericAlias): + type_params = getattr(typ, "__type_params__", ()) + # The __len__ check is for a bug in Python 3.12.0: + # https://github.com/brentyi/tyro/issues/235 + if not hasattr(type_params, "__len__") or len(type_params) == 0: + break + + for k, v in zip(type_params, get_args(typ)): type_from_typevar[k] = TypeParamResolver.concretize_type_params( v, seen=seen ) diff --git a/tests/test_new_style_annotations_min_py39.py b/tests/test_new_style_annotations_min_py39.py index 9b8908a3..80ccfabc 100644 --- a/tests/test_new_style_annotations_min_py39.py +++ b/tests/test_new_style_annotations_min_py39.py @@ -1,7 +1,8 @@ import dataclasses -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Optional, Type, Union import pytest +from helptext_utils import get_helptext_with_checks import tyro @@ -67,3 +68,33 @@ def main( def test_tuple_direct() -> None: assert tyro.cli(tuple[int, ...], args="1 2".split(" ")) == (1, 2) # type: ignore assert tyro.cli(tuple[int, int], args="1 2".split(" ")) == (1, 2) # type: ignore + + +try: + from torch.optim.lr_scheduler import LinearLR, LRScheduler + + def test_type_with_init_false() -> None: + """https://github.com/brentyi/tyro/issues/235""" + + @dataclasses.dataclass(frozen=True) + class LinearLRConfig: + _target: type[LRScheduler] = dataclasses.field( + init=False, default_factory=lambda: LinearLR + ) + _target2: Type[LRScheduler] = dataclasses.field( + init=False, default_factory=lambda: LinearLR + ) + start_factor: float = 1.0 / 3 + end_factor: float = 1.0 + total_iters: Optional[int] = None + + def main(config: LinearLRConfig) -> LinearLRConfig: + return config + + assert tyro.cli(main, args=[]) == LinearLRConfig() + assert "_target" not in get_helptext_with_checks(LinearLRConfig) +except ImportError: + # We can't install PyTorch in Python 3.13. + import sys + + assert sys.version_info >= (3, 13) diff --git a/tests/test_py311_generated/test_new_style_annotations_min_py39_generated.py b/tests/test_py311_generated/test_new_style_annotations_min_py39_generated.py index 0fe423b5..6747eb8f 100644 --- a/tests/test_py311_generated/test_new_style_annotations_min_py39_generated.py +++ b/tests/test_py311_generated/test_new_style_annotations_min_py39_generated.py @@ -1,7 +1,8 @@ import dataclasses -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Type import pytest +from helptext_utils import get_helptext_with_checks import tyro @@ -67,3 +68,33 @@ def main( def test_tuple_direct() -> None: assert tyro.cli(tuple[int, ...], args="1 2".split(" ")) == (1, 2) # type: ignore assert tyro.cli(tuple[int, int], args="1 2".split(" ")) == (1, 2) # type: ignore + + +try: + from torch.optim.lr_scheduler import LinearLR, LRScheduler + + def test_type_with_init_false() -> None: + """https://github.com/brentyi/tyro/issues/235""" + + @dataclasses.dataclass(frozen=True) + class LinearLRConfig: + _target: type[LRScheduler] = dataclasses.field( + init=False, default_factory=lambda: LinearLR + ) + _target2: Type[LRScheduler] = dataclasses.field( + init=False, default_factory=lambda: LinearLR + ) + start_factor: float = 1.0 / 3 + end_factor: float = 1.0 + total_iters: Optional[int] = None + + def main(config: LinearLRConfig) -> LinearLRConfig: + return config + + assert tyro.cli(main, args=[]) == LinearLRConfig() + assert "_target" not in get_helptext_with_checks(LinearLRConfig) +except ImportError: + # We can't install PyTorch in Python 3.13. + import sys + + assert sys.version_info >= (3, 13)