Skip to content

Commit

Permalink
torch import try/except
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jan 16, 2025
1 parent b7dc931 commit 8d687d1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
49 changes: 28 additions & 21 deletions tests/test_new_style_annotations_min_py39.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Any, Literal, Optional, Type, Union

import pytest
from helptext_utils import get_helptext_with_checks

import tyro

from helptext_utils import get_helptext_with_checks


def test_list() -> None:

Check failure on line 10 in tests/test_new_style_annotations_min_py39.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

tests/test_new_style_annotations_min_py39.py:1:1: I001 Import block is un-sorted or un-formatted
def main(x: list[bool]) -> Any:
Expand Down Expand Up @@ -70,24 +70,31 @@ def test_tuple_direct() -> None:
assert tyro.cli(tuple[int, int], args="1 2".split(" ")) == (1, 2) # type: ignore


def test_type_with_init_false() -> None:
"""https://github.com/brentyi/tyro/issues/235"""
try:
from torch.optim.lr_scheduler import LinearLR, LRScheduler

@dataclasses.dataclass(frozen=True)
class LinearLRConfig:
_target: type[LRScheduler] = dataclasses.field(
init=False, default_factory=lambda: LinearLR
)
_target2: Type[LRScheduler] = dataclasses.field(
init=False, default_factory=lambda: LinearLR
)
start_factor: float = 1.0 / 3
end_factor: float = 1.0
total_iters: Optional[int] = None

def main(config: LinearLRConfig) -> LinearLRConfig:
return config

assert tyro.cli(main, args=[]) == LinearLRConfig()
assert "_target" not in get_helptext_with_checks(LinearLRConfig)
def test_type_with_init_false() -> None:
"""https://github.com/brentyi/tyro/issues/235"""

@dataclasses.dataclass(frozen=True)
class LinearLRConfig:
_target: type[LRScheduler] = dataclasses.field(
init=False, default_factory=lambda: LinearLR
)
_target2: Type[LRScheduler] = dataclasses.field(
init=False, default_factory=lambda: LinearLR
)
start_factor: float = 1.0 / 3
end_factor: float = 1.0
total_iters: Optional[int] = None

def main(config: LinearLRConfig) -> LinearLRConfig:
return config

assert tyro.cli(main, args=[]) == LinearLRConfig()
assert "_target" not in get_helptext_with_checks(LinearLRConfig)
except ImportError:
# We can't install PyTorch in Python 3.13.
import sys

assert sys.version_info >= (3, 13)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from helptext_utils import get_helptext_with_checks
from torch.optim.lr_scheduler import LinearLR, LRScheduler

import tyro

Expand Down Expand Up @@ -72,7 +73,6 @@ def test_tuple_direct() -> None:

def test_type_with_init_false() -> None:
"""https://github.com/brentyi/tyro/issues/235"""
from torch.optim.lr_scheduler import LinearLR, LRScheduler

@dataclasses.dataclass(frozen=True)
class LinearLRConfig:
Expand Down

0 comments on commit 8d687d1

Please # to comment.