Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

【Infer Symbolic Shape No.77】log_softmax #68025

Closed
wants to merge 6 commits into from

Conversation

MikhayEeer
Copy link
Contributor

PR Category

CINN

PR Types

Improvements

Description

算子log_softmax符号推导

Copy link

paddle-bot bot commented Sep 5, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Sep 5, 2024
@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label Sep 6, 2024
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

麻烦提交时删除这些修改或拉一下最新的develop分支

@@ -85,6 +85,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(L1Norm_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LpPool2d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Log_softmax)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除这个声明,只需使用下面的驼峰命名方式。
另外,带inplace版本的op是在驼峰命名的最后添加下划线

@@ -1585,6 +1585,51 @@ bool LogsumexpOpInferSymbolicShape(
return details::ReduceInferDim(op, infer_context, axis, keepdim, reduce_all);
}

bool Log_softmaxOpInferSymbolicShape(
Copy link
Contributor

@gongshaotian gongshaotian Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考yaml文件中的评论

Comment on lines +1589 to +1625
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();

int axis = op->attribute<pir::Int32Attribute>("axis").data();
size_t rank = x_shape.size();

if (rank > 0) {
PADDLE_ENFORCE_GE(axis,
-rank,
common::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
PADDLE_ENFORCE_LT(axis,
rank,
common::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
} else if (rank == 0) {
PADDLE_ENFORCE_GE(axis,
-1,
common::errors::InvalidArgument(
"Attr(axis) value should be in range [-1, "
"0] when input is 0D Tensor "));
PADDLE_ENFORCE_LE(axis,
0,
common::errors::InvalidArgument(
"Attr(axis) value should be in range [-1, "
"0] when input is 0D Tensor "));
}

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)});

return true;
Copy link
Contributor

@gongshaotian gongshaotian Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

麻烦将这部分逻辑包装成一个推导函数,cumprod 和 log_softmax 都直接调用这个推导函数,参考命名:UnchangedCheckAxisInferSymbolicShape ?

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants