Skip to content

Commit a65c95c

Browse files
authored
feat: support amax dynamo converter (#2241)
1 parent b774440 commit a65c95c

File tree

5 files changed

+167
-0
lines changed

5 files changed

+167
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+31
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,37 @@ def aten_ops_expand(
440440
)
441441

442442

443+
def amax_param_validator(amax_node: Node) -> bool:
444+
if len(amax_node.args) < 2:
445+
_LOGGER.debug(
446+
f"At least two args input and dim should be provided, but only got {len(amax_node.args)} args."
447+
)
448+
return False
449+
450+
return True
451+
452+
453+
@dynamo_tensorrt_converter(
454+
torch.ops.aten.amax.default, capability_validator=amax_param_validator
455+
)
456+
def aten_ops_amax(
457+
network: TRTNetwork,
458+
target: Target,
459+
args: Tuple[Argument, ...],
460+
kwargs: Dict[str, Argument],
461+
name: str,
462+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
463+
return impl.reduce.amax(
464+
network,
465+
target,
466+
SourceIR.ATEN,
467+
name,
468+
args[0],
469+
args[1],
470+
args_bounds_check(args, 2, replacement=False),
471+
)
472+
473+
443474
@dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc]
444475
def aten_ops_exp(
445476
network: TRTNetwork,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import logging
23
import re
34
from typing import List, Optional
@@ -7,6 +8,7 @@
78
from torch.fx.node import Target
89
from torch_tensorrt.fx.converters.converter_utils import (
910
Frameworks,
11+
get_axes_for_reduce_op,
1012
unified_dtype_converter,
1113
)
1214
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
@@ -157,3 +159,8 @@ def broadcastable(
157159
if not (a_shape[i] == b_shape[i] or a_shape[i] == 1 or b_shape[i] == 1):
158160
return False
159161
return True
162+
163+
164+
get_axes_for_reduce_op = functools.partial(
165+
get_axes_for_reduce_op, has_implicit_batch_dimension=False
166+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
matmul,
1010
normalization,
1111
permutation,
12+
reduce,
1213
select,
1314
shape,
1415
slice,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Optional, Tuple, Union
2+
3+
import tensorrt as trt
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_utils import (
7+
cast_trt_tensor,
8+
get_axes_for_reduce_op,
9+
)
10+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
11+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
12+
13+
14+
def amax(
15+
network: TRTNetwork,
16+
target: Target,
17+
source_ir: Optional[SourceIR],
18+
name: str,
19+
input_val: TRTTensor,
20+
dim: Union[int, Tuple[int]],
21+
keepdim: bool = False,
22+
) -> TRTTensor:
23+
if (isinstance(input_val, TRTTensor)) and (
24+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
25+
):
26+
input_val = cast_trt_tensor(network, input_val, trt.float32, name)
27+
28+
layer = network.add_reduce(
29+
input_val,
30+
trt.ReduceOperation.MAX,
31+
axes=get_axes_for_reduce_op(dim),
32+
keep_dims=keepdim,
33+
)
34+
set_layer_name(layer, target, name, source_ir)
35+
return layer.get_output(0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import torch
2+
import torch.nn as nn
3+
from harness import DispatchTestCase
4+
from parameterized import parameterized
5+
from torch.testing._internal.common_utils import run_tests
6+
7+
8+
class TestAmaxConverter(DispatchTestCase):
9+
@parameterized.expand(
10+
[
11+
((3, 2, 4), 1, True),
12+
((2, 3, 4, 5), 3, True),
13+
((2, 3, 4, 5), 2, False),
14+
((6, 7, 5, 4, 5), 4, False),
15+
]
16+
)
17+
def test_amax_dim_int_default(self, input_shape, dim, keep_dims):
18+
class Amax(nn.Module):
19+
def forward(self, x):
20+
return torch.amax(x, dim=dim, keepdim=keep_dims)
21+
22+
inputs = [torch.randn(*input_shape)]
23+
self.run_test(
24+
Amax(),
25+
inputs,
26+
expected_ops={torch.ops.aten.amax.default},
27+
)
28+
29+
@parameterized.expand(
30+
[
31+
((3, 2, 4), [1], True),
32+
((2, 1, 4, 5), [0, 3], True),
33+
((2, 3, 4, 5), [0, 1, 2, 3], False),
34+
((6, 7, 5, 4, 5), [1, 3, 4], False),
35+
]
36+
)
37+
def test_amax_dim_tuple_default(self, input_shape, dim, keep_dims):
38+
class Amax(nn.Module):
39+
def forward(self, x):
40+
return torch.amax(x, dim=dim, keepdim=keep_dims)
41+
42+
inputs = [torch.randn(*input_shape)]
43+
self.run_test(
44+
Amax(),
45+
inputs,
46+
expected_ops={torch.ops.aten.amax.default},
47+
)
48+
49+
@parameterized.expand(
50+
[
51+
((3, 2, 4), 1, True, torch.int, 0, 5),
52+
((2, 3, 4, 5), 3, True, torch.int, -10, 10),
53+
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
54+
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
55+
]
56+
)
57+
def test_amax_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high):
58+
class Amax(nn.Module):
59+
def forward(self, x):
60+
return torch.amax(x, dim=dim, keepdim=keep_dims)
61+
62+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
63+
self.run_test(
64+
Amax(),
65+
inputs,
66+
expected_ops={torch.ops.aten.amax.default},
67+
check_dtype=False,
68+
)
69+
70+
@parameterized.expand(
71+
[
72+
((3, 2, 4), [1], True, torch.int, 0, 5),
73+
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
74+
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),
75+
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5),
76+
]
77+
)
78+
def test_amax_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high):
79+
class Amax(nn.Module):
80+
def forward(self, x):
81+
return torch.amax(x, dim=dim, keepdim=keep_dims)
82+
83+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
84+
self.run_test(
85+
Amax(),
86+
inputs,
87+
expected_ops={torch.ops.aten.amax.default},
88+
check_dtype=False,
89+
)
90+
91+
92+
if __name__ == "__main__":
93+
run_tests()

0 commit comments

Comments
 (0)