Skip to content

Commit 1d732a5

Browse files
committed
support neg converter
1 parent aa1c843 commit 1d732a5

File tree

3 files changed

+80
-0
lines changed

3 files changed

+80
-0
lines changed

Diff for: py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+18
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,24 @@ def aten_ops_rsqrt(
251251
)
252252

253253

254+
@dynamo_tensorrt_converter(torch.ops.aten.neg.default)
255+
def aten_ops_neg(
256+
network: TRTNetwork,
257+
target: Target,
258+
args: Tuple[Argument, ...],
259+
kwargs: Dict[str, Argument],
260+
name: str,
261+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
262+
263+
return impl.unary.neg(
264+
network,
265+
target,
266+
SourceIR.ATEN,
267+
name,
268+
args[0],
269+
)
270+
271+
254272
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim)
255273
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims)
256274
def aten_ops_squeeze(

Diff for: py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

+12
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,15 @@ def sign(
9696
double_floor_div_output,
9797
1,
9898
)
99+
100+
101+
def neg(
102+
network: TRTNetwork,
103+
target: Target,
104+
source_ir: Optional[SourceIR],
105+
name: str,
106+
input_val: TRTTensor,
107+
) -> TRTTensor:
108+
return convert_unary(
109+
network, target, source_ir, name, trt.UnaryOperation.NEG, input_val
110+
)

Diff for: tests/py/dynamo/converters/test_neg_aten.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.dynamo.test_utils import DispatchTestCase
6+
from torch_tensorrt import Input
7+
8+
9+
class TestNegConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d_dim_dtype_float", (2, 2), torch.float),
13+
("3d_dim_dtype_float", (2, 2, 2), torch.float),
14+
15+
]
16+
)
17+
def test_neg_float(self, _, x, type):
18+
class neg(nn.Module):
19+
def forward(self, input):
20+
return torch.neg(input)
21+
22+
inputs = [torch.randn(x, dtype=type)]
23+
self.run_test(
24+
neg(),
25+
inputs,
26+
expected_ops={torch.ops.aten.neg.default},
27+
)
28+
29+
@parameterized.expand(
30+
[
31+
("2d_dim_dtype_int", (2, 2), torch.int32, 0, 5),
32+
("3d_dim_dtype_int", (2, 2, 2), torch.int32, 0, 5),
33+
]
34+
)
35+
36+
def test_neg_int(self, _, x, type, min, max):
37+
class neg(nn.Module):
38+
def forward(self, input):
39+
return torch.neg(input)
40+
41+
inputs = [torch.randint(min, max, (x), dtype=type)]
42+
43+
self.run_test(
44+
neg(),
45+
inputs,
46+
expected_ops={torch.ops.aten.neg.default},
47+
)
48+
49+
if __name__ == "__main__":
50+
run_tests()

0 commit comments

Comments
 (0)