Skip to content

Commit

Permalink
【BUAA】【Infer Symbolic Shape】add histogram (#67776)
Browse files Browse the repository at this point in the history
  • Loading branch information
uanu2002 authored Aug 30, 2024
1 parent 4910b6d commit d6f14fb
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -862,12 +862,41 @@ bool HuberLossOpInferSymbolicShape(
return true;
}

// bool HistogramOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool HistogramOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &input_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
int64_t bins = op->attribute<pir::Int64Attribute>("bins").data();
int min = op->attribute<pir::Int32Attribute>("min").data();
int max = op->attribute<pir::Int32Attribute>("max").data();
PADDLE_ENFORCE_GE(bins,
1,
common::errors::InvalidArgument(
"The bins should be greater than or equal to 1."
"But received nbins is %d",
bins));
PADDLE_ENFORCE_GE(
max,
min,
common::errors::InvalidArgument("max must be larger or equal to min."
"But received max is %d, min is %d",
max,
min));
if (op->operand_source(1)) {
const symbol::ShapeOrDataDimExprs &weight_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
size_t ndims_input = input_shape_or_data.shape().size();
for (size_t i = 0; i < ndims_input; ++i) {
infer_context->AddEqualCstr(weight_shape_or_data.shape()[i],
input_shape_or_data.shape()[i]);
}
}
std::vector<symbol::DimExpr> dim_out = {bins};
symbol::ShapeOrDataDimExprs shape_data{
symbol::TensorShapeOrDataDimExprs(dim_out)};
infer_context->SetShapeOrDataForValue(op->result(0), shape_data);
return true;
}

bool IndexSampleOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gather)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GatherNd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GatherTree)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(HuberLoss)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Histogram)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Histogram)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Isclose)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexAdd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexAdd_)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2421,6 +2421,7 @@
optional : weight
kernel :
func : histogram
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : hsigmoid_loss
args : (Tensor x, Tensor label, Tensor w, Tensor bias, Tensor path, Tensor code, int num_classes, bool is_sparse)
Expand Down

0 comments on commit d6f14fb

Please # to comment.