Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

【Infer Symbolic Shape No.77】log_softmax #68025

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,51 @@ bool LogsumexpOpInferSymbolicShape(
return details::ReduceInferDim(op, infer_context, axis, keepdim, reduce_all);
}

bool Log_softmaxOpInferSymbolicShape(
Copy link
Contributor

@gongshaotian gongshaotian Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考yaml文件中的评论

pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();

int axis = op->attribute<pir::Int32Attribute>("axis").data();
size_t rank = x_shape.size();

if (rank > 0) {
PADDLE_ENFORCE_GE(axis,
-rank,
common::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
PADDLE_ENFORCE_LT(axis,
rank,
common::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
} else if (rank == 0) {
PADDLE_ENFORCE_GE(axis,
-1,
common::errors::InvalidArgument(
"Attr(axis) value should be in range [-1, "
"0] when input is 0D Tensor "));
PADDLE_ENFORCE_LE(axis,
0,
common::errors::InvalidArgument(
"Attr(axis) value should be in range [-1, "
"0] when input is 0D Tensor "));
}

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)});

return true;
Comment on lines +1589 to +1625
Copy link
Contributor

@gongshaotian gongshaotian Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

麻烦将这部分逻辑包装成一个推导函数,cumprod 和 log_softmax 都直接调用这个推导函数,参考命名:UnchangedCheckAxisInferSymbolicShape ?

}

bool LogSoftmaxOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return Log_softmaxOpInferSymbolicShape(op, infer_context);
}

bool LuOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(L1Norm_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LpPool2d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Log_softmax)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除这个声明,只需使用下面的驼峰命名方式。
另外,带inplace版本的op是在驼峰命名的最后添加下划线

OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogSoftmax)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lu)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lu_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mode)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2895,6 +2895,7 @@
func : log_softmax
data_type : x
backward : log_softmax_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : logcumsumexp
args : (Tensor x, int axis=-1, bool flatten=false, bool exclusive=false, bool reverse=false)
Expand Down