Skip to content

Commit 1a09463

Browse files
committed
[dynamo/converter] support neg converter
1 parent 0527edd commit 1a09463

File tree

4 files changed

+89
-0
lines changed

4 files changed

+89
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+23
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,29 @@ def aten_ops_rsqrt(
251251
)
252252

253253

254+
@dynamo_tensorrt_converter(torch.ops.aten.neg.default)
255+
def aten_ops_neg(
256+
network: TRTNetwork,
257+
target: Target,
258+
args: Tuple[Argument, ...],
259+
kwargs: Dict[str, Argument],
260+
name: str,
261+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
262+
input_val = args[0]
263+
if (isinstance(input_val, TRTTensor)) and (
264+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
265+
):
266+
input_val = cast_trt_tensor(network, input_val, trt.float32, name)
267+
268+
return impl.unary.neg(
269+
network,
270+
target,
271+
SourceIR.ATEN,
272+
name,
273+
input_val,
274+
)
275+
276+
254277
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim)
255278
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims)
256279
def aten_ops_squeeze(

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

+12
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,15 @@ def sign(
9696
double_floor_div_output,
9797
1,
9898
)
99+
100+
101+
def neg(
102+
network: TRTNetwork,
103+
target: Target,
104+
source_ir: Optional[SourceIR],
105+
name: str,
106+
input_val: TRTTensor,
107+
) -> TRTTensor:
108+
return convert_unary(
109+
network, target, source_ir, name, trt.UnaryOperation.NEG, input_val
110+
)

py/torch_tensorrt/dynamo/test_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def run_test(
261261
atol=1e-03,
262262
precision=torch.float,
263263
check_dtype=True,
264+
output_dtypes=None,
264265
):
265266
mod.eval()
266267
mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None)
@@ -272,6 +273,7 @@ def run_test(
272273
interp = TRTInterpreter(
273274
mod,
274275
Input.from_tensors(inputs),
276+
output_dtypes=output_dtypes,
275277
)
276278
super().run_test(
277279
mod,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.dynamo.test_utils import DispatchTestCase
6+
from torch_tensorrt import Input
7+
8+
9+
class TestNegConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d_dim_dtype_float", (2, 2), torch.float),
13+
("3d_dim_dtype_float", (2, 2, 2), torch.float),
14+
("2d_dim_dtype_half", (2, 2), torch.half),
15+
("3d_dim_dtype_half", (2, 2, 2), torch.half),
16+
]
17+
)
18+
def test_neg_float(self, _, x, type):
19+
class neg(nn.Module):
20+
def forward(self, input):
21+
return torch.neg(input)
22+
23+
inputs = [torch.randn(x, dtype=type)]
24+
self.run_test(
25+
neg(),
26+
inputs,
27+
precision=type,
28+
expected_ops={torch.ops.aten.neg.default},
29+
)
30+
31+
@parameterized.expand(
32+
[
33+
("2d_dim_dtype_int32", (2, 2), torch.int32, 0, 5),
34+
("3d_dim_dtype_int32", (2, 2, 2), torch.int32, 0, 5),
35+
]
36+
)
37+
def test_neg_int(self, _, x, type, min, max):
38+
class neg(nn.Module):
39+
def forward(self, input):
40+
return torch.neg(input)
41+
42+
inputs = [torch.randint(min, max, x, dtype=type)]
43+
self.run_test(
44+
neg(),
45+
inputs,
46+
output_dtypes=[torch.int32],
47+
expected_ops={torch.ops.aten.neg.default},
48+
)
49+
50+
51+
if __name__ == "__main__":
52+
run_tests()

0 commit comments

Comments
 (0)