-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
[CINN][Backend Pass Update No.12] Update transform_gpu_forloop pass #70883
[CINN][Backend Pass Update No.12] Update transform_gpu_forloop pass #70883
Conversation
|
||
std::vector<Expr> iter_values = stmt->iter_values(); | ||
for (ir::Expr& iter_value : iter_values) { | ||
ir::IRMutator<>::Visit(&iter_value, &iter_value); | ||
} | ||
stmt->set_iter_values(iter_values); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace_var_with_expr为什么需要replace iter_value呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class ReplaceLoopVarToGpu {
public:
void operator()(ir::stmt::BlockRef block) {
...
}
private:
void VisitStmt(ir::stmt::For stmt) {
auto bind_info = stmt->bind_info();
std::string var_name = "";
if (bind_info.offset <= 0)
var_name = "x";
else if (bind_info.offset == 1)
var_name = "y";
else if (bind_info.offset == 2)
var_name = "z";
if (stmt->is_gpu_block_binded()) {
var_name = "blockIdx." + var_name;
optim::ReplaceVarWithExpr<ir::stmt::StmtRef>(
stmt, stmt->loop_var(), ir::Expr(ir::Var(var_name)));
} else if (stmt->is_gpu_thread_binded()) {
var_name = "threadIdx." + var_name;
optim::ReplaceVarWithExpr<ir::stmt::StmtRef>(
stmt, stmt->loop_var(), ir::Expr(ir::Var(var_name)));
}
operator()(stmt->body());
}
};
在 ReplaceLoopVarToGpu 中调用 optim::ReplaceVarWithExprir::stmt::StmtRef() 时,stmt (Schedule) 中 iter_value : (i_j_fused / 128ll)
的 i_j_fused
也应该被替换为 blockIdx.x
。
例如:
Before:
245: {
245: Schedule (root_14) {
245: attrs(tile_method:TileFirstGeneralTactic)
245: thread_bind[blockIdx.x] for (i_j_fused, 0ll, (S0 * 128ll)) {
245: Schedule (var_7) {
245: i0_14, i1_8 = axis.bind((i_j_fused / 128ll), (i_j_fused % 128ll))
245: read_buffers(_var[i0_14(0:S0), i1_8(0:128ll)], _var_1[i0_14(0:S0), i1_8(0:128ll)])
245: write_buffers(_var_7[i0_14(0:S0), i1_8(0:128ll)])
245: var_7[(i_j_fused / 128ll), (i_j_fused % 128ll)] = (exp(var[(i_j_fused / 128ll), (i_j_fused % 128ll)]) - var_1[(i_j_fused / 128ll), (i_j_fused % 128ll)])
245: }
245: }
245: }
245: }
After:
245: {
245: Schedule (root_13) {
245: attrs(tile_method:TileFirstGeneralTactic)
245: thread_bind[blockIdx.x] for (blockIdx.x, 0ll, (S0 * 128ll)) {
245: Schedule (var_7) {
245: i0_13, i1_7 = axis.bind((blockIdx.x / 128ll), (blockIdx.x % 128ll))
245: read_buffers(_var[i0_13(0:S0), i1_7(0:128ll)], _var_1[i0_13(0:S0), i1_7(0:128ll)])
245: write_buffers(_var_7[i0_13(0:S0), i1_7(0:128ll)])
245: var_7[(blockIdx.x / 128ll), (blockIdx.x % 128ll)] = (exp(var[(blockIdx.x / 128ll), (blockIdx.x % 128ll)]) - var_1[(blockIdx.x / 128ll), (blockIdx.x % 128ll)])
245: }
245: }
245: }
245: }
LogicalResult CudaSyncThreadsDropIfThenElsePass::Run(ir::LoweredFunc fn) { | ||
DropIfThenElseMutator mutator; | ||
mutator(fn->body_block); | ||
return LogicalResult::success(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个实现为blockpass即可,遍历block里面的stmt,如果是IfThenElse stmt且满足条件(condition,内部是个Evaluate(call syncthread))就可以将其直接替换成内部的stmt
Sorry to inform you that dffe97a's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
#endif | ||
}, | ||
[&](std::variant<common::HygonDCUArchHIP, common::HygonDCUArchSYCL>) { | ||
// optim::EliminateCommonGlobalMemoryRead(&(func_body)); | ||
optim::OptimizeExprGPU(&(func_body)); | ||
optim::EliminateCommonGlobalMemoryRead(&(func_body)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个pass影响性能了,保持注释就好
@@ -392,13 +393,27 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess( | |||
common::ARMArch>) {}, | |||
[&](common::NVGPUArch) { | |||
#ifdef CINN_WITH_CUDA | |||
// optim::EliminateCommonGlobalMemoryRead(&(func_body)); | |||
optim::OptimizeExprGPU(&(func_body)); | |||
optim::EliminateCommonGlobalMemoryRead(&(func_body)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂时不开这个优化,保持注释就好
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
PR Category
CINN
PR Types
Improvements
Description
改造了 transform_gpu_forloop pass