-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
【Infer Symbolic Shape No.77】log_softmax #68025
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
third_party/flashattn
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
麻烦提交时删除这些修改或拉一下最新的develop分支
@@ -85,6 +85,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(L1Norm_) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LpPool2d) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Log_softmax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除这个声明,只需使用下面的驼峰命名方式。
另外,带inplace版本的op是在驼峰命名的最后添加下划线
@@ -1585,6 +1585,51 @@ bool LogsumexpOpInferSymbolicShape( | |||
return details::ReduceInferDim(op, infer_context, axis, keepdim, reduce_all); | |||
} | |||
|
|||
bool Log_softmaxOpInferSymbolicShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考yaml文件中的评论
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const auto &x_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape(); | ||
|
||
int axis = op->attribute<pir::Int32Attribute>("axis").data(); | ||
size_t rank = x_shape.size(); | ||
|
||
if (rank > 0) { | ||
PADDLE_ENFORCE_GE(axis, | ||
-rank, | ||
common::errors::InvalidArgument( | ||
"Attr(axis) value should be in range [-R, R-1], " | ||
"R is the rank of Input(X).")); | ||
PADDLE_ENFORCE_LT(axis, | ||
rank, | ||
common::errors::InvalidArgument( | ||
"Attr(axis) value should be in range [-R, R-1], " | ||
"R is the rank of Input(X).")); | ||
} else if (rank == 0) { | ||
PADDLE_ENFORCE_GE(axis, | ||
-1, | ||
common::errors::InvalidArgument( | ||
"Attr(axis) value should be in range [-1, " | ||
"0] when input is 0D Tensor ")); | ||
PADDLE_ENFORCE_LE(axis, | ||
0, | ||
common::errors::InvalidArgument( | ||
"Attr(axis) value should be in range [-1, " | ||
"0] when input is 0D Tensor ")); | ||
} | ||
|
||
infer_context->SetShapeOrDataForValue( | ||
op->result(0), | ||
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); | ||
|
||
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
麻烦将这部分逻辑包装成一个推导函数,cumprod 和 log_softmax 都直接调用这个推导函数,参考命名:UnchangedCheckAxisInferSymbolicShape ?
PR Category
CINN
PR Types
Improvements
Description
算子
log_softmax
符号推导