Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

burn-import: add some tests for ConstantNode #2623

Open
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

jameshiew
Copy link

@jameshiew jameshiew commented Dec 17, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

I've been trying to implement OneHot ONNX op (#1714) in a WIP draft branch. The ONNX model ends up containing a constant integer vector values=[0, 1] used by the OneHot op, this vector was causing issues when trying to test the model. I looked at ConstantNode and these are the tests so far I could get working while investigating.

Issues for ConstantNode
#2624 - constant tensors aren't populated with values
#2625 - generated code for const int tensors doesn't compile

Changes

  • added a helper method ConstantNode::tensor_ty_tokens for tests, but this PR otherwise shouldn't be changing how ConstantNode currently works
  • add codegen tests for ConstantNode (i32/64 + f32/64 scalar, tensors)
  • add ONNX model tests for i32/64 + f32/64 scalars - implicitly testing by adding the constant

Testing

Ran added tests

cargo xtask check all
cargo nextest run --manifest-path crates/burn-import/Cargo.toml
cargo nextest run --manifest-path crates/burn-import/onnx-tests/Cargo.toml

I checked the .onnx models contain the expected scalar constants using Netron

Screenshots f32 f64 i32 i64

let device = Default::default();
let model = constant_f64::Model::<Backend>::new(&device);
let input = TensorData::zeros::<f64, _>(Shape::from([2, 3, 4]));
let expected_output = TensorData::full(Shape::from([2, 3, 4]), 2f32);
Copy link
Author

@jameshiew jameshiew Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the addition is coercing f64 -> f32 somewhere (and i32 -> i64 below). I wasn't sure how to get PyTorch to just forward the constant by itself so these tests are adding the constant to the input

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe by having the output return the constant only? But a simple constant addition works too.

In case you're curious, you could also manually define the onnx graph like the ConstOfShape script. PyTorch tends doesn't always have a 1-to-1 correspondence for ops, so in such cases it could be easier to define the graph manually.

I'm not sure if the addition is coercing f64 -> f32 somewhere (and i32 -> i64 below)

The floating point and integer data types are defined by the backend used. A model is not as statically defined like an ONNX graph. If you look at the other tests, the input(s) and output(s) are created using the Tensor methods, not from TensorData.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the extra context, sorry it's taken me a bit to get back to this PR! I've rebased on latest main and updated these tests to use Tensor instead of TensorData. If I understand correctly now - the Burn backend used in these tests (type Backend = burn_ndarray::NdArray<f32>) uses fixed element types f32 and i64 - so these types will be expected in the final output data even if the ONNX model is using e.g. f64 or i32.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry it's taken me a bit to get back to this PR

No worries 🙂

the Burn backend used in these tests (type Backend = burn_ndarray::NdArray) uses fixed element types f32 and i64 - so these types will be expected in the final output data even if the ONNX model is using e.g. f64 or i32.

That's correct.

We actually support different floating point types now with tensor.cast(dtype) (so even if a backend uses f32 as the default floating point type, tensors could be cast to another precision). But for ONNX import, we currently don't take the precision dtype into account.

@jameshiew jameshiew marked this pull request as ready for review December 17, 2024 19:16
Copy link

codecov bot commented Jan 2, 2025

Codecov Report

Attention: Patch coverage is 99.74684% with 1 line in your changes missing coverage. Please review.

Project coverage is 83.75%. Comparing base (6015823) to head (7291b0a).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-import/src/burn/node/constant.rs 99.71% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2623      +/-   ##
==========================================
+ Coverage   83.67%   83.75%   +0.08%     
==========================================
  Files         832      832              
  Lines      109524   109919     +395     
==========================================
+ Hits        91643    92066     +423     
+ Misses      17881    17853      -28     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

github-actions bot commented Feb 2, 2025

This PR has been marked as stale because it has not been updated for over a month

@github-actions github-actions bot added the stale The issue or pr has been open for too long label Feb 2, 2025
@jameshiew jameshiew force-pushed the onnx-constant-tests branch from 756fd3c to a1f87b8 Compare February 9, 2025 11:10
@jameshiew jameshiew force-pushed the onnx-constant-tests branch from a1f87b8 to 7291b0a Compare February 9, 2025 11:17
@github-actions github-actions bot removed the stale The issue or pr has been open for too long label Feb 9, 2025
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants