Skip to content

Commit 6c9832a

Browse files
committed
refactor: Centralizing sigmoid implementation (FX Converter Refactor [2/N]) <Target: converter_reorg_proto> (#1868)
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 847ebff commit 6c9832a

File tree

7 files changed

+128
-48
lines changed

7 files changed

+128
-48
lines changed

Diff for: py/torch_tensorrt/fx/converters/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import tensorrt as trt
33

44
if hasattr(trt, "__version__"):
5-
from .activation import * # noqa: F401 F403
65
from .adaptive_avgpool import * # noqa: F401 F403
76
from .add import * # noqa: F401 F403
87
from .batchnorm import * # noqa: F401 F403

Diff for: py/torch_tensorrt/fx/converters/acc_ops_converters.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -3192,21 +3192,13 @@ def acc_ops_sigmoid(
31923192
kwargs: Dict[str, Argument],
31933193
name: str,
31943194
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3195-
input_val = kwargs["input"]
31963195

3197-
if not isinstance(input_val, TRTTensor):
3198-
raise RuntimeError(
3199-
f"Sigmoid received input {input_val} that is not part "
3200-
"of the TensorRT region!"
3201-
)
3202-
3203-
return activation.convert_activation(
3196+
return activation.sigmoid(
32043197
network,
32053198
target,
32063199
SourceIR.ACC,
32073200
name,
3208-
trt.ActivationType.SIGMOID,
3209-
input_val,
3201+
kwargs["input"],
32103202
)
32113203

32123204

Diff for: py/torch_tensorrt/fx/converters/activation.py

-37
This file was deleted.

Diff for: py/torch_tensorrt/fx/converters/aten_ops_converters.py

+18
Original file line numberDiff line numberDiff line change
@@ -484,3 +484,21 @@ def aten_ops_sym_size(
484484
)
485485
set_layer_name(slice_layer, target, "_slice_layer")
486486
return slice_layer.get_output(0)
487+
488+
489+
@tensorrt_converter(torch.ops.aten.sigmoid.default)
490+
def aten_ops_sigmoid(
491+
network: TRTNetwork,
492+
target: Target,
493+
args: Tuple[Argument, ...],
494+
kwargs: Dict[str, Argument],
495+
name: str,
496+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
497+
498+
return activation.sigmoid(
499+
network,
500+
target,
501+
SourceIR.ATEN,
502+
name,
503+
args[0],
504+
)

Diff for: py/torch_tensorrt/fx/converters/impl/activation.py

+27
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,30 @@ def relu_dyn_range_fn(dyn_range):
9090
input_val,
9191
dyn_range_fn=relu_dyn_range_fn,
9292
)
93+
94+
95+
def sigmoid(
96+
network: TRTNetwork,
97+
target: Target,
98+
source_ir: Optional[SourceIR],
99+
name: str,
100+
input_val: TRTTensor,
101+
):
102+
operation_type = trt.ActivationType.SIGMOID
103+
104+
def sigmoid_dyn_range_fn(dyn_range):
105+
def sigmoid_fn(x):
106+
# TODO: Can this just call torch.nn.functional.sigmoid?
107+
return 1 / (1 + np.exp(-x))
108+
109+
return sigmoid_fn(dyn_range[0]), sigmoid_fn(dyn_range[1])
110+
111+
return convert_activation(
112+
network,
113+
target,
114+
source_ir,
115+
name,
116+
operation_type,
117+
input_val,
118+
dyn_range_fn=sigmoid_dyn_range_fn,
119+
)

Diff for: py/torch_tensorrt/fx/converters/nn_ops_converters.py

+14
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,17 @@ def relu(network, submod, args, kwargs, layer_name):
2222
name=layer_name,
2323
input_val=kwargs["input"],
2424
)
25+
26+
27+
@tensorrt_converter(torch.nn.modules.activation.Sigmoid)
28+
def sigmoid(network, submod, args, kwargs, layer_name):
29+
# args/kwargs should have already been normalized to kwargs
30+
assert len(args) == 0
31+
32+
activation.sigmoid(
33+
network=network,
34+
target="torch.nn.modules.activation.Sigmoid",
35+
source_ir=SourceIR.NN,
36+
name=layer_name,
37+
input_val=kwargs["input"],
38+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.utils import LowerPrecision
5+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
6+
7+
8+
class TestSigmoidConverter(DispatchTestCase):
9+
def test_sigmoid(self):
10+
class TestModule(nn.Module):
11+
def forward(self, x):
12+
return nn.functional.sigmoid(x)
13+
14+
inputs = [torch.randn(1, 10)]
15+
self.run_test(
16+
TestModule(), inputs, expected_ops={torch.ops.aten.sigmoid.default}
17+
)
18+
19+
def test_sigmoid_with_dynamic_shape(self):
20+
class TestModule(nn.Module):
21+
def forward(self, x):
22+
return nn.functional.sigmoid(x)
23+
24+
input_specs = [
25+
InputTensorSpec(
26+
shape=(-1, -1, -1),
27+
dtype=torch.float32,
28+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
29+
),
30+
]
31+
self.run_test_with_dynamic_shape(
32+
TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default}
33+
)
34+
35+
def test_sigmoid_with_dynamic_shape_four_dimensions(self):
36+
class TestModule(nn.Module):
37+
def forward(self, x):
38+
return nn.functional.sigmoid(x)
39+
40+
input_specs = [
41+
InputTensorSpec(
42+
shape=(-1, -1, -1, -1),
43+
dtype=torch.float32,
44+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
45+
),
46+
]
47+
48+
self.run_test_with_dynamic_shape(
49+
TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default}
50+
)
51+
52+
def test_sigmoid_fp16(self):
53+
class TestModule(nn.Module):
54+
def forward(self, x):
55+
return nn.functional.sigmoid(x)
56+
57+
inputs = [torch.randn(1, 10)]
58+
self.run_test(
59+
TestModule(),
60+
inputs,
61+
expected_ops={torch.ops.aten.sigmoid.default},
62+
precision=torch.half,
63+
)
64+
65+
66+
if __name__ == "__main__":
67+
run_tests()

0 commit comments

Comments
 (0)