-
Notifications
You must be signed in to change notification settings - Fork 462
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Add subtract tensor from scalar for ONNX sub op #1964
Add subtract tensor from scalar for ONNX sub op #1964
Conversation
785c5ed
to
cb23e98
Compare
b0a5ac6
to
cb6c1d3
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1964 +/- ##
=======================================
Coverage 85.29% 85.30%
=======================================
Files 798 798
Lines 95512 95522 +10
=======================================
+ Hits 81471 81481 +10
Misses 14041 14041 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for fixing the bug.
It looks good overall. I have minor suggestion to imporove.
@@ -131,6 +131,9 @@ impl BinaryNode { | |||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) }, | |||
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) }, | |||
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs }, | |||
(Type::Scalar(_), Type::Tensor(_)) => { | |||
move |lhs, rhs| quote! { #rhs.mul_scalar(-1).add_scalar(#lhs) } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be more efficient if we have one tensor op and rely on compiler to negate #lhs
. We can rewrite as follows: #rhs.add_scalar(-(#lhs)). So generated code might look like this:
#rhs.add_scalar(-(- 42)). And Rust compiler will precompute number literal correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks or review @antimora!
Since this is lhs (scalar) - rhs (tensor), looks like -#rhs.sub_scalar(#lhs)
produced the right result. I will make an update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the already requested change, LGTM!
Thanks for fixing :)
Make sure sub op is more efficient by using one operator
cb6c1d3
to
af2138d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your fix! LGTM
@@ -131,6 +131,7 @@ impl BinaryNode { | |||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) }, | |||
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) }, | |||
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs }, | |||
(Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { -#rhs.sub_scalar(#lhs) }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
FYI @nathanielsimard , @laggui , @louisfd
Another reason to support native Scalar type in Burn. See our earlier design discussion: #1689 (comment)
{ -#rhs.sub_scalar(#lhs) },
solution to scalar - tensor
will result in two operations instead of one.
I'll merge to the main once the CI passes again. |
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
Help Wanted: Implementing ONNX Ops
Changes
Previous sub and sub_int do not handle a scalar subtracting a tensor. This is a problem when implementing ONNX ops like pad.
Added support for sub and sub_int to handle that scenario
Testing
crates/burn-import/onnx-tests
cargo test
crates/burn-import
cargo test