From 98d6e723d6f549be64d226957201c6cf29a833da Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 2 Sep 2024 16:51:41 -0700 Subject: [PATCH] Fix function signature generation when param is not annotated (#176) Previously the parameter was not added to the signature. Fixes https://github.com/justinchuby/torch-onnx/issues/175 --- src/torch_onnx/_schemas.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/torch_onnx/_schemas.py b/src/torch_onnx/_schemas.py index f49b2ee..32c1c77 100644 --- a/src/torch_onnx/_schemas.py +++ b/src/torch_onnx/_schemas.py @@ -127,6 +127,8 @@ def has_default(self) -> bool: @dataclasses.dataclass(frozen=True) class AttributeParameter: + """A parameter in the function signature that represents an ONNX attribute.""" + name: str type: ir.AttributeType required: bool @@ -441,7 +443,7 @@ def from_function( # https://github.com/python/cpython/issues/102405 type_hints = typing.get_type_hints(func) - params = [] + params: list[Parameter | AttributeParameter] = [] # Create a mapping from type to a unique name type_constraints: dict[str, TypeConstraintParam] = {} @@ -452,8 +454,18 @@ def from_function( param.name, py_signature, ) - type_constraints[param.name] = TypeConstraintParam.any_value( - f"T_{param.name}" + type_constraint = TypeConstraintParam.any_value(f"T_{param.name}") + params.append( + Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.default is inspect.Parameter.empty, + # TODO: Handle variadic + variadic=False, + default=param.default + if param.default is not inspect.Parameter.empty + else _EMPTY_DEFAULT, + ) ) else: type_ = type_hints[param.name] @@ -496,7 +508,7 @@ def from_function( type_constraints[type_constraint_name] = type_constraint # 4. Create Parameter params.append( - Parameter( # type: ignore[arg-type] + Parameter( name=param.name, type_constraint=type_constraint, required=param.default is inspect.Parameter.empty,