Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Allow 0 size dimensions (dimensions containing a 0 in the list of sizes, not a rank of 0 which is valid) #391

Open
huningxin opened this issue May 23, 2023 · 21 comments

Comments

@huningxin
Copy link
Contributor

In Chromium CL review, @fdwr mentioned (Thanks Dwayne!)

0 size dimensions really should just be treated as nops. e.g. memcpy(dest, src, 0) is valid and does nothing, and adding two empty tensors just returns an empty tensor. There are legitimate cases within a graph where a tensor may be temporarily sliced down to emptiness and then reconcatenated later with other data (I've come across at least one ONNX model that does this).

From native ML API perspective, Dwayne also mentioned

Now, there are certain backends that are not prepared for empty dimensions (e.g. DML doesn't accept them and rejects the operator creation call), and those operators will need to be skipped inside the backend, but the front-end model builder imo should treat them validly and return 0 elements.

We may want to investigate more native ML APIs to understand what's the status of the support.

I am opening this issue to start tracking, e.g. adding TODO in the implementation. @fdwr, feel free to share more details. Thanks!

@fdwr
Copy link
Collaborator

fdwr commented May 23, 2023

References

All the major Python ML APIs handle them robustly, and they are not considered degenerate:

NumPy

import numpy

x = numpy.ones(shape=(2,0,2), dtype=numpy.float32)
y = numpy.add(x, x)
print("NumPy:")
print("value:", y)
print("shape:", y.shape)

# Prints:
# value: []
# shape: (2, 0, 2)

TensorFlow

import tensorflow as tf

x = tf.ones(shape=(2,0,2), dtype=tf.float32)
y = tf.add(x, x);
print("TensorFlow:")
print("value:", y)
print("shape:", y.shape)

# Prints:
# value: tf.Tensor([], shape=(2, 0, 2), dtype=float32)
# shape: ()

PyTorch

import torch

x = torch.ones(size=(2,0,2), dtype=torch.float)
y = torch.add(x, x)
print("PyTorch:")
print("value:", y)
print("shape:", y.shape)

# Prints:
# value: tensor([], size=(2, 0, 2))
# shape: torch.Size([2, 0, 2])

ONNX / ONNX Runtime

import onnx

# Scalar via [].
x = onnx.helper.make_tensor(
    name="value", data_type=onnx.TensorProto.FLOAT, dims=[2,0,2], vals=[]
)
print(x)

# Prints:
# dims: 2
# dims: 0
# dims: 2
# data_type: 1
# name: "value"

In ONNX Runtime, these cases are handled as nops, either directly by the EP backend (if it handles them gracefully) or by the lower-level code just before it reaches the backend API call (such as with DirectML which currently rejects 0's in the dimensions, where the EP skips operator creation while still leaving the overall graph connectivity intact).

XNNPack

Allows them.

// see Bin Miao's code below

SafeTensors

The SafeTensors file format (commonly used with Stable Diffusion models for custom weights) explicitly allows 0D scalars and 0-size tensors - "Empty tensors (tensors with 1 dimension being 0) are allowed" and "0-rank Tensors (tensors with shape []) are allowed, they are merely a scalar".

CoreML / MPS / BNNS

? Not evident from documentation:

DirectML

Disallows 0 for DML_BUFFER_TENSOR_DESC::Sizes. The backend must skip the operation.

@huningxin
Copy link
Contributor Author

I'm unsure what XNNPack would do if you tried to add two empty tensors (needs research).

@miaobin would volunteer to help investigate XNNPACK's support. Thanks!

@miaobin
Copy link

miaobin commented Jul 1, 2023

I'm unsure what XNNPack would do if you tried to add two empty tensors (needs research).

@miaobin would volunteer to help investigate XNNPACK's support. Thanks!

After I deleted and modified the errant validation statement in the ml_graph_builder.cc and graph_validation_utils.cc. I verified that XNNPack supports both 0D scalars and 0-size tensors through the following two test cases:

Test for 0D scalars:

{
    auto* input1 =
        BuildInput(builder, "input1", {}, V8MLOperandType::Enum::kFloat32,
                   scope.GetExceptionState());
    EXPECT_NE(input1, nullptr);
    EXPECT_EQ(scope.GetExceptionState().CodeAs<DOMExceptionCode>(),
              DOMExceptionCode::kNoError);
    EXPECT_EQ(input1->Kind(), MLOperand::OperandKind::kInput);
    EXPECT_EQ(input1->Type(), V8MLOperandType::Enum::kFloat32);
    EXPECT_EQ(input1->Dimensions(), Vector<uint32_t>({}));
    EXPECT_EQ(input1->Name(), "input1");

    auto* input2 =
        BuildInput(builder, "input2", {}, V8MLOperandType::Enum::kFloat32,
                   scope.GetExceptionState());
    EXPECT_NE(input2, nullptr);
    EXPECT_EQ(scope.GetExceptionState().CodeAs<DOMExceptionCode>(),
              DOMExceptionCode::kNoError);
    EXPECT_EQ(input2->Kind(), MLOperand::OperandKind::kInput);
    EXPECT_EQ(input2->Type(), V8MLOperandType::Enum::kFloat32);
    EXPECT_EQ(input2->Dimensions(), Vector<uint32_t>({}));
    EXPECT_EQ(input2->Name(), "input2");

    auto* output_operand = builder->add(input1, input2, scope.GetExceptionState());
    EXPECT_NE(output_operand, nullptr);
    EXPECT_EQ(scope.GetExceptionState().CodeAs<DOMExceptionCode>(),
              DOMExceptionCode::kNoError);
    EXPECT_EQ(output_operand->Kind(), MLOperand::OperandKind::kOutput);
    EXPECT_EQ(output_operand->Type(), V8MLOperandType::Enum::kFloat32);
    EXPECT_EQ(output_operand->Dimensions(), Vector<uint32_t>({}));

    auto [graph, build_exception] =
        BuildGraph(scope, builder, {{"output", output_operand}});
    EXPECT_NE(graph, nullptr);

    // Compute the graph.
    MLNamedArrayBufferViews inputs(
        {{"input1", CreateArrayBufferViewForOperand<float>(input1, {42.0})},
         {"input2", CreateArrayBufferViewForOperand<float>(input2, {42.0})}});
    MLNamedArrayBufferViews outputs(
        {{"output", CreateArrayBufferViewForOperand(output_operand)}});
    auto* compute_exception = ComputeGraph(scope, graph, inputs, outputs);
    EXPECT_EQ(compute_exception, nullptr);
    auto results = GetArrayBufferViewValues<float>(outputs[0].second);
    Vector<float> r{84.0};
    EXPECT_EQ(results, r);
  }

Test for 0-size tensors:

  {
    auto* input1 =
        BuildInput(builder, "input1", {2, 0, 2}, V8MLOperandType::Enum::kFloat32,
                   scope.GetExceptionState());
    EXPECT_NE(input1, nullptr);
    EXPECT_EQ(scope.GetExceptionState().CodeAs<DOMExceptionCode>(),
              DOMExceptionCode::kNoError);
    EXPECT_EQ(input1->Kind(), MLOperand::OperandKind::kInput);
    EXPECT_EQ(input1->Type(), V8MLOperandType::Enum::kFloat32);
    EXPECT_EQ(input1->Dimensions(), Vector<uint32_t>({2, 0, 2}));
    EXPECT_EQ(input1->Name(), "input1");

    auto* input2 =
        BuildInput(builder, "input2", {2, 0, 2}, V8MLOperandType::Enum::kFloat32,
                   scope.GetExceptionState());
    EXPECT_NE(input2, nullptr);
    EXPECT_EQ(scope.GetExceptionState().CodeAs<DOMExceptionCode>(),
              DOMExceptionCode::kNoError);
    EXPECT_EQ(input2->Kind(), MLOperand::OperandKind::kInput);
    EXPECT_EQ(input2->Type(), V8MLOperandType::Enum::kFloat32);
    EXPECT_EQ(input2->Dimensions(), Vector<uint32_t>({2, 0, 2}));
    EXPECT_EQ(input2->Name(), "input2");

    auto* output_operand = builder->add(input1, input2, scope.GetExceptionState());
    EXPECT_NE(output_operand, nullptr);
    EXPECT_EQ(scope.GetExceptionState().CodeAs<DOMExceptionCode>(),
              DOMExceptionCode::kNoError);
    EXPECT_EQ(output_operand->Kind(), MLOperand::OperandKind::kOutput);
    EXPECT_EQ(output_operand->Type(), V8MLOperandType::Enum::kFloat32);
    EXPECT_EQ(output_operand->Dimensions(), Vector<uint32_t>({2, 0, 2}));

    auto [graph, build_exception] =
        BuildGraph(scope, builder, {{"output", output_operand}});
    EXPECT_NE(graph, nullptr);

    // Compute the graph.
    MLNamedArrayBufferViews inputs(
        {{"input1", CreateArrayBufferViewForOperand<float>(input1, {})},
         {"input2", CreateArrayBufferViewForOperand<float>(input2, {})}});
    MLNamedArrayBufferViews outputs(
        {{"output", CreateArrayBufferViewForOperand(output_operand)}});
    auto* compute_exception = ComputeGraph(scope, graph, inputs, outputs);
    EXPECT_EQ(compute_exception, nullptr);
    auto results = GetArrayBufferViewValues<float>(outputs[0].second);
    Vector<float> r{};
    EXPECT_EQ(results, r);
  }

Both the tests have passed.

@fdwr
Copy link
Collaborator

fdwr commented Jul 1, 2023

miaobin: Great - thanks for investigating and adding the test cases. It will be more interesting for the DirectML backend because the current API rejects zero size tensors, and even if we were to update the API to accept them (and add test cases for all 100+ operators...), the older version would still be on the operating system. So we'll have to do the same thing like was done in ONNX Runtime where the operator creation is bypassed for such operators (the node is left null as a placeholder, and it's not added to the graph later).

@fdwr
Copy link
Collaborator

fdwr commented Sep 14, 2023

Evidently LLaMA is another model that can encounter legal 0 size tensors during concat.

@huningxin
Copy link
Contributor Author

@fdwr

There are legitimate cases within a graph where a tensor may be temporarily sliced down to emptiness and then reconcatenated later with other data (I've come across at least one ONNX model that does this).

WebNN's slice requires "the size must not be 0". Would this prevent the ONNX model you mentioned from slicing a tensor down to emptiness?

@fdwr
Copy link
Collaborator

fdwr commented Nov 28, 2023

@fdwr

There are legitimate cases within a graph where a tensor may be temporarily sliced down to emptiness and then reconcatenated later with other data (I've come across at least one ONNX model that does this).

WebNN's slice requires "the size must not be 0". Would this prevent the ONNX model you mentioned from slicing a tensor down to emptiness?

@huningxin: 🤔 It could, as TF and ONNX support 0 size slices (see below). Granted, it's unlikely a TF or ONNX model would typically contain a 0-slice window (ends - starts = 0), but it could occur indirectly as a result of a model generation process and manipulating some other variable:

TF

import tensorflow as tf

values = tf.constant([0, 1, 2, 3, 4, 5], dtype=tf.uint8)
result = tf.slice(values, [1], [1])
print("value:", result)
print("shape:", result.shape)

ONNX

image

ir_version: 4
producer_name: "OnnxConformanceTest"
graph {
  node {
    input: "data"
    output: "output"
    op_type: "Slice"
    attribute {
      name: "axes"
      ints: 0
      type: INTS
    }
    attribute {
      name: "starts"
      ints: 1
      type: INTS
    }
    attribute {
      name: "ends"
      ints: 1
      type: INTS
    }
    domain: ""
  }
  name: "Slice_1d_zero_size"
  input {
    name: "data"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 6
          }
        }
      }
    }
  }
  output {
    name: "output"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 0
          }
        }
      }
    }
  }
}
opset_import {
  domain: ""
  version: 1
}
opset_import {
  domain: ""
  version: 7
}

@huningxin
Copy link
Contributor Author

According to my test, XNNPACK concat (xnn_define_concatenate2/3/4) and split (xnn_define_even_split2/3/4) operators would report invalid parameter error (xnn_status_invalid_parameter) when the input has 0 size dimension.

@fdwr
Copy link
Collaborator

fdwr commented Nov 29, 2023

Bin Miao showed that XNNPack's add supports empty tensors fine, and so if it fails on concat, then that's a bug in XNNPack. Shall I open a GitHub issue, or do you want to? It doesn't matter either way for the WebNN EP though, because you just skip passing that input tensor to concat, the same as the ORT DML EP. So if there were 3 inputs (a=[2,3], b=[2,0], c=[2,4]), then only pass the nonzero ones to the XNNPack call (inputs = [a, c]).

@huningxin
Copy link
Contributor Author

Shall I open a GitHub issue, or do you want to?

Opened: google/XNNPACK#5807

It doesn't matter either way for the WebNN EP though

If frameworks can handle that, it would help simplify WebNN implementation.

@sushraja-msft
Copy link

sushraja-msft commented Mar 18, 2024

Evidently LLaMA is another model that can encounter legal 0 size tensors during concat.

Another case that we should consider is having webnn input operands that have 0 dimensions. This is not allowed today.
However, for TinyLama the first round of next token generation requires representing the past key value tensor as a tensor of dimension [1,4,0,64] the graph then takes the .shape() of that tensor and performs operations on it to determine the size of other tensors it creates via generateConstantOfShape.

@fdwr
Copy link
Collaborator

fdwr commented Mar 18, 2024

And it sounds like from @guschmue today that this affects yolov9 too.

@huningxin
Copy link
Contributor Author

@sushraja-msft

However, for TinyLama the first round of next token generation requires representing the past key value tensor as a tensor of dimension [1,4,0,64] the graph then takes the .shape() of that tensor and performs operations on it to determine the size of other tensors it creates via generateConstantOfShape.

Would this mean the shape of key value tensors keeps changing for each round of inference? WebNN only supports static shape. This may cause re-compiling WebNN graph for each round? We met similar issue for Whisper model inference. The static key value cache seems to be useful: huggingface/transformers#27931

inexorabletash added a commit to inexorabletash/webnn that referenced this issue May 20, 2024
Noticed during a review of the Chromium prototype. These are all
pretty obvious except for slice() where there is subtlety for 0-size
dimensions. I added an issue linking to webmachinelearning#391 since the steps will need
to be revised depending on how that issue is resolved.
fdwr pushed a commit that referenced this issue May 23, 2024
* Add missing validation for pad(), slice(), and split()

Noticed during a review of the Chromium prototype. These are all
pretty obvious except for slice() where there is subtlety for 0-size
dimensions. I added an issue linking to #391 since the steps will need
to be revised depending on how that issue is resolved.

* Add another note for split()
@reillyeon
Copy link
Contributor

@fdwr, you mentioned an ONNX model which depends on this. Can you elaborate on what this is used for in the model?

@fdwr
Copy link
Collaborator

fdwr commented May 24, 2024

@fdwr, you mentioned an ONNX model which depends on this. Can you elaborate on what this is used for in the model?

@reillyeon It will take some history digging for full context (like which operators in the model hit the issue). Two affected operators I recall were concatenation and slice. ⌛

@bbernhar
Copy link

@fdwr

If I permit MLBuffer to exist with 0-dim, what does it mean for DML to execute a IDMLCommandRecorder::RecordDispatch using a DML_BUFFER_BINDING::SizeInBytes equal to 0 and can this binding be NULL or left unbound?

@fdwr
Copy link
Collaborator

fdwr commented Aug 19, 2024

@fdwr, you mentioned an ONNX model which depends on this. Can you elaborate on what this is used for in the model?

@reillyeon It will take some history digging for full context (like which operators in the model hit the issue). Two affected operators I recall were concatenation and slice. ⌛

I know it hit a few more models, but my email search is just turning up RCNN models (like MaskRCNN) with operators {Cast, Xor, Unsqueeze, Concat, Scatter with empty indices which becomes identity}, plus this ORT CUDA CR microsoft/onnxruntime#2337 (but I didn't see the context of impacted models for CUDA EP).

If I permit MLBuffer to exist with 0-dim, what does it mean for DML to execute a IDMLCommandRecorder::RecordDispatch using a DML_BUFFER_BINDING::SizeInBytes equal to 0 and can this binding be NULL or left unbound?

@bbernhar Currently the DML API rejects empty tensors anyway, but I've been thinking of relaxing that (we see that it actually "just works" for a lot of operators when the validation is relaxed), and I think we'd still need the binding even for emptiness (so not unbound).

@bbernhar
Copy link

@fdwr

it actually "just works" for a lot of operators when the validation is relaxed

So if we create an "empty binding" by giving IDMLCommandRecorder::RecordDispatch a 4B dummy buffer with a DML_BUFFER_BINDING::SizeInBytes of zero, DML API will NOT reject?

@fdwr
Copy link
Collaborator

fdwr commented Aug 20, 2024

So if we create an "empty binding" by giving IDMLCommandRecorder::RecordDispatch a 4B dummy buffer with a DML_BUFFER_BINDING::SizeInBytes of zero, DML API will NOT reject?

@bbernhar I don't know what would happen in that case, but it's a moot point currently anyway because you cannot create an operator with empty tensors anyway (so you wouldn't even get as far as RecordDispatch). I just don't want to back ourselves into an inoperable corner where we can't support this in the DML API later (empty tensors have also been an issue when DML is called from TensorFlow, PyTorch, and ORT). Note it probably requires 16-byte dummy (per DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT = 16).

@bbernhar
Copy link

@fdwr

Currently, we can't specify MLBuffer using an operator (or "in the graph"), only as input/output to dispatch() which basically calls nothing but RecordDispatch . We need to ensure we're not relying on undefined DML behavior - perhaps others on DML team have thoughts on this "dummy buffer" approach? I believe DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT is the offset alignment requirement for bound buffers - DML buffer size alignment is 4B per MSDN [1]. Currently, the WebNN runtime disallows MLBuffer to be bound to non-zero offsets so only the size requirement matters, I think.

[1] https://learn.microsoft.com/en-us/windows/ai/directml/dml-helper-functions#dmlcalcbuffertensorsize

@fdwr fdwr changed the title Allow 0 size dimensions Allow 0 size dimensions (dimensions containing a 0 in the list of sizes, not a rank of 0 which is valid) Aug 20, 2024
@fdwr
Copy link
Collaborator

fdwr commented Aug 21, 2024

DML buffer size alignment is 4B

@bbernhar : Confirmed. Passing < 16 bytes is okay for DML_BUFFER_TENSOR_DESC::TotalTensorSizeInBytes, but it must be >= 4, or else you get: "The TotalTensorSizeInBytes of '...' for tensor '...' does not meet the minimum size required for this tensor, which is %llu bytes...".

# for free to join this conversation on GitHub. Already have an account? # to comment
Projects
None yet
Development

No branches or pull requests

7 participants