Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

burn-import: add some tests for ConstantNode #2623

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

jameshiew
Copy link

@jameshiew jameshiew commented Dec 17, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

I've been trying to implement OneHot ONNX op (#1714) in a WIP draft branch. The ONNX model ends up containing a constant integer vector values=[0, 1] used by the OneHot op, this vector was causing issues when trying to test the model. I looked at ConstantNode and these are the tests so far I could get working while investigating.

Issues for ConstantNode
#2624 - constant tensors aren't populated with values
#2625 - generated code for const int tensors doesn't compile

Changes

  • added a helper method ConstantNode::tensor_ty_tokens for tests, but this PR otherwise shouldn't be changing how ConstantNode currently works
  • add codegen tests for ConstantNode (i32/64 + f32/64 scalar, tensors)
  • add ONNX model tests for i32/64 + f32/64 scalars - implicitly testing by adding the constant

Testing

Ran added tests

cargo xtask check all
cargo nextest run --manifest-path crates/burn-import/Cargo.toml
cargo nextest run --manifest-path crates/burn-import/onnx-tests/Cargo.toml

I checked the .onnx models contain the expected scalar constants using Netron

Screenshots f32 f64 i32 i64

let device = Default::default();
let model = constant_f64::Model::<Backend>::new(&device);
let input = TensorData::zeros::<f64, _>(Shape::from([2, 3, 4]));
let expected_output = TensorData::full(Shape::from([2, 3, 4]), 2f32);
Copy link
Author

@jameshiew jameshiew Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the addition is coercing f64 -> f32 somewhere (and i32 -> i64 below). I wasn't sure how to get PyTorch to just forward the constant by itself so these tests are adding the constant to the input

@jameshiew jameshiew marked this pull request as ready for review December 17, 2024 19:16
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant