Skip to content

Commit 1696cd2

Browse files
narendasanNaren Dasan
authored and
Naren Dasan
committed
refactor: Reorging to reduce code duplication and seperating TRT implementation, example changes with ReLU
Signed-off-by: Naren Dasan <naren@narendasan.com>
1 parent c5cc6e3 commit 1696cd2

File tree

7 files changed

+217
-104
lines changed

7 files changed

+217
-104
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+59-17
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
trt_transposed_matmul,
2727
)
2828
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
29+
from torch_tensorrt.fx.converters.impl import activation
2930

3031
_LOGGER: logging.Logger = logging.getLogger(__name__)
3132

@@ -1004,9 +1005,14 @@ def acc_ops_relu(
10041005
kwargs: Dict[str, Argument],
10051006
name: str,
10061007
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1007-
input_val = kwargs["input"]
1008-
operation_type = trt.ActivationType.RELU
1009-
return add_activation_layer(network, input_val, operation_type, target, name)
1008+
1009+
return activation.relu(
1010+
network,
1011+
target,
1012+
SourceIR.ACC,
1013+
name,
1014+
kwargs["input"],
1015+
)
10101016

10111017

10121018
@tensorrt_converter(acc_ops.leaky_relu)
@@ -1020,8 +1026,14 @@ def acc_ops_leaky_relu(
10201026
input_val = kwargs["input"]
10211027
negative_slope = kwargs["negative_slope"]
10221028
operation_type = trt.ActivationType.LEAKY_RELU
1023-
return add_activation_layer(
1024-
network, input_val, operation_type, target, name, negative_slope
1029+
return activation.convert_activation(
1030+
network,
1031+
target,
1032+
SourceIR.ACC,
1033+
name,
1034+
operation_type,
1035+
input_val,
1036+
alpha=negative_slope,
10251037
)
10261038

10271039

@@ -1036,7 +1048,9 @@ def acc_ops_elu(
10361048
input_val = kwargs["input"]
10371049
alpha = kwargs["alpha"]
10381050
operation_type = trt.ActivationType.ELU
1039-
return add_activation_layer(network, input_val, operation_type, target, name, alpha)
1051+
return activation.convert_activation(
1052+
network, target, SourceIR.ACC, name, operation_type, input_val, alpha=alpha
1053+
)
10401054

10411055

10421056
@tensorrt_converter(acc_ops.selu)
@@ -1049,7 +1063,14 @@ def acc_ops_selu(
10491063
) -> Union[TRTTensor, Sequence[TRTTensor]]:
10501064
input_val = kwargs["input"]
10511065
operation_type = trt.ActivationType.SELU
1052-
return add_activation_layer(network, input_val, operation_type, target, name)
1066+
return activation.convert_activation(
1067+
network,
1068+
target,
1069+
SourceIR.ACC,
1070+
name,
1071+
operation_type,
1072+
input_val,
1073+
)
10531074

10541075

10551076
@tensorrt_converter(acc_ops.softsign)
@@ -1062,7 +1083,14 @@ def acc_ops_softsign(
10621083
) -> Union[TRTTensor, Sequence[TRTTensor]]:
10631084
input_val = kwargs["input"]
10641085
operation_type = trt.ActivationType.SOFTSIGN
1065-
return add_activation_layer(network, input_val, operation_type, target, name)
1086+
return activation.convert_activation(
1087+
network,
1088+
target,
1089+
SourceIR.ACC,
1090+
name,
1091+
operation_type,
1092+
input_val,
1093+
)
10661094

10671095

10681096
@tensorrt_converter(acc_ops.sin)
@@ -1140,7 +1168,14 @@ def acc_ops_tanh(
11401168
) -> Union[TRTTensor, Sequence[TRTTensor]]:
11411169
input_val = kwargs["input"]
11421170
operation_type = trt.ActivationType.TANH
1143-
return add_activation_layer(network, input_val, operation_type, target, name)
1171+
return activation.convert_activation(
1172+
network,
1173+
target,
1174+
SourceIR.ACC,
1175+
name,
1176+
operation_type,
1177+
input_val,
1178+
)
11441179

11451180

11461181
@tensorrt_converter(acc_ops.asin)
@@ -3137,12 +3172,13 @@ def acc_ops_hard_sigmoid(
31373172
"of the TensorRT region!"
31383173
)
31393174

3140-
return add_activation_layer(
3175+
return activation.convert_activation(
31413176
network,
3142-
input_val,
3143-
trt.ActivationType.HARD_SIGMOID,
31443177
target,
3178+
SourceIR.ACC,
31453179
name,
3180+
trt.ActivationType.HARD_SIGMOID,
3181+
input_val,
31463182
alpha=1 / 6,
31473183
beta=0.5,
31483184
)
@@ -3164,8 +3200,13 @@ def acc_ops_sigmoid(
31643200
"of the TensorRT region!"
31653201
)
31663202

3167-
return add_activation_layer(
3168-
network, input_val, trt.ActivationType.SIGMOID, target, name
3203+
return activation.convert_activation(
3204+
network,
3205+
target,
3206+
SourceIR.ACC,
3207+
name,
3208+
trt.ActivationType.SIGMOID,
3209+
input_val,
31693210
)
31703211

31713212

@@ -3557,12 +3598,13 @@ def acc_ops_hardtanh(
35573598
"of the TensorRT region!"
35583599
)
35593600

3560-
return add_activation_layer(
3601+
return activation.convert_activation(
35613602
network,
3562-
input_val,
3563-
trt.ActivationType.CLIP,
35643603
target,
3604+
SourceIR.ACC,
35653605
name,
3606+
trt.ActivationType.CLIP,
3607+
input_val,
35663608
alpha=kwargs["min_val"],
35673609
beta=kwargs["max_val"],
35683610
)

py/torch_tensorrt/fx/converters/activation.py

-39
Original file line numberDiff line numberDiff line change
@@ -9,45 +9,6 @@
99
from .converter_utils import mark_as_int8_layer
1010

1111

12-
def common_activation(
13-
network, mod, input_val, activation_type, activation_dyn_range_fn, layer_name
14-
):
15-
layer = network.add_activation(input=input_val, type=activation_type)
16-
layer.name = layer_name
17-
18-
if input_val.dynamic_range:
19-
dyn_range = activation_dyn_range_fn(input_val.dynamic_range)
20-
mark_as_int8_layer(layer, dyn_range)
21-
22-
return layer.get_output(0)
23-
24-
25-
@tensorrt_converter(torch.nn.functional.relu)
26-
@tensorrt_converter(torch.nn.modules.activation.ReLU)
27-
def relu(network, submod, args, kwargs, layer_name):
28-
# args/kwargs should have already been normalized to kwargs
29-
assert len(args) == 0
30-
input_val = kwargs["input"]
31-
32-
if not isinstance(input_val, trt.tensorrt.ITensor):
33-
raise RuntimeError(
34-
f"ReLU received input {input_val} that is not part "
35-
"of the TensorRT region!"
36-
)
37-
38-
def activation_dyn_range_fn(dyn_range):
39-
return max(0, dyn_range[0]), max(0, dyn_range[1])
40-
41-
return common_activation(
42-
network,
43-
submod,
44-
input_val,
45-
trt.ActivationType.RELU,
46-
activation_dyn_range_fn,
47-
layer_name,
48-
)
49-
50-
5112
@tensorrt_converter(torch.nn.modules.activation.Sigmoid)
5213
def sigmoid(network, submod, args, kwargs, layer_name):
5314
# args/kwargs should have already been normalized to kwargs

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .converter_utils import * # noqa: F403
2424
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
25+
from torch_tensorrt.fx.converters.impl import activation
2526

2627
_LOGGER: logging.Logger = logging.getLogger(__name__)
2728

@@ -290,10 +291,14 @@ def aten_ops_relu(
290291
kwargs: Dict[str, Argument],
291292
name: str,
292293
) -> Union[TRTTensor, Sequence[TRTTensor]]:
293-
kwargs_new = {
294-
"input": args[0],
295-
}
296-
return acc_ops_converters.acc_ops_relu(network, target, None, kwargs_new, name)
294+
295+
return activation.relu(
296+
network,
297+
target,
298+
SourceIR.ATEN,
299+
name,
300+
args[0],
301+
)
297302

298303

299304
@tensorrt_converter(torch.ops.aten.sub.Tensor)

py/torch_tensorrt/fx/converters/converter_utils.py

+33-44
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
44

5+
from enum import Enum, auto
56
import numpy as np
67

78
# @manual=//deeplearning/trt/python:py_tensorrt
@@ -22,6 +23,26 @@
2223
from ..utils import torch_dtype_from_trt
2324

2425

26+
class SourceIR(Enum):
27+
NN = auto()
28+
ACC = auto()
29+
ATEN = auto()
30+
PRIM = auto()
31+
UNKNOWN = auto()
32+
33+
def __str__(self):
34+
if self == SourceIR.NN:
35+
return "nn"
36+
elif self == SourceIR.ACC:
37+
return "acc"
38+
elif self == SourceIR.ATEN:
39+
return "aten"
40+
elif self == SourceIR.PRIM:
41+
return "prim"
42+
else:
43+
return "unknown_ir"
44+
45+
2546
def get_trt_plugin(
2647
plugin_name: str,
2748
field_collection: List[TRTPluginFieldCollection],
@@ -77,7 +98,9 @@ def get_positive_dim(dim: int, dim_size: int) -> int:
7798
return dim
7899

79100

80-
def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
101+
def set_layer_name(
102+
layer: TRTLayer, target: Target, name: str, source_ir: Optional[SourceIR] = None
103+
) -> None:
81104
"""
82105
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
83106
@@ -86,8 +109,16 @@ def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
86109
target (Target): A fx node.target. For call_function node, it's the function that
87110
the node represents.
88111
name (str): Consists of fx node.name with optional suffix.
112+
source_ir: (Optional[SourceIR]): The IR producing the op.
89113
"""
90-
target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}"
114+
115+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
116+
117+
target_name = (
118+
f"{source_ir}_ops.{target}"
119+
if isinstance(target, str)
120+
else f"{source_ir}_ops.{target.__name__}"
121+
)
91122
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"
92123

93124

@@ -560,48 +591,6 @@ def add_unary_layer(
560591
return layer.get_output(0)
561592

562593

563-
def add_activation_layer(
564-
network: TRTNetwork,
565-
input_val: TRTTensor,
566-
operation_type: trt.ActivationType,
567-
target: Target,
568-
name: str,
569-
alpha: Optional[Any] = None,
570-
beta: Optional[Any] = None,
571-
) -> TRTTensor:
572-
"""
573-
Add a TensorRT Activation layer to `network`.
574-
575-
Args:
576-
network (TRTNetwork): TensorRT network object.
577-
input_val (TRTTensor): Input to the activation op.
578-
Must be a TensorRT tensor.
579-
op_type (trt.ElementWiseOperation): Type of the TensorRT activation
580-
operation.
581-
target (Target): Target of fx node.
582-
name (str): The name we want to assign to the created TensorRT layer.
583-
alpha (Optional[Any]): If not None, we will use it to set the alpha
584-
attribute of the created TensorRT activation layer.
585-
beta (Optional[Any]): If not None, we will use it to set the beta
586-
attribute of the created TensorRT activation layer.
587-
588-
Returns:
589-
The output of TensorRT Activation layer.
590-
"""
591-
if not isinstance(input_val, TRTTensor):
592-
raise RuntimeError(
593-
f"{operation_type} received input {input_val} that is not part "
594-
"of the TensorRT region!"
595-
)
596-
layer = network.add_activation(input_val, operation_type)
597-
if alpha is not None:
598-
layer.alpha = alpha
599-
if beta is not None:
600-
layer.beta = beta
601-
set_layer_name(layer, target, name)
602-
return layer.get_output(0)
603-
604-
605594
def add_reduce_layer(
606595
network: TRTNetwork,
607596
target: Target,

py/torch_tensorrt/fx/converters/impl/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)