-
Notifications
You must be signed in to change notification settings - Fork 364
Enabling var_mean decomposition #2273
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change looks good! Consider adding aten.rsub
, aten.rsqrt
, aten.sqrt
, or any other decompositions which might be preferable/have recent PRs.
For instance, the following can just be replaced with aten.rsub
:
TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Lines 128 to 129 in 1ad2aeb
aten.rsub.Scalar, | |
aten.rsub.Tensor, |
@@ -175,6 +175,7 @@ | |||
aten.linalg_vector_norm, | |||
aten.full, | |||
aten.repeat, | |||
aten.var_mean, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gs-olive you mentioned the rsub
. I will add that.
Regarding sqrt
and rsqrt
, sqrt is not of lowering type so I don't think it should be included here.
For rsqrt
, should I include it here? Since rsqrt
is already present in py/torch_tensorrt/dynamo/lowering/_decompositions.py
and that would take precedence over the enabled ones, right? And we would not need to add it to the disabled ops then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is true, I agree with your points. I think the rsub
would be the main one to change; we can leave the others out, and the rsqrt
implementation in _decompositions.py
will take precedence, yes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
This PR enables the var_mean and rsub decomposition in
torch_enabled_decompositions