Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add subtract tensor from scalar for ONNX sub op #1964

Conversation

johnhuichen
Copy link
Contributor

@johnhuichen johnhuichen commented Jul 4, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Help Wanted: Implementing ONNX Ops

Changes

Previous sub and sub_int do not handle a scalar subtracting a tensor. This is a problem when implementing ONNX ops like pad.

Added support for sub and sub_int to handle that scenario

Testing

cargo xtask run-checks all

crates/burn-import/onnx-tests

cargo test 

crates/burn-import

cargo test

@johnhuichen johnhuichen force-pushed the add-subtract-tensor-from-scalar-for-onnx-sub branch 2 times, most recently from 785c5ed to cb23e98 Compare July 4, 2024 14:36
@johnhuichen johnhuichen marked this pull request as ready for review July 4, 2024 14:37
@johnhuichen johnhuichen force-pushed the add-subtract-tensor-from-scalar-for-onnx-sub branch 2 times, most recently from b0a5ac6 to cb6c1d3 Compare July 4, 2024 15:12
Copy link

codecov bot commented Jul 4, 2024

Codecov Report

Attention: Patch coverage is 76.92308% with 3 lines in your changes missing coverage. Please review.

Project coverage is 85.30%. Comparing base (1ad2a63) to head (cede35c).

Files Patch % Lines
crates/onnx-ir/src/dim_inference.rs 70.00% 3 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1964   +/-   ##
=======================================
  Coverage   85.29%   85.30%           
=======================================
  Files         798      798           
  Lines       95512    95522   +10     
=======================================
+ Hits        81471    81481   +10     
  Misses      14041    14041           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@antimora antimora requested review from laggui and antimora July 4, 2024 18:25
Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for fixing the bug.

It looks good overall. I have minor suggestion to imporove.

@@ -131,6 +131,9 @@ impl BinaryNode {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) },
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) },
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs },
(Type::Scalar(_), Type::Tensor(_)) => {
move |lhs, rhs| quote! { #rhs.mul_scalar(-1).add_scalar(#lhs) }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be more efficient if we have one tensor op and rely on compiler to negate #lhs. We can rewrite as follows: #rhs.add_scalar(-(#lhs)). So generated code might look like this: #rhs.add_scalar(-(- 42)). And Rust compiler will precompute number literal correctly.

Copy link
Contributor Author

@johnhuichen johnhuichen Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks or review @antimora!

Since this is lhs (scalar) - rhs (tensor), looks like -#rhs.sub_scalar(#lhs) produced the right result. I will make an update.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the already requested change, LGTM!

Thanks for fixing :)

Make sure sub op is more efficient by using one operator
@johnhuichen johnhuichen force-pushed the add-subtract-tensor-from-scalar-for-onnx-sub branch from cb6c1d3 to af2138d Compare July 4, 2024 19:31
Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your fix! LGTM

@@ -131,6 +131,7 @@ impl BinaryNode {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) },
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) },
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs },
(Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { -#rhs.sub_scalar(#lhs) },
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

FYI @nathanielsimard , @laggui , @louisfd

Another reason to support native Scalar type in Burn. See our earlier design discussion: #1689 (comment)

{ -#rhs.sub_scalar(#lhs) }, solution to scalar - tensor will result in two operations instead of one.

@antimora
Copy link
Collaborator

antimora commented Jul 5, 2024

I'll merge to the main once the CI passes again.

@antimora antimora merged commit fe0544b into tracel-ai:main Jul 5, 2024
14 checks passed
@johnhuichen johnhuichen deleted the add-subtract-tensor-from-scalar-for-onnx-sub branch July 7, 2024 07:49
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants