-
Notifications
You must be signed in to change notification settings - Fork 503
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Refactor GatherNode to support scalar outputs. #2828
Conversation
Updated `GatherNode` and related code to handle scalar outputs alongside tensors. Adjusted the dimension inference logic and `codegen` functionality to accommodate this change. Added unit tests to ensure correctness of scalar output handling.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2828 +/- ##
========================================
Coverage 82.22% 82.22%
========================================
Files 853 853
Lines 113444 113829 +385
========================================
+ Hits 93275 93595 +320
- Misses 20169 20234 +65 ☔ View full report in Codecov by Sentry. |
how to fix ci ? |
You can check the failed job: https://github.com/tracel-ai/burn/actions/runs/13404320429/job/37449962928?pr=2828
Should be an easy fix 😄 |
The `TensorType` import was not being used in the code. Removing it enhances clarity and reduces unnecessary dependencies. This is a minor cleanup for better code quality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great start! But I think I found a couple of issues.
Also, we should add a test for the gather with a scalar 🙂 this would have caught some of the problems I mentioned below.
This is described in this section of the book. But basically, add an ONNX model onnx-tests/tests/gather/gather_scalar_out.onnx
that produces a scalar output with the gather node, add the file to the onnx-tests/build.rs
, include the model in onnx-tests/tests/test_onnx.rs
and add a test in the same file. You can check the other gather nodes as reference (e.g., gather_scalar
for scalar index).
Refactored the `Gather` node handler to ensure correct scalar handling by removing unnecessary squeezing and adding checks. Introduced the `gather_scalar_out` ONNX test, model generator, and validation for scalar outputs. Additionally, updated tensor API to enforce dimension constraints in `squeeze` operations.
Updated the formatting of an assertion error message to improve readability and maintain consistent style. This change does not affect functionality but enhances code clarity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job! The ONNX import code LGTM 👍
Just some comments regarding additional checks to move and a leftover print.
Will be good to go after that 🙂
This commit introduces a check to ensure the specified dimension index is within bounds for tensor operations. It removes redundant assertions and debug logs to streamline code and centralize error handling. These changes improve robustness and maintain consistency in tensor operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for improving the ONNX support 🙏
Updated
GatherNode
and related code to handle scalar outputs alongside tensors. Adjusted the dimension inference logic andcodegen
functionality to accommodate this change. Added unit tests to ensure correctness of scalar output handling.Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
#2781
Changes
Summarize the problem being addressed and your solution.
Testing
Describe how these changes have been tested.