Skip to content

Commit d6a07bb

Browse files
authored
feat: support activation dynamo converters (#2254)
1 parent d90494a commit d6a07bb

File tree

11 files changed

+720
-90
lines changed

11 files changed

+720
-90
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+128-18
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,15 @@ def aten_ops_fmod(
152152
return impl.elementwise.fmod(network, target, SourceIR.ATEN, name, args[0], args[1])
153153

154154

155-
@dynamo_tensorrt_converter(torch.ops.aten.gelu.default) # type: ignore[misc]
156-
def aten_ops_gelu(
155+
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
156+
def aten_ops_relu(
157157
network: TRTNetwork,
158158
target: Target,
159159
args: Tuple[Argument, ...],
160160
kwargs: Dict[str, Argument],
161161
name: str,
162162
) -> Union[TRTTensor, Sequence[TRTTensor]]:
163-
return impl.activation.gelu(
163+
return impl.activation.relu(
164164
network,
165165
target,
166166
SourceIR.ATEN,
@@ -169,61 +169,171 @@ def aten_ops_gelu(
169169
)
170170

171171

172-
@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc]
173-
@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc]
174-
@dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc]
175-
def aten_ops_matmul(
172+
@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default)
173+
def aten_ops_sigmoid(
176174
network: TRTNetwork,
177175
target: Target,
178176
args: Tuple[Argument, ...],
179177
kwargs: Dict[str, Argument],
180178
name: str,
181179
) -> Union[TRTTensor, Sequence[TRTTensor]]:
182-
return impl.matmul.matrix_multiply(
180+
return impl.activation.sigmoid(
183181
network,
184182
target,
185183
SourceIR.ATEN,
186184
name,
187185
args[0],
188-
args[1],
189186
)
190187

191188

192-
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
193-
def aten_ops_layernorm(
189+
@dynamo_tensorrt_converter(torch.ops.aten.tanh.default)
190+
def aten_ops_tanh(
194191
network: TRTNetwork,
195192
target: Target,
196193
args: Tuple[Argument, ...],
197194
kwargs: Dict[str, Argument],
198195
name: str,
199196
) -> Union[TRTTensor, Sequence[TRTTensor]]:
200-
return impl.normalization.layer_norm(
197+
return impl.activation.tanh(
198+
network,
199+
target,
200+
SourceIR.ATEN,
201+
name,
202+
args[0],
203+
)
204+
205+
206+
@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default)
207+
def aten_ops_leaky_relu(
208+
network: TRTNetwork,
209+
target: Target,
210+
args: Tuple[Argument, ...],
211+
kwargs: Dict[str, Argument],
212+
name: str,
213+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
214+
return impl.activation.leaky_relu(
215+
network,
216+
target,
217+
SourceIR.ATEN,
218+
name,
219+
args[0],
220+
args_bounds_check(args, 1, 0.01),
221+
)
222+
223+
224+
@dynamo_tensorrt_converter(torch.ops.aten.elu.default)
225+
def aten_ops_elu(
226+
network: TRTNetwork,
227+
target: Target,
228+
args: Tuple[Argument, ...],
229+
kwargs: Dict[str, Argument],
230+
name: str,
231+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
232+
return impl.activation.elu(
233+
network,
234+
target,
235+
SourceIR.ATEN,
236+
name,
237+
args[0],
238+
alpha=args_bounds_check(args, 1, 1.0),
239+
beta=args_bounds_check(args, 2, None),
240+
)
241+
242+
243+
@dynamo_tensorrt_converter(torch.ops.aten.softplus.default)
244+
def aten_ops_softplus(
245+
network: TRTNetwork,
246+
target: Target,
247+
args: Tuple[Argument, ...],
248+
kwargs: Dict[str, Argument],
249+
name: str,
250+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
251+
return impl.activation.softplus(
252+
network,
253+
target,
254+
SourceIR.ATEN,
255+
name,
256+
args[0],
257+
beta=args_bounds_check(args, 1, 1),
258+
)
259+
260+
261+
@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
262+
def aten_ops_clip(
263+
network: TRTNetwork,
264+
target: Target,
265+
args: Tuple[Argument, ...],
266+
kwargs: Dict[str, Argument],
267+
name: str,
268+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
269+
return impl.activation.clip(
270+
network,
271+
target,
272+
SourceIR.ATEN,
273+
name,
274+
args[0],
275+
alpha=args_bounds_check(args, 1),
276+
beta=args_bounds_check(args, 2),
277+
)
278+
279+
280+
@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default)
281+
def aten_ops_hard_sigmoid(
282+
network: TRTNetwork,
283+
target: Target,
284+
args: Tuple[Argument, ...],
285+
kwargs: Dict[str, Argument],
286+
name: str,
287+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
288+
return impl.activation.hard_sigmoid(
289+
network,
290+
target,
291+
SourceIR.ATEN,
292+
name,
293+
args[0],
294+
alpha=args_bounds_check(args, 1, 1 / 6),
295+
beta=args_bounds_check(args, 2, 1 / 2),
296+
)
297+
298+
299+
@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc]
300+
@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc]
301+
@dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc]
302+
def aten_ops_matmul(
303+
network: TRTNetwork,
304+
target: Target,
305+
args: Tuple[Argument, ...],
306+
kwargs: Dict[str, Argument],
307+
name: str,
308+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
309+
return impl.matmul.matrix_multiply(
201310
network,
202311
target,
203312
SourceIR.ATEN,
204313
name,
205314
args[0],
206315
args[1],
207-
args[2],
208-
args[3],
209-
args[4],
210316
)
211317

212318

213-
@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc]
214-
def aten_ops_relu(
319+
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
320+
def aten_ops_layernorm(
215321
network: TRTNetwork,
216322
target: Target,
217323
args: Tuple[Argument, ...],
218324
kwargs: Dict[str, Argument],
219325
name: str,
220326
) -> Union[TRTTensor, Sequence[TRTTensor]]:
221-
return impl.activation.relu(
327+
return impl.normalization.layer_norm(
222328
network,
223329
target,
224330
SourceIR.ATEN,
225331
name,
226332
args[0],
333+
args[1],
334+
args[2],
335+
args[3],
336+
args[4],
227337
)
228338

229339

py/torch_tensorrt/dynamo/conversion/impl/activation.py

-63
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ops import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Any, Callable, Optional
2+
3+
import tensorrt as trt
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.fx.converters.converter_utils import (
7+
mark_as_int8_layer,
8+
set_layer_name,
9+
)
10+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
11+
12+
13+
def convert_activation(
14+
network: TRTNetwork,
15+
target: Target,
16+
source_ir: Optional[SourceIR],
17+
name: str,
18+
operation_type: trt.ActivationType,
19+
input_val: TRTTensor,
20+
alpha: Optional[Any] = None,
21+
beta: Optional[Any] = None,
22+
dyn_range_fn: Optional[Callable[[float, float], Any]] = None,
23+
) -> TRTTensor:
24+
"""
25+
Add a TensorRT Activation layer to `network`.
26+
"""
27+
if not isinstance(input_val, TRTTensor):
28+
raise RuntimeError(
29+
f"{operation_type} received input {input_val} that is not part "
30+
"of the TensorRT region!"
31+
)
32+
layer = network.add_activation(input_val, operation_type)
33+
if alpha is not None:
34+
layer.alpha = alpha
35+
if beta is not None:
36+
layer.beta = beta
37+
set_layer_name(layer, target, name, source_ir)
38+
39+
if input_val.dynamic_range is not None:
40+
dyn_range = dyn_range_fn(input_val.dynamic_range)
41+
mark_as_int8_layer(layer, dyn_range)
42+
return layer.get_output(0)

0 commit comments

Comments
 (0)