Skip to content

Commit 3ec137b

Browse files
narendasangs-olive
authored andcommitted
refactor: Centralizing sigmoid implementation
Signed-off-by: Naren Dasan <naren@narendasan.com>
1 parent 12f545c commit 3ec137b

File tree

7 files changed

+114
-48
lines changed

7 files changed

+114
-48
lines changed

Diff for: py/torch_tensorrt/fx/converters/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import tensorrt as trt
33

44
if hasattr(trt, "__version__"):
5-
from .activation import * # noqa: F401 F403
65
from .adaptive_avgpool import * # noqa: F401 F403
76
from .add import * # noqa: F401 F403
87
from .batchnorm import * # noqa: F401 F403

Diff for: py/torch_tensorrt/fx/converters/acc_ops_converters.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -3192,21 +3192,13 @@ def acc_ops_sigmoid(
31923192
kwargs: Dict[str, Argument],
31933193
name: str,
31943194
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3195-
input_val = kwargs["input"]
31963195

3197-
if not isinstance(input_val, TRTTensor):
3198-
raise RuntimeError(
3199-
f"Sigmoid received input {input_val} that is not part "
3200-
"of the TensorRT region!"
3201-
)
3202-
3203-
return activation.convert_activation(
3196+
return activation.sigmoid(
32043197
network,
32053198
target,
32063199
SourceIR.ACC,
32073200
name,
3208-
trt.ActivationType.SIGMOID,
3209-
input_val,
3201+
kwargs["input"],
32103202
)
32113203

32123204

Diff for: py/torch_tensorrt/fx/converters/activation.py

-37
This file was deleted.

Diff for: py/torch_tensorrt/fx/converters/aten_ops_converters.py

+18
Original file line numberDiff line numberDiff line change
@@ -484,3 +484,21 @@ def aten_ops_sym_size(
484484
)
485485
set_layer_name(slice_layer, target, "_slice_layer")
486486
return slice_layer.get_output(0)
487+
488+
489+
@tensorrt_converter(torch.ops.aten.sigmoid.default)
490+
def aten_ops_sigmoid(
491+
network: TRTNetwork,
492+
target: Target,
493+
args: Tuple[Argument, ...],
494+
kwargs: Dict[str, Argument],
495+
name: str,
496+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
497+
498+
return activation.sigmoid(
499+
network,
500+
target,
501+
SourceIR.ATEN,
502+
name,
503+
args[0],
504+
)

Diff for: py/torch_tensorrt/fx/converters/impl/activation.py

+27
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,30 @@ def relu_dyn_range_fn(dyn_range):
9090
input_val,
9191
dyn_range_fn=relu_dyn_range_fn,
9292
)
93+
94+
95+
def sigmoid(
96+
network: TRTNetwork,
97+
target: Target,
98+
source_ir: Optional[SourceIR],
99+
name: str,
100+
input_val: TRTTensor,
101+
):
102+
operation_type = trt.ActivationType.SIGMOID
103+
104+
def sigmoid_dyn_range_fn(dyn_range):
105+
def sigmoid_fn(x):
106+
# TODO: Can this just call torch.nn.functional.sigmoid?
107+
return 1 / (1 + np.exp(-x))
108+
109+
return sigmoid_fn(dyn_range[0]), sigmoid_fn(dyn_range[1])
110+
111+
return convert_activation(
112+
network,
113+
target,
114+
source_ir,
115+
name,
116+
operation_type,
117+
input_val,
118+
dyn_range_fn=sigmoid_dyn_range_fn,
119+
)

Diff for: py/torch_tensorrt/fx/converters/nn_ops_converters.py

+14
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,17 @@ def relu(network, submod, args, kwargs, layer_name):
2222
name=layer_name,
2323
input_val=kwargs["input"],
2424
)
25+
26+
27+
@tensorrt_converter(torch.nn.modules.activation.Sigmoid)
28+
def sigmoid(network, submod, args, kwargs, layer_name):
29+
# args/kwargs should have already been normalized to kwargs
30+
assert len(args) == 0
31+
32+
activation.sigmoid(
33+
network=network,
34+
target="torch.nn.modules.activation.Sigmoid",
35+
source_ir=SourceIR.NN,
36+
name=layer_name,
37+
input_val=kwargs["input"],
38+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestSigmoidConverter(DispatchTestCase):
8+
def test_sigmoid(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.sigmoid(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(
15+
TestModule(), inputs, expected_ops={torch.ops.aten.sigmoid.default}
16+
)
17+
18+
def test_sigmoid_with_dynamic_shape(self):
19+
class TestModule(nn.Module):
20+
def forward(self, x):
21+
return nn.functional.sigmoid(x)
22+
23+
input_specs = [
24+
InputTensorSpec(
25+
shape=(-1, -1, -1),
26+
dtype=torch.float32,
27+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
28+
),
29+
]
30+
self.run_test_with_dynamic_shape(
31+
TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default}
32+
)
33+
34+
def test_sigmoid_with_dynamic_shape_four_dimensions(self):
35+
class TestModule(nn.Module):
36+
def forward(self, x):
37+
return nn.functional.sigmoid(x)
38+
39+
input_specs = [
40+
InputTensorSpec(
41+
shape=(-1, -1, -1, -1),
42+
dtype=torch.float32,
43+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
44+
),
45+
]
46+
47+
self.run_test_with_dynamic_shape(
48+
TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default}
49+
)
50+
51+
52+
if __name__ == "__main__":
53+
run_tests()

0 commit comments

Comments
 (0)