-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[CoreML] Update Reshape op to support more nodes #24594
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
base: main
Are you sure you want to change the base?
Conversation
… suggestions for comments
// See this issue, https://github.com/apple/coremltools/issues/1003 | ||
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf has maximum texture widths which may be the | ||
// root cause. | ||
static bool CheckShapeForLimit(onnxruntime::VectorInt64& shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe the name and/or the comment can be more descriptive. it's not clear what the "limit" is.
@@ -237,6 +237,16 @@ bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara | |||
const auto* weight_shape = input_defs[1]->Shape(); | |||
int64_t num_dims = weight_shape ? weight_shape->dim_size() : -1; | |||
|
|||
std::vector<int64_t> weight_shape_vec; | |||
std::vector<int64_t> x_shape_vec; | |||
GetShape(*input_defs[1], weight_shape_vec, logger); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check GetShape
return value. if it fails, it won't write to the output argument.
GetShape(*input_defs[0], x_shape_vec, logger); | ||
|
||
if (!CheckShapeForLimit(weight_shape_vec) || !CheckShapeForLimit(x_shape_vec)) { | ||
LOGS(logger, VERBOSE) << "Conv [" << name << "] has a shape with dimension > 16384. CoreML does not support conv operations with dim > 16384."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hardcoding 16384 in multiple places is not ideal. can we log the limit value from CheckShapeForLimit
?
@@ -75,11 +75,40 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, | |||
return Status::OK(); | |||
} | |||
|
|||
bool AllPositiveShape(gsl::span<const int64_t> shape) { | |||
return std::all_of(shape.begin(), shape.end(), [](int64_t dim) { return dim > 0 || dim == 0; }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it checks for all non-negative instead of "all positive", by the looks of it. is this equivalent to IsStaticShape()
?
// Case 2: If new_shape has exactly one -1 dimension, check if input_shape has at least one -1 | ||
if (negative_one_count == 1) { | ||
return std::any_of(input_shape.begin(), input_shape.end(), | ||
[](int64_t dim) { return dim == -1; }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if any dimensions of input_shape
are -1, wouldn't it already have failed this check above?
int input_negative_one_count = std::count(input_shape.begin(), input_shape.end(), -1);
if (input_negative_one_count > 0) return false;
bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, | ||
const logging::Logger& logger) const { | ||
const auto& input_defs = node.InputDefs(); | ||
const auto& new_shape_name = input_defs[1]->Name(); | ||
const auto* new_shape_tensor = input_params.graph_viewer.GetConstantInitializer(new_shape_name); | ||
|
||
NodeAttrHelper helper(node); | ||
const bool allow_zero = helper.Get("allow_zero", 0) == 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const bool allow_zero = helper.Get("allow_zero", 0) == 1; | |
const bool allow_zero = helper.Get("allowzero", 0) == 1; |
} | ||
} | ||
|
||
// first input must be fixed rank OR (first input has variadic rank AND shape only contains positive integers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"OR (first input has variadic rank AND shape only contains positive integers)"
if I understand correctly, we don't handle this, right? might be worth explicitly saying so.
|
||
// first input must be fixed rank OR (first input has variadic rank AND shape only contains positive integers) | ||
// as per docs, 0 is considered an illegal shape element if the input is variadic | ||
if (!GetShape(*input_defs[0], input_shape, logger)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of calling GetStaticShape
and then GetShape
and populating input_shape
twice, could call GetShape
and then IsStaticShape
on the result.
if (input_shape.empty()) { | ||
LOGS(logger, VERBOSE) << "Reshape does not support empty input shape"; | ||
if (input_shape.empty() && !AllPositiveShape(new_shape)) { | ||
// unknown rank & fails the positive shape check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does an empty shape mean unknown rank, or that we are dealing with a scalar?
Description
Follow the specs for CoreML Reshape more closely to allow for more Reshape ops to be used on CoreML EP
Folded into this PR, I also moved the dimension limit because it seems to only apply to conv operations (texture memory is typically used for conv operations in the GPU because it has a slow write but fast read -- ChromaDB model had a slice operation with an input > 16384 -- operation worked fine after I had moved the dim check)