diff --git a/onnxruntime/core/providers/cuda/tensor/where.cc b/onnxruntime/core/providers/cuda/tensor/where.cc index cfe4ce0696efa..142dc9294414e 100644 --- a/onnxruntime/core/providers/cuda/tensor/where.cc +++ b/onnxruntime/core/providers/cuda/tensor/where.cc @@ -9,16 +9,16 @@ namespace onnxruntime { namespace cuda { // kernel builder functions -#define WHERE_TYPED_KERNEL_WITH_TYPE_NAME(T, TName) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Where, \ - kOnnxDomain, \ - 9, \ - TName, \ - kCudaExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("B", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define WHERE_TYPED_KERNEL_WITH_TYPE_NAME(T, TName) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Where, \ + kOnnxDomain, \ + 9, \ + TName, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("B", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Where); // Compute where operator output shape based upon three way broad-casting. @@ -44,6 +44,10 @@ Status ComputeOutputShape(const std::string& node_name, const TensorShape& cond_ y_dim = y_shape[y_rank - 1 - i]; int64_t out_dim = std::max(std::max(cond_dim, x_dim), y_dim); + // special case to handle a dim of 0 which can be broadcast with a 1 + if (out_dim == 1) + out_dim = std::min(std::min(cond_dim, x_dim), y_dim); + if (cond_dim != out_dim && cond_dim != 1) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": condition operand cannot broadcast on dim ", cond_rank - 1 - i, " Condition Shape: ", cond_shape.ToString(), ", X Shape: ", x_shape.ToString(), ", Y Shape: ", y_shape.ToString()); @@ -150,6 +154,9 @@ Status Where::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), condition_shape, X_shape, Y_shape, output_shape)); auto output_tensor = context->Output(0, output_shape); + if (output_shape.Size() == 0) + return Status::OK(); + TernaryElementwisePreparation prepare(this, condition, X, Y); ORT_RETURN_IF_ERROR(prepare.TernaryElementwiseBroadcastPrepareHelper(condition_shape, X_shape, Y_shape, output_shape)); ORT_RETURN_IF_ERROR(prepare.CopyToGpu()); @@ -169,11 +176,11 @@ Status Where::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } -#define SPECIALIZED_COMPUTE_WITH_NAME(T, TName) \ - WHERE_TYPED_KERNEL_WITH_TYPE_NAME(T, TName) \ +#define SPECIALIZED_COMPUTE_WITH_NAME(T, TName) \ + WHERE_TYPED_KERNEL_WITH_TYPE_NAME(T, TName) \ template Status Where::ComputeInternal(OpKernelContext* context) const; -#define SPECIALIZED_COMPUTE(T) \ +#define SPECIALIZED_COMPUTE(T) \ SPECIALIZED_COMPUTE_WITH_NAME(T, T) SPECIALIZED_COMPUTE(int32_t) diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 4764e8fa0a8d1..f344b6e114bc2 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -11,12 +11,13 @@ namespace onnxruntime { namespace test { -TEST(BroadcastingTest, DimWithZeroHandling) { +TEST(MathOpTest, DimWithZeroHandling) { auto run = [](OpTester& tester) { // exclude NGraph and TensorRT as this isn't handled by those EPs tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kNGraphExecutionProvider}); }; + // test binary element-wise op broadcasting when there's a dim with value of zero // equal rank OpTester test("Add"); test.AddInput("A", {3, 1}, {1, 2, 3}); @@ -51,6 +52,12 @@ TEST(BroadcastingTest, DimWithZeroHandling) { test4.AddInput("B", {2, 1}, {1, 2}); test4.AddOutput("C", {2, 2, 0}, {}); run(test4); + + // test unary op handles it as well + OpTester test5("Floor"); + test5.AddInput("A", {0, 3}, {}); + test5.AddOutput("B", {0, 3}, {}); + run(test5); } TEST(MathOpTest, Add_int32) { diff --git a/onnxruntime/test/providers/cpu/tensor/where_op_test.cc b/onnxruntime/test/providers/cpu/tensor/where_op_test.cc index 5da7b3db325c5..e953761df0a9c 100644 --- a/onnxruntime/test/providers/cpu/tensor/where_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/where_op_test.cc @@ -109,5 +109,18 @@ TEST(WhereOpTest, Broadcast) { WhereBroadcastTest("true", "false"); } +TEST(WhereOpTest, BroadcastDimWithZero) { + // test where broadcast is possible, and dim of 0 should be selected + OpTester test{kOpName, kOpVersion}; + + test.AddInput("condition", {3}, {true, false, true}); + test.AddInput("X", {1, 3}, {1, 2, 3}); + test.AddInput("Y", {0, 1}, {}); + + test.AddOutput("output", {0, 3}, {}); + + // exclude NGraph as this isn't handled by that EP + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kNGraphExecutionProvider}); +} } // namespace test } // namespace onnxruntime