Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Refactor GatherNode to support scalar outputs. #2828

Merged
merged 5 commits into from
Feb 24, 2025

Conversation

loloxwg
Copy link
Contributor

@loloxwg loloxwg commented Feb 19, 2025

Updated GatherNode and related code to handle scalar outputs alongside tensors. Adjusted the dimension inference logic and codegen functionality to accommodate this change. Added unit tests to ensure correctness of scalar output handling.

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#2781

Changes

Summarize the problem being addressed and your solution.

Testing

Describe how these changes have been tested.

Updated `GatherNode` and related code to handle scalar outputs alongside tensors. Adjusted the dimension inference logic and `codegen` functionality to accommodate this change. Added unit tests to ensure correctness of scalar output handling.
Copy link

codecov bot commented Feb 19, 2025

Codecov Report

Attention: Patch coverage is 84.32432% with 29 lines in your changes missing coverage. Please review.

Project coverage is 82.22%. Comparing base (e51eace) to head (f88e0fa).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-import/src/burn/node/gather.rs 85.41% 21 Missing ⚠️
crates/burn-tensor/src/tensor/api/check.rs 22.22% 7 Missing ⚠️
crates/onnx-ir/src/dim_inference.rs 94.44% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##             main    #2828    +/-   ##
========================================
  Coverage   82.22%   82.22%            
========================================
  Files         853      853            
  Lines      113444   113829   +385     
========================================
+ Hits        93275    93595   +320     
- Misses      20169    20234    +65     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@loloxwg
Copy link
Contributor Author

loloxwg commented Feb 19, 2025

how to fix ci ?

@laggui
Copy link
Member

laggui commented Feb 19, 2025

how to fix ci ?

You can check the failed job: https://github.com/tracel-ai/burn/actions/runs/13404320429/job/37449962928?pr=2828

     Compiling burn-import v0.17.0 (/home/runner/work/burn/burn/crates/burn-import)
  error: unused import: `TensorType`
   --> crates/burn-import/src/burn/node/gather.rs:2:31
    |
  2 | use crate::burn::{ScalarKind, TensorType, ToTokens, Type};
    |                               ^^^^^^^^^^
    |
    = note: `-D unused-imports` implied by `-D warnings`
    = help: to override `-D warnings` add `#[allow(unused_imports)]`
  
      Checking burn-wgpu v0.17.0 (/home/runner/work/burn/burn/crates/burn-wgpu)
      Checking burn-cuda v0.17.0 (/home/runner/work/burn/burn/crates/burn-cuda)
      Checking burn-hip v0.17.0 (/home/runner/work/burn/burn/crates/burn-hip)
      Checking burn-vision v0.17.0 (/home/runner/work/burn/burn/crates/burn-vision)
      Checking guide v0.17.0 (/home/runner/work/burn/burn/examples/guide)
  error: could not compile `burn-import` (lib) due to 1 previous error
  warning: build failed, waiting for other jobs to finish...
  error: could not compile `burn-import` (lib) due to 1 previous error
  Error: Workspace lint failed
  Error: Process completed with exit code 1.

Should be an easy fix 😄

@laggui laggui self-requested a review February 19, 2025 13:54
The `TensorType` import was not being used in the code. Removing it enhances clarity and reduces unnecessary dependencies. This is a minor cleanup for better code quality.
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great start! But I think I found a couple of issues.

Also, we should add a test for the gather with a scalar 🙂 this would have caught some of the problems I mentioned below.

This is described in this section of the book. But basically, add an ONNX model onnx-tests/tests/gather/gather_scalar_out.onnx that produces a scalar output with the gather node, add the file to the onnx-tests/build.rs, include the model in onnx-tests/tests/test_onnx.rs and add a test in the same file. You can check the other gather nodes as reference (e.g., gather_scalar for scalar index).

Refactored the `Gather` node handler to ensure correct scalar handling by removing unnecessary squeezing and adding checks. Introduced the `gather_scalar_out` ONNX test, model generator, and validation for scalar outputs. Additionally, updated tensor API to enforce dimension constraints in `squeeze` operations.
Updated the formatting of an assertion error message to improve readability and maintain consistent style. This change does not affect functionality but enhances code clarity.
@loloxwg loloxwg requested a review from laggui February 21, 2025 04:35
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job! The ONNX import code LGTM 👍

Just some comments regarding additional checks to move and a leftover print.

Will be good to go after that 🙂

This commit introduces a check to ensure the specified dimension index is within bounds for tensor operations. It removes redundant assertions and debug logs to streamline code and centralize error handling. These changes improve robustness and maintain consistency in tensor operations.
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving the ONNX support 🙏

@laggui laggui merged commit bdbfa65 into tracel-ai:main Feb 24, 2025
11 checks passed
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants