-
Notifications
You must be signed in to change notification settings - Fork 364
support argmax converter #2291
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
support argmax converter #2291
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/argmax.py 2023-09-05 22:31:02.244529+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/argmax.py 2023-09-05 22:33:23.441716+00:00
@@ -23,18 +23,15 @@
dim: int = 0,
keep_dim: bool = False,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
- f"argmax received input {input} that is not part "
- "of the TensorRT region!"
+ f"argmax received input {input} that is not part " "of the TensorRT region!"
)
if dim < 0:
dim = len(tuple(input.shape)) + dim
reduce_mask = 1 << dim
topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)
set_layer_name(topk_layer, target, name)
return topk_layer.get_output(1)
-
-
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_argmax.py 2023-09-05 22:31:02.264529+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_argmax.py 2023-09-05 22:33:26.764451+00:00
@@ -2,33 +2,23 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from harness import DispatchTestCase
+
class TestArgmaxConverter(DispatchTestCase):
- @parameterized.expand(
- [
- ("dim_0_keep_dim_false", (3, 4), 0, False)
- ]
- )
-
+ @parameterized.expand([("dim_0_keep_dim_false", (3, 4), 0, False)])
def test_argmax(self, _, input_shape, dim, keep_dim):
class ArgMax(nn.Module):
def __init__(self):
super().__init__()
- def forward(self, input):
+ def forward(self, input):
return torch.argmax(input, dim, keep_dim)
-
input = [torch.randn(*input_shape)]
- self.run_test(
- ArgMax(),
- input,
- expected_ops={torch.ops.aten.argmax.default}
- )
+ self.run_test(ArgMax(), input, expected_ops={torch.ops.aten.argmax.default})
+
if __name__ == "__main__":
- run_tests()
-
-
+ run_tests()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
9ca9577
to
0047b3d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
dim: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch to dim: Optional[int] = None
since this is the default dim
, as per the documentation. Alternatively, if this converter cannot support reducing over all dimensions, you can add a capability_validator
to the converter to disallow inputs where the dim
is not specified or non-integral.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used dim: Union[int, None]
, is that ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
9e62066
to
1f76a5c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
Hey @gs-olive I will be OOO next week. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
1f76a5c
to
ffe53e0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, but I left some comments about using our new APIs and small fixes.
Signed-off-by: Bo Wang <bowa@nvidia.com>
0bf93c6
to
60c576d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix! I found a small bug here. Other looks good to me!
- Added regression test
60c576d
to
668f897
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Signed-off-by: Bo Wang <bowa@nvidia.com> Co-authored-by: gs-olive <113141689+gs-olive@users.noreply.github.com>
Description
Support argmax converter
Checklist: