Skip to content

Commit

Permalink
Added test for inherited bound method
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jan 16, 2025
1 parent d2c275f commit b119e50
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 2 deletions.
44 changes: 42 additions & 2 deletions tests/test_generics_and_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from typing import Generic, List, NewType, Tuple, Type, TypeVar, Union

import pytest
import tyro
import yaml
from helptext_utils import get_helptext_with_checks
from typing_extensions import Annotated

import tyro
from helptext_utils import get_helptext_with_checks

T = TypeVar("T")

Check failure on line 14 in tests/test_generics_and_serialization.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

tests/test_generics_and_serialization.py:1:1: I001 Import block is un-sorted or un-formatted

Expand Down Expand Up @@ -543,3 +543,43 @@ def method(self, a: T) -> T:
assert tyro.cli(Config[int](3).method, args="--a 5".split(" ")) == 5
with pytest.raises(tyro.constructors.UnsupportedTypeAnnotationError):
tyro.cli(Config(3).method, args="--a 5".split(" "))


def test_inherited_bound_method() -> None:
"""From @deeptoaster: https://github.com/brentyi/tyro/issues/233"""

@dataclasses.dataclass
class AConfig:
a: int

TContainsAConfig = TypeVar("TContainsAConfig", bound=AConfig)

class AModel(Generic[TContainsAConfig]):
def __init__(self, config: TContainsAConfig):
self.config = config

config: TContainsAConfig

TContainsAModel = TypeVar("TContainsAModel", bound=AModel)

@dataclasses.dataclass
class ABConfig(AConfig):
b: int

class ABModel(AModel[ABConfig]):
pass

class AHelper(Generic[TContainsAModel]):
def spam(self, model: TContainsAModel) -> TContainsAModel:
self.model = model
return model

model: TContainsAModel

class ABHelper(AHelper[ABModel]):
def print_model(self) -> None:
print(self.model.config)

assert tyro.cli(
ABHelper().spam, args="--model.config.a 5 --model.config.b 7".split(" ")
).config == ABConfig(5, 7)
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,55 @@ def c(model: ABModel[ABCConfig]):
assert "--model.config.a" in get_helptext_with_checks(c)
assert "--model.config.b" in get_helptext_with_checks(c)
assert "--model.config.c" in get_helptext_with_checks(c)


def test_simple_bound_method() -> None:
class Config[T]:
def __init__(self, a: T) -> None: ...
def method(self, a: T) -> T:
return a

assert tyro.cli(Config[int], args="--a 5".split(" ")).method(3) == 3
assert tyro.cli(Config[int](3).method, args="--a 5".split(" ")) == 5
with pytest.raises(tyro.constructors.UnsupportedTypeAnnotationError):
tyro.cli(Config(3).method, args="--a 5".split(" "))


def test_inherited_bound_method() -> None:
"""From @deeptoaster: https://github.com/brentyi/tyro/issues/233"""

@dataclasses.dataclass
class AConfig:
a: int

TContainsAConfig = TypeVar("TContainsAConfig", bound=AConfig)

class AModel(Generic[TContainsAConfig]):
def __init__(self, config: TContainsAConfig):
self.config = config

config: TContainsAConfig

TContainsAModel = TypeVar("TContainsAModel", bound=AModel)

@dataclasses.dataclass
class ABConfig(AConfig):
b: int

class ABModel(AModel[ABConfig]):
pass

class AHelper(Generic[TContainsAModel]):
def spam(self, model: TContainsAModel) -> TContainsAModel:
self.model = model
return model

model: TContainsAModel

class ABHelper(AHelper[ABModel]):
def print_model(self) -> None:
print(self.model.config)

assert tyro.cli(
ABHelper().spam, args="--model.config.a 5 --model.config.b 7".split(" ")
).config == ABConfig(5, 7)

0 comments on commit b119e50

Please # to comment.