Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[CINN][Backend Pass Update No.12] Update transform_gpu_forloop pass #70883

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "paddle/cinn/optim/eliminate_common_global_memory_read.h"
#include "paddle/cinn/optim/schedule_block_dce.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/pass/pass_manager.h"
#include "paddle/common/ddim.h"
#include "paddle/common/enforce.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
Expand Down Expand Up @@ -393,12 +394,26 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
// optim::EliminateCommonGlobalMemoryRead(&(func_body));
optim::OptimizeExprGPU(&(func_body));
ir::stmt::BlockRef func_body_block =
ir::ConvertExprBlockToStmtBlock(func_body);
LOG(INFO) << "Before OptimizeExprGPU in op_lowering_impl: \n"
<< func_body_block;
optim::OptimizeExprGPU(func_body_block);
LOG(INFO) << "After OptimizeExprGPU in op_lowering_impl: \n"
<< func_body_block;
func_body = ir::ConvertStmtBlockToExprBlock(func_body_block);
#endif
},
[&](std::variant<common::HygonDCUArchHIP, common::HygonDCUArchSYCL>) {
// optim::EliminateCommonGlobalMemoryRead(&(func_body));
optim::OptimizeExprGPU(&(func_body));
ir::stmt::BlockRef func_body_block =
ir::ConvertExprBlockToStmtBlock(func_body);
LOG(INFO) << "Before OptimizeExprGPU in op_lowering_impl: \n"
<< func_body_block;
optim::OptimizeExprGPU(func_body_block);
LOG(INFO) << "After OptimizeExprGPU in op_lowering_impl: \n"
<< func_body_block;
func_body = ir::ConvertStmtBlockToExprBlock(func_body_block);
});
}

Expand Down
25 changes: 20 additions & 5 deletions paddle/cinn/optim/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,17 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
#ifdef CINN_WITH_CUDA
ir::SetCudaAxisInfo(copied);
if (remove_gpu_for_loops) {
RemoveGpuForLoops(copied);
LOG(INFO) << "Before removing GPU for loops:\n" << copied;
FuncPassManager func_pass_manager;
func_pass_manager.AddPass(CreateRemoveGpuForLoopsPass());
func_pass_manager.Run(copied);
LOG(INFO) << "After removing GPU for loops:\n" << copied;
}
CudaSyncThreadsDropIfThenElse(copied);
VLOG(10) << "Before Optimize CudaSyncThreadsDropIfThenElse:" << copied;
BlockPassManager blk_pass_manager;
blk_pass_manager.AddPass(CreateCudaSyncThreadsDropIfThenElsePass());
blk_pass_manager.Run(copied->body_block);
VLOG(10) << "After Optimize CudaSyncThreadsDropIfThenElse:" << copied;
FuncPassManager func_pass_manager;
VLOG(10) << "Before Optimize TransBufferWithDynamicShape:" << copied;
func_pass_manager.AddPass(CreateTransBufferWithDynamicShapePass());
Expand All @@ -99,10 +107,17 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
#ifdef CINN_WITH_HIP
ir::SetCudaAxisInfo(copied);
if (remove_gpu_for_loops) {
RemoveGpuForLoops(copied);
LOG(INFO) << "Before removing GPU for loops:\n" << copied;
FuncPassManager func_pass_manager;
func_pass_manager.AddPass(CreateRemoveGpuForLoopsPass());
func_pass_manager.Run(copied);
LOG(INFO) << "After removing GPU for loops:\n" << copied;
}
CudaSyncThreadsDropIfThenElse(copied);
// CudaTransBufferWithDynamicShape(&copied);
VLOG(10) << "Before Optimize CudaSyncThreadsDropIfThenElse:" << copied;
BlockPassManager blk_pass_manager;
blk_pass_manager.AddPass(CreateCudaSyncThreadsDropIfThenElsePass());
blk_pass_manager.Run(copied->body_block);
VLOG(10) << "After Optimize CudaSyncThreadsDropIfThenElse:" << copied;
#endif
},
[&](common::HygonDCUArchSYCL) { CINN_NOT_IMPLEMENTED },
Expand Down
7 changes: 7 additions & 0 deletions paddle/cinn/optim/replace_var_with_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<>,
ir::IRMutator<>::Visit(&var->upper_bound, &var->upper_bound);
}
}

std::vector<Expr> iter_values = stmt->iter_values();
for (ir::Expr& iter_value : iter_values) {
ir::IRMutator<>::Visit(&iter_value, &iter_value);
}
stmt->set_iter_values(iter_values);

Comment on lines +121 to +127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace_var_with_expr为什么需要replace iter_value呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class ReplaceLoopVarToGpu {
 public:
  void operator()(ir::stmt::BlockRef block) {
  ...
  }

 private:
  void VisitStmt(ir::stmt::For stmt) {
    auto bind_info = stmt->bind_info();

    std::string var_name = "";
    if (bind_info.offset <= 0)
      var_name = "x";
    else if (bind_info.offset == 1)
      var_name = "y";
    else if (bind_info.offset == 2)
      var_name = "z";
    if (stmt->is_gpu_block_binded()) {
      var_name = "blockIdx." + var_name;
      optim::ReplaceVarWithExpr<ir::stmt::StmtRef>(
          stmt, stmt->loop_var(), ir::Expr(ir::Var(var_name)));
    } else if (stmt->is_gpu_thread_binded()) {
      var_name = "threadIdx." + var_name;
      optim::ReplaceVarWithExpr<ir::stmt::StmtRef>(
          stmt, stmt->loop_var(), ir::Expr(ir::Var(var_name)));
    }

    operator()(stmt->body());
  }
};

在 ReplaceLoopVarToGpu 中调用 optim::ReplaceVarWithExprir::stmt::StmtRef() 时,stmt (Schedule) 中 iter_value : (i_j_fused / 128ll)i_j_fused 也应该被替换为 blockIdx.x

例如:

Before:

245: {
245:   Schedule (root_14) {
245:     attrs(tile_method:TileFirstGeneralTactic)
245:     thread_bind[blockIdx.x] for (i_j_fused, 0ll, (S0 * 128ll)) {
245:       Schedule (var_7) {
245:         i0_14, i1_8 = axis.bind((i_j_fused / 128ll), (i_j_fused % 128ll))
245:         read_buffers(_var[i0_14(0:S0), i1_8(0:128ll)], _var_1[i0_14(0:S0), i1_8(0:128ll)])
245:         write_buffers(_var_7[i0_14(0:S0), i1_8(0:128ll)])
245:         var_7[(i_j_fused / 128ll), (i_j_fused % 128ll)] = (exp(var[(i_j_fused / 128ll), (i_j_fused % 128ll)]) - var_1[(i_j_fused / 128ll), (i_j_fused % 128ll)])
245:       }
245:     }
245:   }
245: }

After:

245: {
245:   Schedule (root_13) {
245:     attrs(tile_method:TileFirstGeneralTactic)
245:     thread_bind[blockIdx.x] for (blockIdx.x, 0ll, (S0 * 128ll)) {
245:       Schedule (var_7) {
245:         i0_13, i1_7 = axis.bind((blockIdx.x / 128ll), (blockIdx.x % 128ll))
245:         read_buffers(_var[i0_13(0:S0), i1_7(0:128ll)], _var_1[i0_13(0:S0), i1_7(0:128ll)])
245:         write_buffers(_var_7[i0_13(0:S0), i1_7(0:128ll)])
245:         var_7[(blockIdx.x / 128ll), (blockIdx.x % 128ll)] = (exp(var[(blockIdx.x / 128ll), (blockIdx.x % 128ll)]) - var_1[(blockIdx.x / 128ll), (blockIdx.x % 128ll)])
245:       }
245:     }
245:   }
245: }

std::vector<Expr> new_read_buffers = stmt->read_buffers();
for (Expr& read_buffer : new_read_buffers) {
ir::IRMutator<>::Visit(&read_buffer, &read_buffer);
Expand Down
Loading