Skip to content

Commit

Permalink
Fix function signature generation when param is not annotated (#176)
Browse files Browse the repository at this point in the history
Previously the parameter was not added to the signature. Fixes
#175
  • Loading branch information
justinchuby authored Sep 2, 2024
1 parent 804d98e commit 98d6e72
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/torch_onnx/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def has_default(self) -> bool:

@dataclasses.dataclass(frozen=True)
class AttributeParameter:
"""A parameter in the function signature that represents an ONNX attribute."""

name: str
type: ir.AttributeType
required: bool
Expand Down Expand Up @@ -441,7 +443,7 @@ def from_function(
# https://github.com/python/cpython/issues/102405
type_hints = typing.get_type_hints(func)

params = []
params: list[Parameter | AttributeParameter] = []
# Create a mapping from type to a unique name
type_constraints: dict[str, TypeConstraintParam] = {}

Expand All @@ -452,8 +454,18 @@ def from_function(
param.name,
py_signature,
)
type_constraints[param.name] = TypeConstraintParam.any_value(
f"T_{param.name}"
type_constraint = TypeConstraintParam.any_value(f"T_{param.name}")
params.append(
Parameter(
name=param.name,
type_constraint=type_constraint,
required=param.default is inspect.Parameter.empty,
# TODO: Handle variadic
variadic=False,
default=param.default
if param.default is not inspect.Parameter.empty
else _EMPTY_DEFAULT,
)
)
else:
type_ = type_hints[param.name]
Expand Down Expand Up @@ -496,7 +508,7 @@ def from_function(
type_constraints[type_constraint_name] = type_constraint
# 4. Create Parameter
params.append(
Parameter( # type: ignore[arg-type]
Parameter(
name=param.name,
type_constraint=type_constraint,
required=param.default is inspect.Parameter.empty,
Expand Down

0 comments on commit 98d6e72

Please # to comment.