From 8d687d1c6ef32a2d916dcfd59e9f551eda793726 Mon Sep 17 00:00:00 2001 From: brentyi Date: Thu, 16 Jan 2025 11:54:53 -0800 Subject: [PATCH] torch import try/except --- tests/test_new_style_annotations_min_py39.py | 49 +++++++++++-------- ...ew_style_annotations_min_py39_generated.py | 2 +- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/tests/test_new_style_annotations_min_py39.py b/tests/test_new_style_annotations_min_py39.py index a61ce91f..999774af 100644 --- a/tests/test_new_style_annotations_min_py39.py +++ b/tests/test_new_style_annotations_min_py39.py @@ -2,10 +2,10 @@ from typing import Any, Literal, Optional, Type, Union import pytest -from helptext_utils import get_helptext_with_checks - import tyro +from helptext_utils import get_helptext_with_checks + def test_list() -> None: def main(x: list[bool]) -> Any: @@ -70,24 +70,31 @@ def test_tuple_direct() -> None: assert tyro.cli(tuple[int, int], args="1 2".split(" ")) == (1, 2) # type: ignore -def test_type_with_init_false() -> None: - """https://github.com/brentyi/tyro/issues/235""" +try: from torch.optim.lr_scheduler import LinearLR, LRScheduler - @dataclasses.dataclass(frozen=True) - class LinearLRConfig: - _target: type[LRScheduler] = dataclasses.field( - init=False, default_factory=lambda: LinearLR - ) - _target2: Type[LRScheduler] = dataclasses.field( - init=False, default_factory=lambda: LinearLR - ) - start_factor: float = 1.0 / 3 - end_factor: float = 1.0 - total_iters: Optional[int] = None - - def main(config: LinearLRConfig) -> LinearLRConfig: - return config - - assert tyro.cli(main, args=[]) == LinearLRConfig() - assert "_target" not in get_helptext_with_checks(LinearLRConfig) + def test_type_with_init_false() -> None: + """https://github.com/brentyi/tyro/issues/235""" + + @dataclasses.dataclass(frozen=True) + class LinearLRConfig: + _target: type[LRScheduler] = dataclasses.field( + init=False, default_factory=lambda: LinearLR + ) + _target2: Type[LRScheduler] = dataclasses.field( + init=False, default_factory=lambda: LinearLR + ) + start_factor: float = 1.0 / 3 + end_factor: float = 1.0 + total_iters: Optional[int] = None + + def main(config: LinearLRConfig) -> LinearLRConfig: + return config + + assert tyro.cli(main, args=[]) == LinearLRConfig() + assert "_target" not in get_helptext_with_checks(LinearLRConfig) +except ImportError: + # We can't install PyTorch in Python 3.13. + import sys + + assert sys.version_info >= (3, 13) diff --git a/tests/test_py311_generated/test_new_style_annotations_min_py39_generated.py b/tests/test_py311_generated/test_new_style_annotations_min_py39_generated.py index f17e9371..0b51242e 100644 --- a/tests/test_py311_generated/test_new_style_annotations_min_py39_generated.py +++ b/tests/test_py311_generated/test_new_style_annotations_min_py39_generated.py @@ -3,6 +3,7 @@ import pytest from helptext_utils import get_helptext_with_checks +from torch.optim.lr_scheduler import LinearLR, LRScheduler import tyro @@ -72,7 +73,6 @@ def test_tuple_direct() -> None: def test_type_with_init_false() -> None: """https://github.com/brentyi/tyro/issues/235""" - from torch.optim.lr_scheduler import LinearLR, LRScheduler @dataclasses.dataclass(frozen=True) class LinearLRConfig: