Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[CINN][Backend Pass Update No.12] Update transform_gpu_forloop pass #70883

Merged

Conversation

Albresky
Copy link
Contributor

PR Category

CINN

PR Types

Improvements

Description

改造了 transform_gpu_forloop pass

@paddle-bot paddle-bot bot added the contributor External developers label Jan 17, 2025
Comment on lines +121 to +127

std::vector<Expr> iter_values = stmt->iter_values();
for (ir::Expr& iter_value : iter_values) {
ir::IRMutator<>::Visit(&iter_value, &iter_value);
}
stmt->set_iter_values(iter_values);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace_var_with_expr为什么需要replace iter_value呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class ReplaceLoopVarToGpu {
 public:
  void operator()(ir::stmt::BlockRef block) {
  ...
  }

 private:
  void VisitStmt(ir::stmt::For stmt) {
    auto bind_info = stmt->bind_info();

    std::string var_name = "";
    if (bind_info.offset <= 0)
      var_name = "x";
    else if (bind_info.offset == 1)
      var_name = "y";
    else if (bind_info.offset == 2)
      var_name = "z";
    if (stmt->is_gpu_block_binded()) {
      var_name = "blockIdx." + var_name;
      optim::ReplaceVarWithExpr<ir::stmt::StmtRef>(
          stmt, stmt->loop_var(), ir::Expr(ir::Var(var_name)));
    } else if (stmt->is_gpu_thread_binded()) {
      var_name = "threadIdx." + var_name;
      optim::ReplaceVarWithExpr<ir::stmt::StmtRef>(
          stmt, stmt->loop_var(), ir::Expr(ir::Var(var_name)));
    }

    operator()(stmt->body());
  }
};

在 ReplaceLoopVarToGpu 中调用 optim::ReplaceVarWithExprir::stmt::StmtRef() 时,stmt (Schedule) 中 iter_value : (i_j_fused / 128ll)i_j_fused 也应该被替换为 blockIdx.x

例如:

Before:

245: {
245:   Schedule (root_14) {
245:     attrs(tile_method:TileFirstGeneralTactic)
245:     thread_bind[blockIdx.x] for (i_j_fused, 0ll, (S0 * 128ll)) {
245:       Schedule (var_7) {
245:         i0_14, i1_8 = axis.bind((i_j_fused / 128ll), (i_j_fused % 128ll))
245:         read_buffers(_var[i0_14(0:S0), i1_8(0:128ll)], _var_1[i0_14(0:S0), i1_8(0:128ll)])
245:         write_buffers(_var_7[i0_14(0:S0), i1_8(0:128ll)])
245:         var_7[(i_j_fused / 128ll), (i_j_fused % 128ll)] = (exp(var[(i_j_fused / 128ll), (i_j_fused % 128ll)]) - var_1[(i_j_fused / 128ll), (i_j_fused % 128ll)])
245:       }
245:     }
245:   }
245: }

After:

245: {
245:   Schedule (root_13) {
245:     attrs(tile_method:TileFirstGeneralTactic)
245:     thread_bind[blockIdx.x] for (blockIdx.x, 0ll, (S0 * 128ll)) {
245:       Schedule (var_7) {
245:         i0_13, i1_7 = axis.bind((blockIdx.x / 128ll), (blockIdx.x % 128ll))
245:         read_buffers(_var[i0_13(0:S0), i1_7(0:128ll)], _var_1[i0_13(0:S0), i1_7(0:128ll)])
245:         write_buffers(_var_7[i0_13(0:S0), i1_7(0:128ll)])
245:         var_7[(blockIdx.x / 128ll), (blockIdx.x % 128ll)] = (exp(var[(blockIdx.x / 128ll), (blockIdx.x % 128ll)]) - var_1[(blockIdx.x / 128ll), (blockIdx.x % 128ll)])
245:       }
245:     }
245:   }
245: }

Comment on lines 303 to 307
LogicalResult CudaSyncThreadsDropIfThenElsePass::Run(ir::LoweredFunc fn) {
DropIfThenElseMutator mutator;
mutator(fn->body_block);
return LogicalResult::success();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个实现为blockpass即可,遍历block里面的stmt,如果是IfThenElse stmt且满足条件(condition,内部是个Evaluate(call syncthread))就可以将其直接替换成内部的stmt

Copy link

paddle-ci-bot bot commented Feb 10, 2025

Sorry to inform you that dffe97a's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

#endif
},
[&](std::variant<common::HygonDCUArchHIP, common::HygonDCUArchSYCL>) {
// optim::EliminateCommonGlobalMemoryRead(&(func_body));
optim::OptimizeExprGPU(&(func_body));
optim::EliminateCommonGlobalMemoryRead(&(func_body));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个pass影响性能了,保持注释就好

@@ -392,13 +393,27 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
common::ARMArch>) {},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
// optim::EliminateCommonGlobalMemoryRead(&(func_body));
optim::OptimizeExprGPU(&(func_body));
optim::EliminateCommonGlobalMemoryRead(&(func_body));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时不开这个优化,保持注释就好

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

@Hongqing-work Hongqing-work merged commit fba2e98 into PaddlePaddle:develop Feb 19, 2025
31 checks passed
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants