Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

【BUAA】【Infer Symbolic Shape】添加 weight_only_linear op的符号推导接口 #67875

Merged
merged 6 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3053,12 +3053,74 @@ bool WarprnntOpInferSymbolicShape(
return true;
}

// bool WeightOnlyLinearOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool WeightOnlyLinearOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &weight_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const auto &bias_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(2));
const auto &weight_scale_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(3));
const std::string &weight_dtype =
op->attribute<pir::StrAttribute>("weight_dtype").AsString();
const int group_size =
op->attribute<pir::Int32Attribute>("group_size").data();

PADDLE_ENFORCE(
(group_size == -1 || group_size == 64 || group_size == 128),
common::errors::InvalidArgument("group_size must be -1, 64 or 128."));

ExprVec weight_scale_shape = weight_scale_shape_or_data.shape();
ExprVec x_shape = x_shape_or_data.shape();
ExprVec weight_shape = weight_shape_or_data.shape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能用 const & 尽量用,避免copy

symbol::DimExpr n =
group_size == -1 ? weight_scale_shape[0] : weight_scale_shape[1];
PADDLE_ENFORCE(weight_dtype == "int8" || weight_dtype == "int4",
common::errors::InvalidArgument(
"quant_method must be 'int8' or 'int4'."));
PADDLE_ENFORCE_EQ(weight_shape.size(),
2UL,
common::errors::InvalidArgument(
"The input(weight) must be a 2D Tensor."));
// TODO(Jeff114514): can not use % between symbol::DimExpr and int, need to
// make sure weight_shape[0] and weight_shape[1] is divisible by 16
infer_context->AddEqualCstr(x_shape[x_shape.size() - 1], weight_shape[1]);
if (!bias_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
ExprVec bias_shape = bias_shape_or_data.shape();
PADDLE_ENFORCE_EQ(
bias_shape.size(),
1UL,
common::errors::InvalidArgument(
"The size of Input(Bias)'s dimension should equal to 1UL.",
bias_shape.size()));
}

if (group_size == -1) {
PADDLE_ENFORCE_EQ(weight_scale_shape.size(),
1UL,
common::errors::InvalidArgument(
"The input(weight_scale) must be a 1D Tensor."
"in per-channel mode."));
} else {
PADDLE_ENFORCE_EQ(weight_scale_shape.size(),
2UL,
common::errors::InvalidArgument(
"The input(weight_scale) must be a 2D Tensor"
" in groupwise mode."));
infer_context->AddEqualCstr(
weight_scale_shape[0],
(weight_shape[1] + (group_size - 1)) / group_size);
}
ExprVec out_shape = x_shape;
out_shape[out_shape.size() - 1] = n;
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});
return true;
}

bool WeightedSampleNeighborsOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(HsigmoidLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ViterbiDecode)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Warpctc)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Warprnnt)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(WeightOnlyLinear)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(WeightOnlyLinear)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(WeightedSampleNeighbors)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where_)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5113,7 +5113,7 @@
data_type : x
optional : bias
backward : weight_only_linear_grad
# interfaces : paddle::dialect::InferSymbolicShapeInterface
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : weight_quantize
args : (Tensor x, str algo = "weight_only_int8", int arch = 80, int group_size = -1)
Expand Down