diff --git a/crates/burn-import/DEVELOPMENT.md b/crates/burn-import/DEVELOPMENT.md index 687428fe1bf..2d122a042bc 100644 --- a/crates/burn-import/DEVELOPMENT.md +++ b/crates/burn-import/DEVELOPMENT.md @@ -62,7 +62,7 @@ To extend `burn-import` with support for new ONNX operators, follow these steps: the Burn model in Rust code, and `my-model.json` includes the model data. 7. **Add End-to-End Test**: Include the test in `./burn-import/onnx-tests/tests/onnx_tests.rs`. - Further details can be found in the [onnx-tests README](./burn-import/onnx-tests/README.md). + Further details can be found in the [onnx-tests README](./onnx-tests/README.md). ## Testing diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index e072bef10a1..7aa250ce4e4 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -149,7 +149,7 @@ mod tests { let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]], &device); let scalar = 3.0f64; let output = model.forward(input, scalar); - let expected = TensorData::from([[[[6f32, 7., 8., 9.]]]]); + let expected = TensorData::from([[[[-12f32, -13., -14., -15.]]]]); output.to_data().assert_eq(&expected, true); } @@ -164,7 +164,7 @@ mod tests { let input = Tensor::::from_ints([[[[1, 2, 3, 4]]]], &device); let scalar = 3; let output = model.forward(input, scalar); - let expected = TensorData::from([[[[6i64, 6, 6, 6]]]]); + let expected = TensorData::from([[[[-12i64, -12, -12, -12]]]]); output.to_data().assert_eq(&expected, true); } diff --git a/crates/burn-import/onnx-tests/tests/sub/sub.onnx b/crates/burn-import/onnx-tests/tests/sub/sub.onnx index 7ffdfc8083c..60d76ea1c55 100644 Binary files a/crates/burn-import/onnx-tests/tests/sub/sub.onnx and b/crates/burn-import/onnx-tests/tests/sub/sub.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/sub/sub.py b/crates/burn-import/onnx-tests/tests/sub/sub.py index f71cf4a018d..2169592b903 100755 --- a/crates/burn-import/onnx-tests/tests/sub/sub.py +++ b/crates/burn-import/onnx-tests/tests/sub/sub.py @@ -26,6 +26,9 @@ def forward(self, x, k): # Sutract a scalar from a tensor x = x - d + # Sutract a tensor from a scalar + x = d - x + return x @@ -40,8 +43,9 @@ def main(): scalar = 3.0 - torch.onnx.export(model, (dummy_input, scalar), onnx_name, - verbose=False, opset_version=16) + torch.onnx.export( + model, (dummy_input, scalar), onnx_name, verbose=False, opset_version=16 + ) print("Finished exporting model to {}".format(onnx_name)) @@ -53,5 +57,5 @@ def main(): print("Test output data: {}".format(output)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/crates/burn-import/onnx-tests/tests/sub/sub_int.onnx b/crates/burn-import/onnx-tests/tests/sub/sub_int.onnx index 55309cea189..73a4ace795e 100644 Binary files a/crates/burn-import/onnx-tests/tests/sub/sub_int.onnx and b/crates/burn-import/onnx-tests/tests/sub/sub_int.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/sub/sub_int.py b/crates/burn-import/onnx-tests/tests/sub/sub_int.py index 17ace09cad6..487c66b19b4 100755 --- a/crates/burn-import/onnx-tests/tests/sub/sub_int.py +++ b/crates/burn-import/onnx-tests/tests/sub/sub_int.py @@ -27,6 +27,9 @@ def forward(self, x, k): # Sutract a scalar from a tensor x = x - d + # Sutract a tensor from a scalar + x = d - x + return x @@ -41,8 +44,9 @@ def main(): test_input = torch.tensor([[[[1, 2, 3, 4]]]], device=device) scalar = 3 - torch.onnx.export(model, (test_input, scalar), onnx_name, - verbose=False, opset_version=16) + torch.onnx.export( + model, (test_input, scalar), onnx_name, verbose=False, opset_version=16 + ) print("Finished exporting model to {}".format(onnx_name)) @@ -51,5 +55,5 @@ def main(): print("Test output data: {}".format(output)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/crates/burn-import/src/burn/node/binary.rs b/crates/burn-import/src/burn/node/binary.rs index da6b7b9303e..4aed4c37df9 100644 --- a/crates/burn-import/src/burn/node/binary.rs +++ b/crates/burn-import/src/burn/node/binary.rs @@ -131,6 +131,9 @@ impl BinaryNode { (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) }, (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) }, (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs }, + (Type::Scalar(_), Type::Tensor(_)) => { + move |lhs, rhs| quote! { #rhs.mul_scalar(-1).add_scalar(#lhs) } + } _ => panic!("Subtraction is supported for tensor and scalar only"), }; diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index e39ef73bdf1..23a1619b475 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -69,7 +69,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::Slice => slice_update_outputs(node), NodeType::Softmax => same_as_input(node), NodeType::Sqrt => same_as_input(node), - NodeType::Sub => same_as_input(node), + NodeType::Sub => sub_update_outputs(node), NodeType::Sum => same_as_input(node), NodeType::Tanh => same_as_input(node), NodeType::Transpose => same_as_input(node), @@ -481,6 +481,24 @@ fn slice_update_outputs(node: &mut Node) { } } +fn sub_update_outputs(node: &mut Node) { + node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) { + (ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs), + (ArgType::Tensor(lhs), ArgType::Tensor(rhs)) => { + // Support broadcasting for lhs/rhs + if lhs.dim > rhs.dim { + ArgType::Tensor(lhs) + } else { + ArgType::Tensor(rhs) + } + } + _ => { + println!("{0:?},{1:?}", node.inputs, node.inputs); + panic!("Only tensor-scalar inputs are valid."); + } + }; +} + /// Update the output tensor dimension based on the "axes" attribute or the second input fn unsqueeze_update_output(node: &mut Node) { let axes = if node.inputs.len() == 2 {