Skip to content

Commit

Permalink
Handle [] inputs (#116)
Browse files Browse the repository at this point in the history
Fixes #108
  • Loading branch information
justinchuby authored Jul 25, 2024
1 parent 8d551b9 commit c34c277
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 60 deletions.
7 changes: 6 additions & 1 deletion src/torch_onnx/_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,16 @@ def _process_python_constants_and_sequences(
if isinstance(arg, ir.Value):
# TODO(justinchuby): Cast the ir.Value here if needed
continue
if isinstance(arg, Sequence) and all(isinstance(val, ir.Value) for val in arg):
if (
isinstance(arg, Sequence)
and len(arg) > 0
and all(isinstance(val, ir.Value) for val in arg)
):
# Skip the sequence of ir.Value. This is a variadic input or a Sequence input
# NOTE: Variadic operators like Max can be called with mixed ir.Value and Python constants
# like `Max(0, ir.Value())`
# We need to convert the Python constants to Constant nodes
# NOTE: Important to check that arg is not empty because we need to treat it as list[int] or list[float]
continue
# if param.variadic:
# # FXIME: Handle variadic inputs and sequence inputs differently
Expand Down
59 changes: 0 additions & 59 deletions tests/conditional_test.py

This file was deleted.

70 changes: 70 additions & 0 deletions tests/conversion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Unit test for converting into ONNX format."""

import unittest
import torch
import torch_onnx
from functorch.experimental.control_flow import cond

IS_MAIN = __name__ == "__main__"

torch_onnx.patch_torch(
error_report=IS_MAIN, profile=IS_MAIN, dump_exported_program=IS_MAIN
)


class ConversionTest(unittest.TestCase):
@unittest.expectedFailure # Conditionals are not supported yet
def test_conditional(self):
class MySubModule(torch.nn.Module):
def foo(self, x):
return x.cos()

def forward(self, x):
return self.foo(x)

class CondBranchClassMethod(torch.nn.Module):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
- both branches must take the same args, which must also match the branch args passed to cond.
- both branches must return a single tensor
- returned tensor must have the same tensor metadata, e.g. shape and dtype
- branch function can be free function, nested function, lambda, class methods
- branch function can not have closure variables
- no inplace mutations on inputs or global variables
This example demonstrates using class method in cond().
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""

def __init__(self):
super().__init__()
self.subm = MySubModule()

def bar(self, x):
return x.sin()

def forward(self, x):
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])

model = CondBranchClassMethod()
input = torch.randn(5)
onnx_program = torch.onnx.dynamo_export(model, input)
if IS_MAIN:
onnx_program.save("conditional.onnx")

def test_list_as_empty_input(self):
class GraphModule(torch.nn.Module):
def forward(self, arg0_1):
view = torch.ops.aten.view.default(arg0_1, [])
return (view,)

model = GraphModule()
input = torch.randn(1)
onnx_program = torch.onnx.dynamo_export(model, input)
if IS_MAIN:
onnx_program.save("list_as_empty_input.onnx")


if __name__ == "__main__":
unittest.main()
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit c34c277

Please # to comment.