Skip to content

Reorg for converters elu and selu (FX Converter Refactor [7/N]) <Target: converter_reorg_proto> #1903

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,11 +1040,14 @@ def acc_ops_elu(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
alpha = kwargs["alpha"]
operation_type = trt.ActivationType.ELU
return activation.convert_activation(
network, target, SourceIR.ACC, name, operation_type, input_val, alpha=alpha

return activation.elu(
network,
target,
SourceIR.ACC,
name,
kwargs["input"],
kwargs["alpha"],
)


Expand All @@ -1056,15 +1059,13 @@ def acc_ops_selu(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
operation_type = trt.ActivationType.SELU
return activation.convert_activation(

return activation.selu(
network,
target,
SourceIR.ACC,
name,
operation_type,
input_val,
kwargs["input"],
)


Expand Down
27 changes: 27 additions & 0 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,33 @@ def aten_ops_div(
)


@tensorrt_converter(torch.ops.aten.elu.default)
def aten_ops_elu(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:

if len(args) > 2:
return activation.selu(
network,
target,
SourceIR.ATEN,
name,
args[0],
)
Comment on lines +182 to +189
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the purpose of this if logic?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The aten trace for both selu and elu comes as torch.ops.aten.elu.default, but with different args. elu has args[0] and args[1] for alpha, while selu has only one arg args[0]. Just a way to differentiate between the two.

return activation.elu(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@tensorrt_converter(torch.ops.aten.floor_divide.default)
def aten_ops_floor_div(
network: TRTNetwork,
Expand Down
48 changes: 48 additions & 0 deletions py/torch_tensorrt/fx/converters/impl/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,51 @@ def leaky_relu_dyn_range_fn(dyn_range):
alpha,
dyn_range_fn=leaky_relu_dyn_range_fn,
)


def elu(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
alpha: Optional[Any],
):
operation_type = trt.ActivationType.ELU

def elu_dyn_range_fn(dyn_range):
return (torch.nn.ELU(dyn_range[0]), torch.nn.ELU(dyn_range[1]))

return convert_activation(
network,
target,
source_ir,
name,
operation_type,
input_val,
alpha,
dyn_range_fn=elu_dyn_range_fn,
)


def selu(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
):
operation_type = trt.ActivationType.SELU

def elu_dyn_range_fn(dyn_range):
return (torch.nn.SELU(dyn_range[0]), torch.nn.ELU(dyn_range[1]))

return convert_activation(
network,
target,
source_ir,
name,
operation_type,
input_val,
dyn_range_fn=elu_dyn_range_fn,
)
31 changes: 31 additions & 0 deletions py/torch_tensorrt/fx/converters/nn_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,34 @@ def leaky_relu(network, submod, args, kwargs, layer_name):
input_val=kwargs["input"],
alpha=kwargs["negative_slope"],
)


@tensorrt_converter(torch.nn.functional.elu)
@tensorrt_converter(torch.nn.modules.activation.ELU)
def elu(network, submod, args, kwargs, layer_name):
# args/kwargs should have already been normalized to kwargs
assert len(args) == 0

return activation.elu(
network=network,
target="torch.nn.functional.elu",
source_ir=SourceIR.NN,
name=layer_name,
input_val=kwargs["input"],
)


@tensorrt_converter(torch.nn.functional.selu)
@tensorrt_converter(torch.nn.modules.activation.SELU)
def selu(network, submod, args, kwargs, layer_name):
# args/kwargs should have already been normalized to kwargs
assert len(args) == 0

return activation.selu(
network=network,
target="torch.nn.functional.selu",
source_ir=SourceIR.NN,
name=layer_name,
input_val=kwargs["input"],
alpha=kwargs["alpha"],
)
51 changes: 51 additions & 0 deletions py/torch_tensorrt/fx/test/converters/aten_op/test_elu_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec


class TestELUConverter(DispatchTestCase):
def test_elu(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.elu(x)

inputs = [torch.randn(1, 10)]
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default})

def test_elu_with_dynamic_shape(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.elu(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
),
]
self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
)

def test_elu_with_dynamic_shape_four_dimensions(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.elu(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
),
]

self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
)


if __name__ == "__main__":
run_tests()
51 changes: 51 additions & 0 deletions py/torch_tensorrt/fx/test/converters/aten_op/test_selu_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec


class TestSeLUConverter(DispatchTestCase):
def test_selu(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.selu(x)

inputs = [torch.randn(1, 10)]
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default})

def test_selu_with_dynamic_shape(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.selu(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
),
]
self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
)

def test_selu_with_dynamic_shape_four_dimensions(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.selu(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
),
]

self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
)


if __name__ == "__main__":
run_tests()