Skip to content

feat: support aten.any related converters in dynamo #2578

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
Jan 23, 2024

Conversation

bowang007
Copy link
Collaborator

Description

support aten.any converter

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@bowang007 bowang007 requested a review from gs-olive January 5, 2024 04:33
@bowang007 bowang007 requested a review from zewenli98 January 5, 2024 04:33
@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests labels Jan 5, 2024
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2024-01-05 04:37:44.368463+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2024-01-05 04:39:32.661484+00:00
@@ -2571,10 +2571,11 @@
        src,
        src.dtype,
        force_layer=True,
    )

+
@dynamo_tensorrt_converter(torch.ops.aten.any.default)
@dynamo_tensorrt_converter(torch.ops.aten.any.dim)
@dynamo_tensorrt_converter(torch.ops.aten.any.dims)
def aten_ops_any(
    ctx: ConversionContext,
@@ -2589,6 +2590,6 @@
        SourceIR.ATEN,
        name,
        args[0],
        args_bounds_check(args, 1, replacement=[]),
        args_bounds_check(args, 2, replacement=False),
-    )
\ No newline at end of file
+    )

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a bool test case? Its probably a common variant of any

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few suggestions on the implementation. Also needs a rebase to resolve merge conflict.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2024-01-19 18:47:29.428220+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2024-01-19 18:49:20.020601+00:00
@@ -2623,10 +2623,11 @@
        args[0],
        args_bounds_check(args, 1, replacement=None),
        args_bounds_check(args, 2, replacement=False),
    )

+
@dynamo_tensorrt_converter(torch.ops.aten._pdist_forward.default)
@enforce_tensor_types(
    {
        0: (TRTTensor,),
    }

@bowang007 bowang007 requested a review from gs-olive January 20, 2024 00:58
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me; it seems that the test failures are unrelated to the PR.

@bowang007 bowang007 merged commit 6848571 into main Jan 23, 2024
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants