From c34c2778c2d15725e0ae3961b53a4fa67fd4d823 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 25 Jul 2024 09:27:06 -0700 Subject: [PATCH] Handle `[]` inputs (#116) Fixes https://github.com/justinchuby/torch-onnx/issues/108 --- src/torch_onnx/_building.py | 7 ++- tests/conditional_test.py | 59 -------------------- tests/conversion_test.py | 70 ++++++++++++++++++++++++ tests/{ => models}/longformer.py | 0 tests/{ => models}/longformer_export.py | 0 tests/{ => models}/resnet.py | 0 tests/{ => models}/resnet_export_test.py | 0 tests/{ => models}/rotary_embedding.py | 0 8 files changed, 76 insertions(+), 60 deletions(-) delete mode 100644 tests/conditional_test.py create mode 100644 tests/conversion_test.py rename tests/{ => models}/longformer.py (100%) rename tests/{ => models}/longformer_export.py (100%) rename tests/{ => models}/resnet.py (100%) rename tests/{ => models}/resnet_export_test.py (100%) rename tests/{ => models}/rotary_embedding.py (100%) diff --git a/src/torch_onnx/_building.py b/src/torch_onnx/_building.py index 6a97c2b7..923351d5 100644 --- a/src/torch_onnx/_building.py +++ b/src/torch_onnx/_building.py @@ -219,11 +219,16 @@ def _process_python_constants_and_sequences( if isinstance(arg, ir.Value): # TODO(justinchuby): Cast the ir.Value here if needed continue - if isinstance(arg, Sequence) and all(isinstance(val, ir.Value) for val in arg): + if ( + isinstance(arg, Sequence) + and len(arg) > 0 + and all(isinstance(val, ir.Value) for val in arg) + ): # Skip the sequence of ir.Value. This is a variadic input or a Sequence input # NOTE: Variadic operators like Max can be called with mixed ir.Value and Python constants # like `Max(0, ir.Value())` # We need to convert the Python constants to Constant nodes + # NOTE: Important to check that arg is not empty because we need to treat it as list[int] or list[float] continue # if param.variadic: # # FXIME: Handle variadic inputs and sequence inputs differently diff --git a/tests/conditional_test.py b/tests/conditional_test.py deleted file mode 100644 index a2a3ba2b..00000000 --- a/tests/conditional_test.py +++ /dev/null @@ -1,59 +0,0 @@ -import unittest -import torch -import torch_onnx -from functorch.experimental.control_flow import cond - -IS_MAIN = __name__ == "__main__" - -torch_onnx.patch_torch( - error_report=IS_MAIN, profile=IS_MAIN, dump_exported_program=IS_MAIN -) - - -class MySubModule(torch.nn.Module): - def foo(self, x): - return x.cos() - - def forward(self, x): - return self.foo(x) - - -class CondBranchClassMethod(torch.nn.Module): - """ - The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - - both branches must take the same args, which must also match the branch args passed to cond. - - both branches must return a single tensor - - returned tensor must have the same tensor metadata, e.g. shape and dtype - - branch function can be free function, nested function, lambda, class methods - - branch function can not have closure variables - - no inplace mutations on inputs or global variables - - - This example demonstrates using class method in cond(). - - NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. - """ - - def __init__(self): - super().__init__() - self.subm = MySubModule() - - def bar(self, x): - return x.sin() - - def forward(self, x): - return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x]) - - -class ConditionalTest(unittest.TestCase): - @unittest.expectedFailure # Conditionals are not supported yet - def test_conditional(self): - model = CondBranchClassMethod() - input = torch.randn(5) - onnx_program = torch.onnx.dynamo_export(model, input) - if IS_MAIN: - onnx_program.save("conditional.onnx") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/conversion_test.py b/tests/conversion_test.py new file mode 100644 index 00000000..c5670ef6 --- /dev/null +++ b/tests/conversion_test.py @@ -0,0 +1,70 @@ +"""Unit test for converting into ONNX format.""" + +import unittest +import torch +import torch_onnx +from functorch.experimental.control_flow import cond + +IS_MAIN = __name__ == "__main__" + +torch_onnx.patch_torch( + error_report=IS_MAIN, profile=IS_MAIN, dump_exported_program=IS_MAIN +) + + +class ConversionTest(unittest.TestCase): + @unittest.expectedFailure # Conditionals are not supported yet + def test_conditional(self): + class MySubModule(torch.nn.Module): + def foo(self, x): + return x.cos() + + def forward(self, x): + return self.foo(x) + + class CondBranchClassMethod(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + This example demonstrates using class method in cond(). + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def __init__(self): + super().__init__() + self.subm = MySubModule() + + def bar(self, x): + return x.sin() + + def forward(self, x): + return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x]) + + model = CondBranchClassMethod() + input = torch.randn(5) + onnx_program = torch.onnx.dynamo_export(model, input) + if IS_MAIN: + onnx_program.save("conditional.onnx") + + def test_list_as_empty_input(self): + class GraphModule(torch.nn.Module): + def forward(self, arg0_1): + view = torch.ops.aten.view.default(arg0_1, []) + return (view,) + + model = GraphModule() + input = torch.randn(1) + onnx_program = torch.onnx.dynamo_export(model, input) + if IS_MAIN: + onnx_program.save("list_as_empty_input.onnx") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/longformer.py b/tests/models/longformer.py similarity index 100% rename from tests/longformer.py rename to tests/models/longformer.py diff --git a/tests/longformer_export.py b/tests/models/longformer_export.py similarity index 100% rename from tests/longformer_export.py rename to tests/models/longformer_export.py diff --git a/tests/resnet.py b/tests/models/resnet.py similarity index 100% rename from tests/resnet.py rename to tests/models/resnet.py diff --git a/tests/resnet_export_test.py b/tests/models/resnet_export_test.py similarity index 100% rename from tests/resnet_export_test.py rename to tests/models/resnet_export_test.py diff --git a/tests/rotary_embedding.py b/tests/models/rotary_embedding.py similarity index 100% rename from tests/rotary_embedding.py rename to tests/models/rotary_embedding.py