Skip to content

support argmax converter #2291

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 4 commits into from
Oct 10, 2023
Merged

support argmax converter #2291

merged 4 commits into from
Oct 10, 2023

Conversation

bowang007
Copy link
Collaborator

Description

Support argmax converter

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests labels Sep 5, 2023
@github-actions github-actions bot requested a review from gs-olive September 5, 2023 22:31
@bowang007 bowang007 changed the title support argmax converter support argmax converter [Draft] Sep 5, 2023
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/argmax.py	2023-09-05 22:31:02.244529+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/argmax.py	2023-09-05 22:33:23.441716+00:00
@@ -23,18 +23,15 @@
    dim: int = 0,
    keep_dim: bool = False,
) -> TRTTensor:
    if not isinstance(input, TRTTensor):
        raise RuntimeError(
-            f"argmax received input {input} that is not part "
-            "of the TensorRT region!"
+            f"argmax received input {input} that is not part " "of the TensorRT region!"
        )
    if dim < 0:
        dim = len(tuple(input.shape)) + dim
    reduce_mask = 1 << dim
    topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)

    set_layer_name(topk_layer, target, name)

    return topk_layer.get_output(1)
-    
-    
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_argmax.py	2023-09-05 22:31:02.264529+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_argmax.py	2023-09-05 22:33:26.764451+00:00
@@ -2,33 +2,23 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from harness import DispatchTestCase

+
class TestArgmaxConverter(DispatchTestCase):
-    @parameterized.expand(
-            [
-                ("dim_0_keep_dim_false", (3, 4), 0, False)
-            ]
-    )
-
+    @parameterized.expand([("dim_0_keep_dim_false", (3, 4), 0, False)])
    def test_argmax(self, _, input_shape, dim, keep_dim):
        class ArgMax(nn.Module):
            def __init__(self):
                super().__init__()

-            def forward(self, input): 
+            def forward(self, input):
                return torch.argmax(input, dim, keep_dim)
-            

        input = [torch.randn(*input_shape)]

-        self.run_test(
-            ArgMax(),
-            input, 
-            expected_ops={torch.ops.aten.argmax.default}
-        )
+        self.run_test(ArgMax(), input, expected_ops={torch.ops.aten.argmax.default})
+

if __name__ == "__main__":
-    run_tests()  
-
-
+    run_tests()

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

@bowang007 bowang007 changed the title support argmax converter [Draft] support argmax converter Sep 22, 2023
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@bowang007 bowang007 force-pushed the argmax_converter_dynamo branch from 9ca9577 to 0047b3d Compare September 22, 2023 04:23
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to dim: Optional[int] = None since this is the default dim, as per the documentation. Alternatively, if this converter cannot support reducing over all dimensions, you can add a capability_validator to the converter to disallow inputs where the dim is not specified or non-integral.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used dim: Union[int, None], is that ok?

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@bowang007
Copy link
Collaborator Author

bowang007 commented Oct 7, 2023

Hey @gs-olive I will be OOO next week.
I think this update covers all edge cases.
Please feel free to merge if this is good to go. Thanks!

@bowang007 bowang007 requested a review from gs-olive October 7, 2023 05:30
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@gs-olive gs-olive force-pushed the argmax_converter_dynamo branch from 1f76a5c to ffe53e0 Compare October 10, 2023 01:23
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@gs-olive gs-olive requested a review from apbose October 10, 2023 01:31
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, but I left some comments about using our new APIs and small fixes.

@gs-olive gs-olive force-pushed the argmax_converter_dynamo branch from 0bf93c6 to 60c576d Compare October 10, 2023 19:44
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! I found a small bug here. Other looks good to me!

- Added regression test
@gs-olive gs-olive force-pushed the argmax_converter_dynamo branch from 60c576d to 668f897 Compare October 10, 2023 21:12
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@gs-olive gs-olive merged commit f3f475b into main Oct 10, 2023
@gs-olive gs-olive deleted the argmax_converter_dynamo branch October 10, 2023 22:20
gs-olive added a commit that referenced this pull request Oct 10, 2023
Signed-off-by: Bo Wang <bowa@nvidia.com>
Co-authored-by: gs-olive <113141689+gs-olive@users.noreply.github.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests priority: high
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants