diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index 84b2697f55d6ad..6dc911d6708803 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -40,6 +40,7 @@ #include "paddle/cinn/optim/eliminate_common_global_memory_read.h" #include "paddle/cinn/optim/schedule_block_dce.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" +#include "paddle/cinn/pass/pass_manager.h" #include "paddle/common/ddim.h" #include "paddle/common/enforce.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" @@ -393,12 +394,26 @@ std::vector OpLowererImpl::PostProcess( [&](common::NVGPUArch) { #ifdef CINN_WITH_CUDA // optim::EliminateCommonGlobalMemoryRead(&(func_body)); - optim::OptimizeExprGPU(&(func_body)); + ir::stmt::BlockRef func_body_block = + ir::ConvertExprBlockToStmtBlock(func_body); + LOG(INFO) << "Before OptimizeExprGPU in op_lowering_impl: \n" + << func_body_block; + optim::OptimizeExprGPU(func_body_block); + LOG(INFO) << "After OptimizeExprGPU in op_lowering_impl: \n" + << func_body_block; + func_body = ir::ConvertStmtBlockToExprBlock(func_body_block); #endif }, [&](std::variant) { // optim::EliminateCommonGlobalMemoryRead(&(func_body)); - optim::OptimizeExprGPU(&(func_body)); + ir::stmt::BlockRef func_body_block = + ir::ConvertExprBlockToStmtBlock(func_body); + LOG(INFO) << "Before OptimizeExprGPU in op_lowering_impl: \n" + << func_body_block; + optim::OptimizeExprGPU(func_body_block); + LOG(INFO) << "After OptimizeExprGPU in op_lowering_impl: \n" + << func_body_block; + func_body = ir::ConvertStmtBlockToExprBlock(func_body_block); }); } diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index 8f7c8fe52916b4..216771354dd2d3 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -85,9 +85,17 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn, #ifdef CINN_WITH_CUDA ir::SetCudaAxisInfo(copied); if (remove_gpu_for_loops) { - RemoveGpuForLoops(copied); + LOG(INFO) << "Before removing GPU for loops:\n" << copied; + FuncPassManager func_pass_manager; + func_pass_manager.AddPass(CreateRemoveGpuForLoopsPass()); + func_pass_manager.Run(copied); + LOG(INFO) << "After removing GPU for loops:\n" << copied; } - CudaSyncThreadsDropIfThenElse(copied); + VLOG(10) << "Before Optimize CudaSyncThreadsDropIfThenElse:" << copied; + BlockPassManager blk_pass_manager; + blk_pass_manager.AddPass(CreateCudaSyncThreadsDropIfThenElsePass()); + blk_pass_manager.Run(copied->body_block); + VLOG(10) << "After Optimize CudaSyncThreadsDropIfThenElse:" << copied; FuncPassManager func_pass_manager; VLOG(10) << "Before Optimize TransBufferWithDynamicShape:" << copied; func_pass_manager.AddPass(CreateTransBufferWithDynamicShapePass()); @@ -99,10 +107,17 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn, #ifdef CINN_WITH_HIP ir::SetCudaAxisInfo(copied); if (remove_gpu_for_loops) { - RemoveGpuForLoops(copied); + LOG(INFO) << "Before removing GPU for loops:\n" << copied; + FuncPassManager func_pass_manager; + func_pass_manager.AddPass(CreateRemoveGpuForLoopsPass()); + func_pass_manager.Run(copied); + LOG(INFO) << "After removing GPU for loops:\n" << copied; } - CudaSyncThreadsDropIfThenElse(copied); - // CudaTransBufferWithDynamicShape(&copied); + VLOG(10) << "Before Optimize CudaSyncThreadsDropIfThenElse:" << copied; + BlockPassManager blk_pass_manager; + blk_pass_manager.AddPass(CreateCudaSyncThreadsDropIfThenElsePass()); + blk_pass_manager.Run(copied->body_block); + VLOG(10) << "After Optimize CudaSyncThreadsDropIfThenElse:" << copied; #endif }, [&](common::HygonDCUArchSYCL) { CINN_NOT_IMPLEMENTED }, diff --git a/paddle/cinn/optim/replace_var_with_expr.cc b/paddle/cinn/optim/replace_var_with_expr.cc index 94514ff440f0cf..a66c050461f127 100644 --- a/paddle/cinn/optim/replace_var_with_expr.cc +++ b/paddle/cinn/optim/replace_var_with_expr.cc @@ -118,6 +118,13 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<>, ir::IRMutator<>::Visit(&var->upper_bound, &var->upper_bound); } } + + std::vector iter_values = stmt->iter_values(); + for (ir::Expr& iter_value : iter_values) { + ir::IRMutator<>::Visit(&iter_value, &iter_value); + } + stmt->set_iter_values(iter_values); + std::vector new_read_buffers = stmt->read_buffers(); for (Expr& read_buffer : new_read_buffers) { ir::IRMutator<>::Visit(&read_buffer, &read_buffer); diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc index 225afbb42b19dd..c74298c3b52a52 100644 --- a/paddle/cinn/optim/transform_gpu_forloop.cc +++ b/paddle/cinn/optim/transform_gpu_forloop.cc @@ -27,6 +27,7 @@ #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/stmt_visitors.h" #include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/stmt_converter.h" #include "paddle/cinn/optim/eliminate_common_factor_of_local_index.h" @@ -46,126 +47,159 @@ PD_DECLARE_bool(cinn_longlong2int); namespace cinn { namespace optim { -void RemoveGpuForLoops(ir::LoweredFunc fn) { - struct Mutator : public ir::IRMutator { - using ir::IRMutator<>::Visit; - void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } - - explicit Mutator(const ir::CudaAxisInfo &cuda_axis_info) - : cuda_axis_info_(cuda_axis_info) {} +class GPUForLoopsMutator { + public: + void operator()(ir::stmt::BlockRef block) { VisitBlock(block); } - private: - ir::CudaAxisInfo cuda_axis_info_; + explicit GPUForLoopsMutator(const ir::CudaAxisInfo &cuda_axis_info) + : cuda_axis_info_(cuda_axis_info) {} - void Visit(const ir::For *op, Expr *expr) override { - switch (op->for_type()) { - case ir::ForType::GPUBlock: - if (NeedToReplaceForloopWithIfThenElse(op)) { - ReplaceForloopWithIfThenElse(expr); - } else { - *expr = op->body; + private: + void VisitBlock(ir::stmt::BlockRef block) { + std::vector stmts = block->stmts(); + std::vector new_stmts; + for (ir::stmt::StmtRef &stmt : stmts) { + switch (stmt->stmt_type()) { + case ir::StmtNodeTy::For: { + ir::stmt::For for_stmt = stmt.as(); + switch (VisitStmt(for_stmt)) { + case 0: { + ReplaceForloopWithIfThenElse(stmt); + ir::stmt::IfThenElse if_stmt = stmt.as(); + // Visit true case only + VisitBlock(if_stmt->true_case()); + new_stmts.push_back(if_stmt); + break; + } + case 1: { + VisitBlock(for_stmt->body()); + for (const auto &stmt : for_stmt->body()->stmts()) { + new_stmts.push_back(stmt); + } + break; + } + case 2: { + VisitBlock(for_stmt->body()); + new_stmts.push_back(for_stmt); + break; + } + default: + break; } - IRMutator<>::Visit(expr, expr); break; - case ir::ForType::GPUThread: - if (NeedToReplaceForloopWithIfThenElse(op)) { - ReplaceForloopWithIfThenElse(expr); - } else { - *expr = op->body; + } + case ir::StmtNodeTy::Schedule: { + ir::stmt::Schedule schedule = stmt.as(); + VisitBlock(schedule->body()); + new_stmts.push_back(stmt); + break; + } + case ir::StmtNodeTy::IfThenElse: { + ir::stmt::IfThenElse if_then_else = stmt.as(); + VisitBlock(if_then_else->true_case()); + if (if_then_else->false_case().defined()) { + VisitBlock(if_then_else->true_case()); } - IRMutator<>::Visit(expr, expr); + new_stmts.push_back(stmt); break; + } default: - auto *node = expr->As(); - IRMutator<>::Visit(&node->body, &node->body); + new_stmts.push_back(stmt); break; } } + block->set_stmts(new_stmts); + } - bool NeedToReplaceForloopWithIfThenElse(const ir::For *n) const { - // If the loop doesn't start from 0. - if (n->min != cinn::common::make_const(0)) { - return true; + // NOLINTNEXTLINE(runtime/references) + int VisitStmt(const ir::stmt::For &stmt) { + if (stmt->for_type() == ir::ForType::GPUBlock || + stmt->for_type() == ir::ForType::GPUThread) { + if (NeedToReplaceForloopWithIfThenElse(stmt)) { + // Replace the GPU For loop with an IfThenElse. + return 0; + } else { + // Replace the GPU For loop with its body. + return 1; } + } + // Keep this For loop, traverse the body of it. + return 2; + } - // Get dim_size from the functions's cuda_axis_info as pre-condition. - ir::Expr dim_size; - switch (n->bind_info().for_type) { - case ir::ForType::GPUThread: - dim_size = cuda_axis_info_.block_dim(n->bind_info().offset); - break; - case ir::ForType::GPUBlock: - dim_size = cuda_axis_info_.grid_dim(n->bind_info().offset); - break; - } - if (!dim_size.defined()) { - return true; - } + bool NeedToReplaceForloopWithIfThenElse(const ir::stmt::For &stmt) const { + // If the loop doesn't start from 0. + if (stmt->min() != cinn::common::make_const(0)) { + return true; + } - // If we can prove the loop's extent >= dim_size, then it's safe not - // to add the IfThenElse guard. - common::cas_intervals_t var_intervals = - common::CollectVarIntervalsOfExprs({n->extent, dim_size}); - common::SymbolicExprAnalyzer analyzer{var_intervals}; - std::optional proved_ge = analyzer.ProveGE(n->extent, dim_size); - if (proved_ge.value_or(false)) { - return false; - } + // Get dim_size from the functions's cuda_axis_info as pre-condition. + ir::Expr dim_size; + switch (stmt->bind_info().for_type) { + case ir::ForType::GPUThread: + dim_size = cuda_axis_info_.block_dim(stmt->bind_info().offset); + break; + case ir::ForType::GPUBlock: + dim_size = cuda_axis_info_.grid_dim(stmt->bind_info().offset); + break; + } + if (!dim_size.defined()) { return true; } - void ReplaceForloopWithIfThenElse(Expr *expr) { - auto *for_n = expr->As(); + // If we can prove the loop's extent >= dim_size, then it's safe not + // to add the IfThenElse guard. + common::cas_intervals_t var_intervals = + common::CollectVarIntervalsOfExprs({stmt->extent(), dim_size}); + common::SymbolicExprAnalyzer analyzer{var_intervals}; + std::optional proved_ge = analyzer.ProveGE(stmt->extent(), dim_size); + if (proved_ge.value_or(false)) { + return false; + } + return true; + } - Expr condition; - const auto AppendCondition = [&](Expr new_cond) { - if (condition.defined()) { - condition = ir::And::Make(condition, new_cond); - } else { - condition = new_cond; - } - }; + // NOLINTNEXTLINE(runtime/references) + void ReplaceForloopWithIfThenElse(ir::stmt::StmtRef &stmt) { + ir::stmt::For for_n = stmt.as(); - // for(i, 2, 100); - // ^ - if (for_n->min != cinn::common::make_const(0)) { - AppendCondition(ir::GE::Make(for_n->loop_var, for_n->min)); + Expr condition; + const auto AppendCondition = [&](Expr new_cond) { + if (condition.defined()) { + condition = ir::And::Make(condition, new_cond); + } else { + condition = new_cond; } - // for(i, 2, min(M/2, 20) - // ^ - AppendCondition(ir::LT::Make(for_n->loop_var, for_n->extent)); - - PADDLE_ENFORCE_EQ(condition.defined(), - true, - ::common::errors::InvalidArgument( - "Condition is not defined, please check.")); - - *expr = ir::IfThenElse::Make(condition, for_n->body); - } - - void Visit(const ir::PolyFor *op, Expr *expr) override { - const auto msg = - "PolyFor is not allowed for GPU, only For nodes are allowed"; - PADDLE_ENFORCE_EQ( - op->for_type() != ir::ForType::GPUBlock, - true, - ::common::errors::InvalidArgument( - "PolyFor is not allowed for GPU, only For nodes are allowed.")); - PADDLE_ENFORCE_EQ( - op->for_type() != ir::ForType::GPUThread, - true, - ::common::errors::InvalidArgument( - "PolyFor is not allowed for GPU, only For nodes are allowed.")); - PADDLE_ENFORCE_EQ( - op->for_type() != ir::ForType::GPULane, - true, - ::common::errors::InvalidArgument( - "PolyFor is not allowed for GPU, only For nodes are allowed.")); - } - }; - - Mutator mutator(fn->cuda_axis_info); - mutator(&fn->body); + }; + + // for(i, 2, 100); + // ^ + if (for_n->min() != cinn::common::make_const(0)) { + AppendCondition(ir::GE::Make(for_n->loop_var(), for_n->min())); + } + // for(i, 2, min(M/2, 20) + // ^ + AppendCondition(ir::LT::Make(for_n->loop_var(), for_n->extent())); + + PADDLE_ENFORCE_EQ(condition.defined(), + true, + ::common::errors::InvalidArgument( + "Condition is not defined, please check.")); + + stmt = ir::stmt::IfThenElse(condition, for_n->body()); + } + + ir::CudaAxisInfo cuda_axis_info_; +}; + +LogicalResult RemoveGpuForLoopsPass::Run(ir::LoweredFunc fn) { + GPUForLoopsMutator mutator(fn->cuda_axis_info); + mutator(fn->body_block); + return LogicalResult::success(); +} + +std::unique_ptr CreateRemoveGpuForLoopsPass() { + return std::make_unique(); } /** @@ -173,40 +207,82 @@ void RemoveGpuForLoops(ir::LoweredFunc fn) { * this is the problem of isl AST output, drop it to make it run in all the * threads. */ -void CudaSyncThreadsDropIfThenElse(ir::LoweredFunc fn) { - struct Mutator : public ir::IRMutator<> { - using ir::IRMutator<>::Visit; - void operator()(ir::LoweredFunc fn) { Visit(fn.As()); } - - void Visit(const ir::IfThenElse *op, Expr *expr) override { - blocked_statement_stack.push_back(expr); - ir::IRMutator<>::Visit(op, expr); - blocked_statement_stack.pop_back(); - } - - void Visit(const ir::Call *op, Expr *expr) override { - if (op->name == runtime::intrinsic::cuda_sync_threads) { - if (!blocked_statement_stack.empty()) { - auto *last_for = blocked_statement_stack.back()->As(); - if (auto *eq_n = last_for->condition.As()) { - if (eq_n->b() == cinn::common::make_const(0)) { - *blocked_statement_stack.back() = *expr; +class DropIfThenElseMutator { + public: + void operator()(ir::stmt::BlockRef block) { VisitBlock(block); } + + private: + bool isDropCandidate(const ir::stmt::IfThenElse &stmt) { + if (!stmt->condition().defined()) return false; + const ir::Expr &cond = stmt->condition(); + if (auto *eq_n = cond.As()) { + if (eq_n->b() == cinn::common::make_const(0)) { + ir::stmt::BlockRef true_case = stmt->true_case(); + if (true_case.defined() && true_case->stmts().size() == 1) { + auto eval_stmt = true_case->stmts()[0]; + if (eval_stmt->stmt_type() == ir::StmtNodeTy::Evaluate) { + auto eval_expr = eval_stmt.as()->value(); + if (auto *call = eval_expr.As()) { + if (call->name == runtime::intrinsic::cuda_sync_threads) { + return true; + } } } } } } + return false; + } + + void VisitBlock(ir::stmt::BlockRef block) { + std::vector stmts = block->stmts(); + std::vector new_stmts; + for (ir::stmt::StmtRef &stmt : stmts) { + switch (stmt->stmt_type()) { + case ir::StmtNodeTy::IfThenElse: { + const ir::stmt::IfThenElse &if_node = stmt.as(); + if (isDropCandidate(if_node)) { + const ir::stmt::BlockRef true_case = if_node->true_case(); + for (const auto &true_stmt : true_case->stmts()) { + new_stmts.push_back(true_stmt); + } + } else { + new_stmts.push_back(stmt); + } + } break; + case ir::StmtNodeTy::For: { + ir::stmt::For for_stmt = stmt.as(); + VisitBlock(for_stmt->body()); + new_stmts.push_back(stmt); + } break; + case ir::StmtNodeTy::Schedule: { + ir::stmt::Schedule schedule = stmt.as(); + VisitBlock(schedule->body()); + new_stmts.push_back(stmt); + } break; + default: + new_stmts.push_back(stmt); + break; + } + } + block->set_stmts(new_stmts); + } +}; - // Collect all the statements with Block(include Block) to the statement. - std::vector blocked_statement_stack; - }; +LogicalResult CudaSyncThreadsDropIfThenElsePass::Run(ir::stmt::BlockRef block) { + DropIfThenElseMutator mutator; + mutator(block); + return LogicalResult::success(); +} - Mutator()(fn); +std::unique_ptr CreateCudaSyncThreadsDropIfThenElsePass() { + return std::make_unique(); } -class RestructureVarNodes : public ir::IRMutator<> { +class RestructureVarNodes : public ir::IRMutator<>, + public ir::stmt::StmtMutator<> { public: - void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + void operator()(ir::stmt::BlockRef block) { VisitBlock(block); } private: void Visit(const ir::Load *load, Expr *op) override { @@ -219,35 +295,74 @@ class RestructureVarNodes : public ir::IRMutator<> { IRMutator::Visit(load, op); } - void Visit(const ir::Store *store, Expr *op) override { + void VisitStmt(ir::stmt::Store stmt) override { std::vector indices_copied; - for (const ir::Expr &indice : store->indices) { + for (const ir::Expr &indice : stmt->indices()) { indices_copied.push_back(ir::ir_utils::IRCopy(indice)); } - op->As()->indices = indices_copied; + stmt->set_indices(indices_copied); - IRMutator::Visit(store, op); + ir::Expr value = stmt->value(); + IRMutator::Visit(&value, &value); + stmt->set_value(value); } + + void VisitStmt(ir::stmt::For stmt) override { operator()(stmt->body()); } + + void VisitStmt(ir::stmt::IfThenElse stmt) override { + operator()(stmt->true_case()); + if (stmt->false_case().defined()) { + operator()(stmt->false_case()); + } + } + + void VisitStmt(ir::stmt::Schedule stmt) override { operator()(stmt->body()); } + + void VisitStmt(ir::stmt::Let stmt) override { + ir::Expr body = stmt->body(); + IRMutator::Visit(&body, &body); + stmt->set_body(body); + } + + void VisitStmt(ir::stmt::Alloc) override {} + + void VisitStmt(ir::stmt::Evaluate) override {} + + void VisitStmt(ir::stmt::Free) override {} }; -class ReplaceIndexToBindExpr : public ir::IRMutator<> { +class ReplaceIndexToBindExpr { public: - void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + void operator()(ir::stmt::BlockRef block) { + for (ir::stmt::StmtRef stmt : block->stmts()) { + switch (stmt->stmt_type()) { + case ir::StmtNodeTy::For: { + operator()(stmt.as()->body()); + break; + } + case ir::StmtNodeTy::Schedule: { + VisitStmt(stmt.as()); + break; + } + case ir::StmtNodeTy::IfThenElse: { + ir::stmt::IfThenElse if_node = stmt.as(); + operator()(if_node->true_case()); + if (if_node->false_case().defined()) { + operator()(if_node->false_case()); + } + break; + } + default: + break; + } + } + } private: - void Visit(const ir::ScheduleBlockRealize *op, Expr *expr) override { - ir::ScheduleBlockRealize *schedule_block_realize = - expr->As(); - PADDLE_ENFORCE_NOT_NULL( - schedule_block_realize->schedule_block.As(), - ::common::errors::InvalidArgument( - "The type of schedule block realize should be ScheduleBlock!")); - std::vector iter_values = schedule_block_realize->iter_values; - ir::Expr body = - schedule_block_realize->schedule_block.As()->body; - std::vector iter_vars = - schedule_block_realize->schedule_block.As() - ->iter_vars; + void VisitStmt(ir::stmt::Schedule stmt) { + std::vector iter_values = stmt->iter_values(); + std::vector iter_vars = stmt->iter_vars(); + ir::stmt::BlockRef body = stmt->body(); PADDLE_ENFORCE_EQ(iter_values.size(), iter_vars.size(), @@ -257,24 +372,46 @@ class ReplaceIndexToBindExpr : public ir::IRMutator<> { iter_values.size(), iter_vars.size())); for (int idx = 0; idx < iter_values.size(); ++idx) { - ReplaceVarWithExpr(&body, iter_vars[idx], iter_values[idx]); + ReplaceVarWithExpr( + body, iter_vars[idx], iter_values[idx]); } - ir::IRMutator<>::Visit(&body, &body); + stmt->set_body(body); + operator()(stmt->body()); } }; -class ReplaceLoopVarToGpu : public ir::IRMutator<> { +class ReplaceLoopVarToGpu { public: - void operator()(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + void operator()(ir::stmt::BlockRef block) { + std::vector stmts = block->stmts(); + for (ir::stmt::StmtRef stmt : stmts) { + switch (stmt->stmt_type()) { + case ir::StmtNodeTy::For: { + VisitStmt(stmt.as()); + break; + } + case ir::StmtNodeTy::Schedule: { + operator()(stmt.as()->body()); + break; + } + case ir::StmtNodeTy::IfThenElse: { + ir::stmt::IfThenElse if_node = stmt.as(); + operator()(if_node->true_case()); + if (if_node->false_case().defined()) { + operator()(if_node->false_case()); + } + break; + } + default: + break; + } + } + block->set_stmts(stmts); + } private: - void Visit(const ir::For *op, Expr *expr) override { - auto for_ir = expr->As(); - PADDLE_ENFORCE_NOT_NULL(for_ir, - ::common::errors::InvalidArgument( - "The type of expression should be For!")); - - auto bind_info = for_ir->bind_info(); + void VisitStmt(ir::stmt::For stmt) { + auto bind_info = stmt->bind_info(); std::string var_name = ""; if (bind_info.offset <= 0) @@ -283,44 +420,49 @@ class ReplaceLoopVarToGpu : public ir::IRMutator<> { var_name = "y"; else if (bind_info.offset == 2) var_name = "z"; - if (for_ir->is_gpu_block_binded()) { + if (stmt->is_gpu_block_binded()) { var_name = "blockIdx." + var_name; - optim::ReplaceVarWithExpr( - expr, op->loop_var, ir::Expr(ir::Var(var_name))); - } else if (for_ir->is_gpu_thread_binded()) { + optim::ReplaceVarWithExpr( + stmt, stmt->loop_var(), ir::Expr(ir::Var(var_name))); + } else if (stmt->is_gpu_thread_binded()) { var_name = "threadIdx." + var_name; - optim::ReplaceVarWithExpr( - expr, op->loop_var, ir::Expr(ir::Var(var_name))); + optim::ReplaceVarWithExpr( + stmt, stmt->loop_var(), ir::Expr(ir::Var(var_name))); } - ir::IRMutator<>::Visit(&for_ir->body, &for_ir->body); - } - void Visit(const ir::PolyFor *op, Expr *expr) override { - PADDLE_THROW(::common::errors::InvalidArgument("Unknown PolyFor!")); + operator()(stmt->body()); } }; -class SharedAxisVisitor : public ir::IRMutator<> { +class SharedAxisVisitor : public ir::IRMutator<>, + public ir::stmt::StmtMutator<> { public: void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + void operator()(ir::stmt::BlockRef block) { + ir::stmt::StmtMutator<>::VisitBlock(block); + } private: - void Visit(const ir::Store *op, Expr *expr) override { - auto store = expr->As(); - if (!store->tensor.as_tensor_ref()->buffer.defined()) { + void VisitStmt(ir::stmt::Store stmt) override { + if (!stmt->tensor().as_tensor_ref()->buffer.defined()) { return; } - if (store->tensor.as_tensor_ref()->buffer->memory_type == + if (stmt->tensor().as_tensor_ref()->buffer->memory_type == ir::MemoryType::GPUShared) { - for (auto &indice : store->indices) { - for (auto axis : gpu_axis) { - optim::ReplaceVarWithExpr(&indice, ir::Var(axis), ir::Expr(0)); + std::vector indices = stmt->indices(); + for (ir::Expr &index : indices) { + for (const std::string &axis : gpu_axis) { + optim::ReplaceVarWithExpr( + &index, ir::Var(axis), ir::Expr(0)); } - indice = cinn::common::AutoSimplify(indice); + index = cinn::optim::ArithSimplify(index); } + stmt->set_indices(indices); } - ir::IRMutator<>::Visit(op, expr); + ir::Expr value = stmt->value(); + ir::IRMutator<>::Visit(&value, &value); + stmt->set_value(value); } void Visit(const ir::Load *op, Expr *expr) override { @@ -334,41 +476,92 @@ class SharedAxisVisitor : public ir::IRMutator<> { if (load->tensor.as_tensor_ref()->buffer->memory_type == ir::MemoryType::GPUShared) { - for (auto &indice : load->indices) { - for (auto axis : gpu_axis) { - optim::ReplaceVarWithExpr(&indice, ir::Var(axis), ir::Expr(0)); + for (auto &index : load->indices) { + for (const std::string &axis : gpu_axis) { + optim::ReplaceVarWithExpr( + &index, ir::Var(axis), ir::Expr(0)); } - indice = cinn::common::AutoSimplify(indice); + index = cinn::optim::ArithSimplify(index); } } ir::IRMutator<>::Visit(op, expr); } + void VisitStmt(ir::stmt::For stmt) override { + ir::Expr min = stmt->min(); + ir::Expr extent = stmt->extent(); + operator()(&min); + operator()(&extent); + stmt->set_min(min); + stmt->set_extent(extent); + operator()(stmt->body()); + } + + void VisitStmt(ir::stmt::IfThenElse stmt) override { + ir::Expr condition = stmt->condition(); + operator()(&condition); + stmt->set_condition(condition); + + operator()(stmt->true_case()); + if (stmt->false_case().defined()) { + operator()(stmt->false_case()); + } + } + + void VisitStmt(ir::stmt::Schedule stmt) override { + std::vector iter_values = stmt->iter_values(); + for (ir::Expr &iter_value : iter_values) { + operator()(&iter_value); + } + stmt->set_iter_values(iter_values); + operator()(stmt->body()); + } + + void VisitStmt(ir::stmt::Let stmt) override { + ir::Expr body = stmt->body(); + ir::IRMutator<>::Visit(&body, &body); + stmt->set_body(body); + } + + void VisitStmt(ir::stmt::Alloc) override {} + + void VisitStmt(ir::stmt::Evaluate) override {} + + void VisitStmt(ir::stmt::Free) override {} + const std::vector gpu_axis = { "blockIdx.x", "blockIdx.y", "blockIdx.z"}; }; -class LocalAxisVisitor : public ir::IRMutator<> { +class LocalAxisVisitor : public ir::IRMutator<>, + public ir::stmt::StmtMutator<> { public: void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + void operator()(ir::stmt::BlockRef block) { + ir::stmt::StmtMutator<>::VisitBlock(block); + } private: - void Visit(const ir::Store *op, Expr *expr) override { - auto store = expr->As(); + void VisitStmt(ir::stmt::Store stmt) override { + ir::Expr value = stmt->value(); + operator()(&value); + stmt->set_value(value); - ir::IRMutator<>::Visit(op, expr); - if (!store->tensor.as_tensor_ref()->buffer.defined()) { + if (!stmt->tensor().as_tensor_ref()->buffer.defined()) { return; } - if (store->tensor.as_tensor_ref()->buffer->memory_type == + if (stmt->tensor().as_tensor_ref()->buffer->memory_type == ir::MemoryType::GPULocal) { - for (auto &indice : store->indices) { - for (auto axis : gpu_axis) { - optim::ReplaceVarWithExpr(&indice, ir::Var(axis), ir::Expr(0)); + std::vector indices = stmt->indices(); + for (ir::Expr &index : indices) { + for (const std::string &axis : gpu_axis) { + optim::ReplaceVarWithExpr( + &index, ir::Var(axis), ir::Expr(0)); } - indice = cinn::common::AutoSimplify(indice); + index = cinn::optim::ArithSimplify(index); } + stmt->set_indices(indices); } } @@ -384,16 +577,46 @@ class LocalAxisVisitor : public ir::IRMutator<> { if (load->tensor.as_tensor_ref()->buffer->memory_type == ir::MemoryType::GPULocal) { - for (auto &indice : load->indices) { - for (auto axis : gpu_axis) { - optim::ReplaceVarWithExpr(&indice, ir::Var(axis), ir::Expr(0)); + for (ir::Expr &index : load->indices) { + for (const std::string &axis : gpu_axis) { + optim::ReplaceVarWithExpr(&index, ir::Var(axis), ir::Expr(0)); } - indice = cinn::common::AutoSimplify(indice); + index = cinn::optim::ArithSimplify(index); } } ir::IRMutator<>::Visit(op, expr); } + void VisitStmt(ir::stmt::For stmt) override { operator()(stmt->body()); } + + void VisitStmt(ir::stmt::IfThenElse stmt) override { + operator()(stmt->true_case()); + if (stmt->false_case().defined()) { + operator()(stmt->false_case()); + } + } + + void VisitStmt(ir::stmt::Schedule stmt) override { + std::vector iter_values = stmt->iter_values(); + for (ir::Expr &iter_value : iter_values) { + operator()(&iter_value); + } + stmt->set_iter_values(iter_values); + operator()(stmt->body()); + } + + void VisitStmt(ir::stmt::Let stmt) override { + ir::Expr body = stmt->body(); + ir::IRMutator<>::Visit(&body, &body); + stmt->set_body(body); + } + + void VisitStmt(ir::stmt::Alloc) override {} + + void VisitStmt(ir::stmt::Evaluate) override {} + + void VisitStmt(ir::stmt::Free) override {} + const std::vector gpu_axis = {"blockIdx.x", "blockIdx.y", "blockIdx.z", @@ -402,25 +625,32 @@ class LocalAxisVisitor : public ir::IRMutator<> { "threadIdx.z"}; }; -class ReplaceUnitVarToZero : public ir::IRMutator<> { +class ReplaceUnitVarToZero : public ir::IRMutator<>, + public ir::stmt::StmtMutator<> { public: void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + void operator()(ir::stmt::BlockRef block) { + ir::stmt::StmtMutator<>::VisitBlock(block); + } private: - void Visit(const ir::Store *op, Expr *expr) override { - auto store = expr->As(); - if (!store->tensor.as_tensor_ref()->buffer.defined()) { + void VisitStmt(ir::stmt::Store stmt) override { + if (!stmt->tensor().as_tensor_ref()->buffer.defined()) { return; } - auto &indices = store->indices; - for (auto &indice : indices) { - for (auto var_ : loop_var_) { - optim::ReplaceVarWithExpr(&indice, ir::Var(var_), ir::Expr(0)); + std::vector indices = stmt->indices(); + for (ir::Expr &index : indices) { + for (const std::string &var_ : loop_var_) { + optim::ReplaceVarWithExpr( + &index, ir::Var(var_), ir::Expr(0)); } - indice = cinn::common::AutoSimplify(indice); + index = cinn::optim::ArithSimplify(index); } - ir::IRMutator<>::Visit(op, expr); + stmt->set_indices(indices); + ir::Expr value = stmt->value(); + operator()(&value); + stmt->set_value(value); } void Visit(const ir::Load *op, Expr *expr) override { @@ -430,87 +660,116 @@ class ReplaceUnitVarToZero : public ir::IRMutator<> { } auto &indices = load->indices; - for (auto &indice : indices) { - for (auto var_ : loop_var_) { - optim::ReplaceVarWithExpr(&indice, ir::Var(var_), ir::Expr(0)); + for (auto &index : indices) { + for (const std::string &var_ : loop_var_) { + optim::ReplaceVarWithExpr( + &index, ir::Var(var_), ir::Expr(0)); } - indice = cinn::common::AutoSimplify(indice); + index = cinn::optim::ArithSimplify(index); } ir::IRMutator<>::Visit(op, expr); } - void Visit(const ir::For *op, Expr *expr) override { - PADDLE_ENFORCE_NOT_NULL(expr->As(), - ::common::errors::InvalidArgument( - "The type of expression should be For!")); - auto for_ir = expr->As(); - auto var_name = for_ir->loop_var->name; - auto extent_i = for_ir->extent; + void VisitStmt(ir::stmt::For stmt) override { + auto var_name = stmt->loop_var()->name; + auto extent_i = stmt->extent(); if (extent_i.is_constant() && extent_i.as_int64() == 1) loop_var_.insert(var_name); - ir::IRMutator<>::Visit(op, expr); + operator()(stmt->body()); loop_var_.erase(var_name); } + + void VisitStmt(ir::stmt::IfThenElse stmt) override { + operator()(stmt->true_case()); + if (stmt->false_case().defined()) { + operator()(stmt->false_case()); + } + } + + void VisitStmt(ir::stmt::Schedule stmt) override { + std::vector iter_values = stmt->iter_values(); + for (ir::Expr &iter_value : iter_values) { + operator()(&iter_value); + } + stmt->set_iter_values(iter_values); + operator()(stmt->body()); + } + + void VisitStmt(ir::stmt::Let stmt) override { + ir::Expr body = stmt->body(); + ir::IRMutator<>::Visit(&body, &body); + stmt->set_body(body); + } + + void VisitStmt(ir::stmt::Alloc) override {} + + void VisitStmt(ir::stmt::Evaluate) override {} + + void VisitStmt(ir::stmt::Free) override {} + std::unordered_set loop_var_; }; -void OptimizeExprGPU(Expr *expr) { - VLOG(4) << "Before Optimize Expr:\n" << *expr; +// void OptimizeExprGPU(Expr *expr) { +void OptimizeExprGPU(ir::stmt::BlockRef block) { + VLOG(4) << "Before Optimize Expr:\n" << block; + // ir::stmt::BlockRef block = ir::ConvertExprBlockToStmtBlock(*expr); // Make independent copies for each load/store's indices to prevent cross // modification in later passes. RestructureVarNodes restructure_var_nodes; - restructure_var_nodes(expr); + restructure_var_nodes(block); - // Replace iter_vars used in ScheduleBlocks to their corresponding iter_values - // in ScheduleBlockRealizes. + // Replace iter_vars used in ScheduleBlocks to their corresponding + // iter_values in ScheduleBlockRealizes. ReplaceIndexToBindExpr replace_index_to_bind_expr; - replace_index_to_bind_expr(expr); + replace_index_to_bind_expr(block); // resize buffer axis BlockPassManager pass_manager; - ir::stmt::BlockRef _block = ir::ConvertExprBlockToStmtBlock(*expr); pass_manager.AddPass(optim::CreateUpdateBufferAxisPass()); - pass_manager.Run(_block); - ir::Expr new_expr = ir::ConvertStmtBlockToExprBlock(_block); - *expr = new_expr; + pass_manager.Run(block); + ir::Expr new_expr = ir::ConvertStmtBlockToExprBlock(block); - // Replace variables bound on block/thread to the actual blockIdx/threadIdx. + // Replace variables bound on block/thread to the actual + // blockIdx/threadIdx. + LOG(INFO) << "Before ReplaceLoopVarToGpu: \n" << block; ReplaceLoopVarToGpu replace_loop_var_to_gpu; - replace_loop_var_to_gpu(expr); + replace_loop_var_to_gpu(block); + LOG(INFO) << "After ReplaceLoopVarToGpu: \n" << block; - // Replace blockIdx in shared memory's indices to zero, because shared memory - // cannot be accessed from another block. + // Replace blockIdx in shared memory's indices to zero, because shared + // memory cannot be accessed from another block. SharedAxisVisitor shared_axis_visitor; - shared_axis_visitor(expr); + shared_axis_visitor(block); - // Replace blockIdx/threadIdx in local buffer's indices to zero, because local - // buffers cannot be accessed from another block/thread. + // Replace blockIdx/threadIdx in local buffer's indices to zero, because + // local buffers cannot be accessed from another block/thread. LocalAxisVisitor local_axis_visitor; - local_axis_visitor(expr); + local_axis_visitor(block); // Replace variables that are in range [0, 1) to zero. ReplaceUnitVarToZero replace_unit_var_to_zero; - replace_unit_var_to_zero(expr); - VLOG(10) << "After ReplaceUnitVarToZero: \n" << *expr; - ir::stmt::BlockRef func_body = ir::ConvertExprBlockToStmtBlock(*expr); - EliminateCommonFactorOfLocalIndex(func_body); - *expr = ir::ConvertStmtBlockToExprBlock(func_body); - VLOG(10) << "After EliminateCommonFactorOfLocalIndex: \n" << *expr; + replace_unit_var_to_zero(block); + + EliminateCommonFactorOfLocalIndex(block); + VLOG(10) << "After EliminateCommonFactorOfLocalIndex: \n" << block; + + ir::Expr expr = ir::ConvertStmtBlockToExprBlock(block); + + ResizeBufferToMaxVarRange(&expr); - ResizeBufferToMaxVarRange(expr); + block = ir::ConvertExprBlockToStmtBlock(expr); if (FLAGS_cinn_longlong2int) { - ir::stmt::BlockRef block = ir::ConvertExprBlockToStmtBlock(*expr); VLOG(10) << "Before CastLonglong2Int: \n" << block; TryCastLonglong2Int(block); VLOG(10) << "After CastLonglong2Int: \n" << block; - *expr = ir::ConvertStmtBlockToExprBlock(block); } - VLOG(4) << "After Optimize Expr: \n" << *expr; + VLOG(4) << "After Optimize Expr: \n" << block; } } // namespace optim diff --git a/paddle/cinn/optim/transform_gpu_forloop.h b/paddle/cinn/optim/transform_gpu_forloop.h index c9f3aadaccb01b..b48162f145b1ae 100644 --- a/paddle/cinn/optim/transform_gpu_forloop.h +++ b/paddle/cinn/optim/transform_gpu_forloop.h @@ -13,14 +13,10 @@ // limitations under the License. #pragma once -#include -#include -#include #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/lowered_func.h" -#include "paddle/cinn/poly/isl_utils.h" -#include "paddle/cinn/poly/stage.h" +#include "paddle/cinn/pass/pass.h" namespace cinn { namespace optim { @@ -40,7 +36,7 @@ namespace optim { * 2) Buffer and Memory Access Optimization * 3) Expression Simplification and Type Casting */ -void OptimizeExprGPU(Expr* expr); +void OptimizeExprGPU(ir::stmt::BlockRef func_body); /** * Remove the GPU block/thread-bound For loops, add IfThenElse guards if needed. @@ -70,7 +66,13 @@ void OptimizeExprGPU(Expr* expr); * * @param fn The LoweredFunc to process. */ -void RemoveGpuForLoops(ir::LoweredFunc fn); +class RemoveGpuForLoopsPass : public FuncPass { + public: + RemoveGpuForLoopsPass() : FuncPass("remove_gpu_for_loops") {} + + LogicalResult Run(ir::LoweredFunc fn) override; +}; +std::unique_ptr CreateRemoveGpuForLoopsPass(); /** * Removes conditional wrappers around CUDA thread synchronization calls. @@ -101,7 +103,14 @@ void RemoveGpuForLoops(ir::LoweredFunc fn); * => * if (xxxx > 0) { __syncthreads(); } */ -void CudaSyncThreadsDropIfThenElse(ir::LoweredFunc fn); +class CudaSyncThreadsDropIfThenElsePass : public BlockPass { + public: + CudaSyncThreadsDropIfThenElsePass() + : BlockPass("cuda_sync_threads_drop_ifthenelse") {} + + LogicalResult Run(ir::stmt::BlockRef block) override; +}; +std::unique_ptr CreateCudaSyncThreadsDropIfThenElsePass(); } // namespace optim } // namespace cinn diff --git a/paddle/cinn/pybind/lang.cc b/paddle/cinn/pybind/lang.cc index e486b1915afa46..76bc54f0ba34ac 100644 --- a/paddle/cinn/pybind/lang.cc +++ b/paddle/cinn/pybind/lang.cc @@ -23,6 +23,7 @@ #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule/ir_schedule_util.h" #include "paddle/cinn/ir/tensor.h" +#include "paddle/cinn/ir/utils/stmt_converter.h" #include "paddle/cinn/lang/buffer.h" #include "paddle/cinn/lang/builtin.h" #include "paddle/cinn/lang/compute.h" @@ -134,7 +135,15 @@ void BindModule(py::module *m) { [&](common::NVGPUArch) { #ifdef CINN_WITH_CUDA ir::SetCudaAxisInfo(func); - optim::OptimizeExprGPU(&(func->body)); + ir::stmt::BlockRef func_body_block = + ir::ConvertExprBlockToStmtBlock(func->body); + VLOG(6) << " Before OptimizeExprGPU in lang: \n" + << func_body_block; + optim::OptimizeExprGPU(func_body_block); + VLOG(6) << "After OptimizeExprGPU in lang: \n" + << func_body_block; + func->body = + ir::ConvertStmtBlockToExprBlock(func_body_block); #endif }, [&](std::variant