From b0a5ac6e88897a80efdbe3dc965c5ff21b14ef99 Mon Sep 17 00:00:00 2001 From: JC Date: Thu, 4 Jul 2024 11:06:33 -0400 Subject: [PATCH] Add case to match lhs as tensor and rhs as scalar --- crates/onnx-ir/src/dim_inference.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index 3a0ef2cb84..d901257791 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -484,6 +484,7 @@ fn slice_update_outputs(node: &mut Node) { fn sub_update_outputs(node: &mut Node) { node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) { (ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs), + (ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs), (ArgType::Tensor(lhs), ArgType::Tensor(rhs)) => { // Support broadcasting for lhs/rhs if lhs.dim > rhs.dim {