Skip to content

Commit

Permalink
Fix custom constructor edge case for variable-length positional argum…
Browse files Browse the repository at this point in the history
…ents (#258)

* Fix custom constructor edge case for variable-length positional arguments

* Python 3.8
  • Loading branch information
brentyi authored Feb 19, 2025
1 parent 4ad01de commit 992909d
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/tyro/_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def get_value_from_arg(
consumed_keywords.add(name_maybe_prefixed)
if not arg.lowered.is_fixed():
value, value_found = get_value_from_arg(name_maybe_prefixed, arg)
should_cast = False

if value in _fields.MISSING_AND_MISSING_NONPROP:
value = arg.field.default
Expand All @@ -114,14 +115,18 @@ def get_value_from_arg(
and arg.lowered.nargs in ("?", "*")
):
value = []
should_cast = True
elif value_found:
# Value was found from the CLI, so we need to cast it with instance_from_str.
should_cast = True
any_arguments_provided = True
if arg.lowered.nargs == "?":
# Special case for optional positional arguments: this is the
# only time that arguments don't come back as a list.
value = [value]

# Attempt to cast the value to the correct type.
if should_cast:
try:
assert arg.lowered.instance_from_str is not None
value = arg.lowered.instance_from_str(value)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_custom_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, List, Union

import numpy as np
import pytest
from typing_extensions import Annotated, Literal, get_args

import tyro
Expand Down Expand Up @@ -102,3 +103,36 @@ def main(
).dtype
== np.float32
)


def make_list_of_strings_with_minimum_length(args: List[str]) -> List[str]:
if len(args) == 0:
raise ValueError("Expected at least one string")
return args


ListOfStringsWithMinimumLength = Annotated[
List[str],
tyro.constructors.PrimitiveConstructorSpec(
nargs="*",
metavar="STR [STR ...]",
is_instance=lambda x: isinstance(x, list)
and all(isinstance(i, str) for i in x),
instance_from_str=make_list_of_strings_with_minimum_length,
str_from_instance=lambda args: args,
),
]


def test_min_length_custom_constructor() -> None:
def main(
field1: ListOfStringsWithMinimumLength, field2: int = 3
) -> ListOfStringsWithMinimumLength:
del field2
return field1

with pytest.raises(SystemExit):
tyro.cli(main, args=[])
with pytest.raises(SystemExit):
tyro.cli(main, args=["--field1"])
assert tyro.cli(main, args=["--field1", "a", "b"]) == ["a", "b"]
34 changes: 34 additions & 0 deletions tests/test_py311_generated/test_custom_constructors_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Annotated, Any, Dict, List, Literal, get_args

import numpy as np
import pytest

import tyro

Expand Down Expand Up @@ -101,3 +102,36 @@ def main(
).dtype
== np.float32
)


def make_list_of_strings_with_minimum_length(args: List[str]) -> List[str]:
if len(args) == 0:
raise ValueError("Expected at least one string")
return args


ListOfStringsWithMinimumLength = Annotated[
List[str],
tyro.constructors.PrimitiveConstructorSpec(
nargs="*",
metavar="STR [STR ...]",
is_instance=lambda x: isinstance(x, list)
and all(isinstance(i, str) for i in x),
instance_from_str=make_list_of_strings_with_minimum_length,
str_from_instance=lambda args: args,
),
]


def test_min_length_custom_constructor() -> None:
def main(
field1: ListOfStringsWithMinimumLength, field2: int = 3
) -> ListOfStringsWithMinimumLength:
del field2
return field1

with pytest.raises(SystemExit):
tyro.cli(main, args=[])
with pytest.raises(SystemExit):
tyro.cli(main, args=["--field1"])
assert tyro.cli(main, args=["--field1", "a", "b"]) == ["a", "b"]

0 comments on commit 992909d

Please # to comment.