-
Notifications
You must be signed in to change notification settings - Fork 491
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
burn-import: add some tests for ConstantNode #2623
base: main
Are you sure you want to change the base?
Conversation
let device = Default::default(); | ||
let model = constant_f64::Model::<Backend>::new(&device); | ||
let input = TensorData::zeros::<f64, _>(Shape::from([2, 3, 4])); | ||
let expected_output = TensorData::full(Shape::from([2, 3, 4]), 2f32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if the addition is coercing f64 -> f32 somewhere (and i32 -> i64 below). I wasn't sure how to get PyTorch to just forward the constant by itself so these tests are adding the constant to the input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe by having the output return the constant only? But a simple constant addition works too.
In case you're curious, you could also manually define the onnx graph like the ConstOfShape script. PyTorch tends doesn't always have a 1-to-1 correspondence for ops, so in such cases it could be easier to define the graph manually.
I'm not sure if the addition is coercing f64 -> f32 somewhere (and i32 -> i64 below)
The floating point and integer data types are defined by the backend used. A model is not as statically defined like an ONNX graph. If you look at the other tests, the input(s) and output(s) are created using the Tensor
methods, not from TensorData
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the extra context, sorry it's taken me a bit to get back to this PR! I've rebased on latest main
and updated these tests to use Tensor
instead of TensorData
. If I understand correctly now - the Burn backend used in these tests (type Backend = burn_ndarray::NdArray<f32>
) uses fixed element types f32 and i64 - so these types will be expected in the final output data even if the ONNX model is using e.g. f64 or i32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry it's taken me a bit to get back to this PR
No worries 🙂
the Burn backend used in these tests (type Backend = burn_ndarray::NdArray) uses fixed element types f32 and i64 - so these types will be expected in the final output data even if the ONNX model is using e.g. f64 or i32.
That's correct.
We actually support different floating point types now with tensor.cast(dtype)
(so even if a backend uses f32
as the default floating point type, tensors could be cast to another precision). But for ONNX import, we currently don't take the precision dtype into account.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2623 +/- ##
==========================================
+ Coverage 83.67% 83.75% +0.08%
==========================================
Files 832 832
Lines 109524 109919 +395
==========================================
+ Hits 91643 92066 +423
+ Misses 17881 17853 -28 ☔ View full report in Codecov by Sentry. |
This PR has been marked as stale because it has not been updated for over a month |
756fd3c
to
a1f87b8
Compare
This reverts commit b6631ba.
a1f87b8
to
7291b0a
Compare
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
I've been trying to implement OneHot ONNX op (#1714) in a WIP draft branch. The ONNX model ends up containing a constant integer vector
values=[0, 1]
used by the OneHot op, this vector was causing issues when trying to test the model. I looked atConstantNode
and these are the tests so far I could get working while investigating.Issues for
ConstantNode
#2624 - constant tensors aren't populated with values
#2625 - generated code for const int tensors doesn't compile
Changes
ConstantNode::tensor_ty_tokens
for tests, but this PR otherwise shouldn't be changing howConstantNode
currently worksTesting
Ran added tests
I checked the .onnx models contain the expected scalar constants using Netron
Screenshots