From 43d14f2774cd541822453a1db6962ff3498f0ada Mon Sep 17 00:00:00 2001 From: ZhouXin Date: Thu, 19 Dec 2024 02:27:26 +0000 Subject: [PATCH 1/5] Update remove schedule block --- paddle/cinn/optim/CMakeLists.txt | 2 +- paddle/cinn/optim/optimize.cc | 5 +- .../cinn/optim/remove_schedule_block_pass.cc | 58 +++++++++ .../cinn/optim/remove_schedule_block_pass.h | 29 +++++ paddle/cinn/optim/replace_var_with_expr.cc | 112 ++++++++++++++++++ paddle/cinn/optim/replace_var_with_expr.h | 10 ++ 6 files changed, 213 insertions(+), 3 deletions(-) create mode 100644 paddle/cinn/optim/remove_schedule_block_pass.cc create mode 100644 paddle/cinn/optim/remove_schedule_block_pass.h diff --git a/paddle/cinn/optim/CMakeLists.txt b/paddle/cinn/optim/CMakeLists.txt index 793e843023f47a..6d2ae9b159df89 100755 --- a/paddle/cinn/optim/CMakeLists.txt +++ b/paddle/cinn/optim/CMakeLists.txt @@ -24,7 +24,7 @@ gather_srcs( lower_intrin.cc cast_bool_to_int8.cc var_mod_simplify.cc - remove_schedule_block.cc + remove_schedule_block_pass.cc replace_cross_block_reduction.cc replace_cross_thread_reduction.cc replace_mod_to_max.cc diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index d69b4d3e06ad2c..7b5dd89d045690 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -31,7 +31,7 @@ #include "paddle/cinn/optim/lower_intrin.h" #include "paddle/cinn/optim/map_extern_call.h" #include "paddle/cinn/optim/rearrange_load_instruction.h" -#include "paddle/cinn/optim/remove_schedule_block.h" +#include "paddle/cinn/optim/remove_schedule_block_pass.h" #include "paddle/cinn/optim/replace_const_param_to_integer.h" #include "paddle/cinn/optim/replace_cross_block_reduction.h" #include "paddle/cinn/optim/replace_cross_thread_reduction.h" @@ -116,7 +116,8 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn, Simplify(&copied->body); VLOG(10) << "After Optimize Simplify" << copied; - RemoveScheduleBlock(&copied->body); + pass_manager.AddPass(CreateRemoveScheduleBlockPass()); + pass_manager.Run(copied); VLOG(10) << "After RemoveScheduleBlock:" << copied; LowerIntrin(&copied->body, target); diff --git a/paddle/cinn/optim/remove_schedule_block_pass.cc b/paddle/cinn/optim/remove_schedule_block_pass.cc new file mode 100644 index 00000000000000..0144a81395c809 --- /dev/null +++ b/paddle/cinn/optim/remove_schedule_block_pass.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/optim/remove_schedule_block_pass.h" +#include "paddle/cinn/optim/replace_var_with_expr.h" +namespace cinn { +namespace optim { + +LogicalResult RemoveScheduleBlockPass::Run(ir::stmt::BlockRef block) { + auto merge_stmt_vector = [&](std::vector& dest, + const std::vector& source) { + dest.insert(dest.end(), source.begin(), source.end()); + }; + + const auto& stmts = block->stmts(); + std::vector new_stmts; + for (auto& stmt : stmts) { + if (!stmt.isa()) { + new_stmts.push_back(stmt); + continue; + } + auto schedule_stmt = stmt.as(); + auto iter_values = schedule_stmt->iter_values(); + auto body = schedule_stmt->body(); + auto iter_vars = schedule_stmt->iter_vars(); + PADDLE_ENFORCE_EQ(iter_vars.size(), + iter_values.size(), + ::common::errors::InvalidArgument( + "The size of iter vars and iter values is not equal," + "where iter vars:%d but iter values:%d.", + iter_vars.size(), + iter_values.size())); + for (int i = 0; i < iter_vars.size(); i++) { + optim::ReplaceVarWithExprInBlock(body, iter_vars[i], iter_values[i], ""); + } + merge_stmt_vector(new_stmts, body->stmts()); + } + block->set_stmts(new_stmts); + return LogicalResult::success(); +} + +std::unique_ptr CreateRemoveScheduleBlockPass() { + return std::make_unique(); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/remove_schedule_block_pass.h b/paddle/cinn/optim/remove_schedule_block_pass.h new file mode 100644 index 00000000000000..920336b555a6b2 --- /dev/null +++ b/paddle/cinn/optim/remove_schedule_block_pass.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/pass/pass.h" + +namespace cinn { +namespace optim { +class RemoveScheduleBlockPass : public BlockPass { + public: + RemoveScheduleBlockPass() : BlockPass("remove_schedule_block") {} + LogicalResult Run(ir::stmt::BlockRef block) override; +}; + +std::unique_ptr CreateRemoveScheduleBlockPass(); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/replace_var_with_expr.cc b/paddle/cinn/optim/replace_var_with_expr.cc index 61f4db4e7b89ac..333f377da4c88f 100644 --- a/paddle/cinn/optim/replace_var_with_expr.cc +++ b/paddle/cinn/optim/replace_var_with_expr.cc @@ -27,6 +27,102 @@ namespace cinn { namespace optim { +struct ReplaceVarWithExprStmtMutator : public ir::stmt::StmtMutator<> { + ReplaceVarWithExprStmtMutator(const Var& var, + const Expr& expr, + const std::string& tensor_name) + : var_(var), expr_(expr), tensor_name_(tensor_name) {} + + void operator()(ir::stmt::StmtRef stmt) { + if (tensor_name_.empty()) visit_all_ = true; + ir::stmt::StmtMutator<>::VisitStmt(stmt); + } + + void operator()(ir::stmt::BlockRef block) { + if (tensor_name_.empty()) visit_all_ = true; + ir::stmt::StmtMutator<>::VisitBlock(block); + } + + private: + bool ShouldReplaceExpr(const Expr& expr) { + if (!expr.is_var()) { + return false; + } + if (expr.as_var()->name == var_->name && (do_replace_ || visit_all_)) { + return true; + } + return false; + } + + void VisitExpr(ir::Expr* expr) { + if (expr->as_var()->name == var_->name && (do_replace_ || visit_all_)) { + auto copied = ir::ir_utils::IRCopy(expr_); + *expr = copied; + } + } + + void VisitStmt(ir::stmt::Let stmt) override { return; } + + void VisitStmt(ir::stmt::Store stmt) override { + auto* tensor = stmt->tensor().as_tensor(); + if (tensor && tensor->name == tensor_name_) { + do_replace_ = true; + } else { + do_replace_ = false; + } + std::vector new_indices = stmt->indices(); + for (size_t i = 0; i < new_indices.size(); ++i) { + if (ShouldReplaceExpr(new_indices[i])) { + auto copied = ir::ir_utils::IRCopy(expr_); + new_indices[i] = copied; + } + } + stmt->set_indices(new_indices); + do_replace_ = false; + if (ShouldReplaceExpr(stmt->tensor())) { + auto copied = ir::ir_utils::IRCopy(expr_); + stmt->set_tensor(copied); + } + + if (ShouldReplaceExpr(stmt->value())) { + auto copied = ir::ir_utils::IRCopy(expr_); + stmt->set_value(copied); + } + } + + void VisitStmt(ir::stmt::For stmt) override { + if (ShouldReplaceExpr(stmt->min())) { + auto copied = ir::ir_utils::IRCopy(expr_); + stmt->set_min(copied); + } + if (ShouldReplaceExpr(stmt->extent())) { + auto copied = ir::ir_utils::IRCopy(expr_); + stmt->set_extent(copied); + } + VisitBlock(stmt->body()); + if (stmt->loop_var()->name == var_->name && expr_.as_var() && visit_all_) { + stmt->set_loop_var(expr_.as_var_ref()); + } + } + + void VisitStmt(ir::stmt::Alloc stmt) override { return; } + + void VisitStmt(ir::stmt::Free stmt) override { return; } + + void VisitStmt(ir::stmt::IfThenElse stmt) override { return; } + + void VisitStmt(ir::stmt::Evaluate) override { return; } + + void VisitStmt(ir::stmt::Schedule stmt) override { return; } + + private: + bool do_replace_{false}; + bool visit_all_{false}; + const Var& var_; + const Expr& expr_; + const std::string& tensor_name_; +}; + struct ReplaceVarWithExprMutator : public ir::IRMutator<> { ReplaceVarWithExprMutator(const Var& var, const Expr& expr, @@ -107,6 +203,22 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<> { const std::string& tensor_name_; }; +void ReplaceVarWithExprInStmt(ir::stmt::StmtRef source, + const Var& var, + const Expr& expr, + const std::string& tensor_name) { + ReplaceVarWithExprStmtMutator mutator(var, expr, tensor_name); + mutator(source); +} + +void ReplaceVarWithExprInBlock(ir::stmt::BlockRef source, + const Var& var, + const Expr& expr, + const std::string& tensor_name) { + ReplaceVarWithExprStmtMutator mutator(var, expr, tensor_name); + mutator(source); +} + void ReplaceVarWithExpr(Expr* source, const Var& var, const Expr& expr, diff --git a/paddle/cinn/optim/replace_var_with_expr.h b/paddle/cinn/optim/replace_var_with_expr.h index c56848f358052d..c226598bddd587 100644 --- a/paddle/cinn/optim/replace_var_with_expr.h +++ b/paddle/cinn/optim/replace_var_with_expr.h @@ -19,6 +19,7 @@ #include #include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/stmt.h" namespace cinn { namespace optim { @@ -59,6 +60,15 @@ void ReplaceVarWithExpr(Expr *source, const Expr &expr, const std::string &tensor_name = ""); +void ReplaceVarWithExprInStmt(ir::stmt::StmtRef source, + const Var &var, + const Expr &expr, + const std::string &tensor_name = ""); +void ReplaceVarWithExprInBlock(ir::stmt::BlockRef source, + const Var &var, + const Expr &expr, + const std::string &tensor_name = ""); + /** * Collect the specific tensor's indices. * @param tensor_name The specific tensor's name. From bc5c665dcfd657e005d30c45d52ca4b8d64d311a Mon Sep 17 00:00:00 2001 From: ZhouXin Date: Thu, 19 Dec 2024 03:27:55 +0000 Subject: [PATCH 2/5] delete old version of removeScheduleBlock --- paddle/cinn/optim/remove_schedule_block.cc | 63 ---------------------- paddle/cinn/optim/remove_schedule_block.h | 60 --------------------- 2 files changed, 123 deletions(-) delete mode 100644 paddle/cinn/optim/remove_schedule_block.cc delete mode 100644 paddle/cinn/optim/remove_schedule_block.h diff --git a/paddle/cinn/optim/remove_schedule_block.cc b/paddle/cinn/optim/remove_schedule_block.cc deleted file mode 100644 index 397b8f9399b379..00000000000000 --- a/paddle/cinn/optim/remove_schedule_block.cc +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/optim/remove_schedule_block.h" - -#include "paddle/cinn/ir/ir_mutator.h" -#include "paddle/cinn/ir/ir_printer.h" -#include "paddle/cinn/ir/module.h" -#include "paddle/cinn/optim/replace_var_with_expr.h" - -namespace cinn { -namespace optim { - -struct ScheduleBlockRemover : public ir::IRMutator { - void operator()(Expr* expr) { ir::IRMutator::Visit(expr, expr); } - - private: - void Visit(const ir::ScheduleBlockRealize* op, Expr* expr) override { - auto* node = expr->As(); - PADDLE_ENFORCE_NOT_NULL( - node, - ::common::errors::InvalidArgument( - "The expression could not be cast to ir::ScheduleBlockRealize. " - "Please check the expression type.")); - auto& iter_values = node->iter_values; - auto* schedule_block = node->schedule_block.As(); - PADDLE_ENFORCE_NOT_NULL( - schedule_block, - ::common::errors::InvalidArgument( - "The schedule block could not be cast to ir::ScheduleBlock. Please " - "check the schedule block type.")); - auto& iter_vars = schedule_block->iter_vars; - Expr body = schedule_block->body; - PADDLE_ENFORCE_EQ(iter_vars.size(), - iter_values.size(), - ::common::errors::InvalidArgument( - "The size of iter vars and iter values is not equal," - "where iter vars:%d but iter values:%d.", - iter_vars.size(), - iter_values.size())); - for (int i = 0; i < iter_vars.size(); i++) { - optim::ReplaceVarWithExpr(&body, iter_vars[i], iter_values[i]); - } - *expr = body; - IRMutator::Visit(expr, expr); - } -}; - -void RemoveScheduleBlock(ir::Expr* expr) { ScheduleBlockRemover()(expr); } - -} // namespace optim -} // namespace cinn diff --git a/paddle/cinn/optim/remove_schedule_block.h b/paddle/cinn/optim/remove_schedule_block.h deleted file mode 100644 index 73ad280acb30c9..00000000000000 --- a/paddle/cinn/optim/remove_schedule_block.h +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/** - * This file implements the strategy to remove the unnecessary nested block. - */ -#pragma once -#include - -#include "paddle/cinn/common/common.h" -#include "paddle/cinn/ir/ir.h" -#include "paddle/common/enforce.h" -namespace cinn { -namespace optim { - -/** - * Removes ScheduleBlock nodes from the IR tree. - * - * This pass is applicable in scenarios where ScheduleBlock nodes are present in - * the IR tree but are no longer needed for further optimization. - * - * When applied, this pass will traverse the IR tree and replace each - * ScheduleBlockRealize node with its body. During this process, it will also - * replace the iter_vars in the body with their corresponding iter_values. This - * effectively removes the ScheduleBlock structure while preserving the - * computational logic within it. - * - * Performance impact: This pass addresses the overhead of maintaining - * ScheduleBlock structures in the IR. By removing these structures, it - * simplifies the IR, which can lead to faster subsequent passes and potentially - * more efficient code generation. - * - * Examples: - * 1. Basic ScheduleBlock removal: - * Input IR: - * ScheduleBlockRealize { - * iter_vars: [i, j] - * iter_values: [0, 1] - * ScheduleBlock { - * body: A[i, j] = B[i, j] + C[i, j] - * } - * } - * Output IR: - * A[0, 1] = B[0, 1] + C[0, 1] - */ -void RemoveScheduleBlock(ir::Expr *expr); - -} // namespace optim -} // namespace cinn From ad82946473a55cd70f1adadb7f6ee240a78ca43c Mon Sep 17 00:00:00 2001 From: ZhouXin Date: Fri, 20 Dec 2024 02:44:25 +0000 Subject: [PATCH 3/5] Fix Expr case in ReplaceVarWithExpr --- .../cinn/optim/remove_schedule_block_pass.cc | 1 + paddle/cinn/optim/replace_var_with_expr.cc | 140 ++++++++++++++---- 2 files changed, 109 insertions(+), 32 deletions(-) diff --git a/paddle/cinn/optim/remove_schedule_block_pass.cc b/paddle/cinn/optim/remove_schedule_block_pass.cc index 0144a81395c809..5488d16efd799e 100644 --- a/paddle/cinn/optim/remove_schedule_block_pass.cc +++ b/paddle/cinn/optim/remove_schedule_block_pass.cc @@ -18,6 +18,7 @@ namespace cinn { namespace optim { LogicalResult RemoveScheduleBlockPass::Run(ir::stmt::BlockRef block) { + VLOG(4) << "RemoveScheduleBlockPass Run"; auto merge_stmt_vector = [&](std::vector& dest, const std::vector& source) { dest.insert(dest.end(), source.begin(), source.end()); diff --git a/paddle/cinn/optim/replace_var_with_expr.cc b/paddle/cinn/optim/replace_var_with_expr.cc index 333f377da4c88f..e6577c8837f8d2 100644 --- a/paddle/cinn/optim/replace_var_with_expr.cc +++ b/paddle/cinn/optim/replace_var_with_expr.cc @@ -34,13 +34,17 @@ struct ReplaceVarWithExprStmtMutator : public ir::stmt::StmtMutator<> { : var_(var), expr_(expr), tensor_name_(tensor_name) {} void operator()(ir::stmt::StmtRef stmt) { + VLOG(4) << "Enter ReplaceVarWithExprStmtMutator::operator()"; if (tensor_name_.empty()) visit_all_ = true; ir::stmt::StmtMutator<>::VisitStmt(stmt); + VLOG(4) << "Exit ReplaceVarWithExprStmtMutator::operator()"; } void operator()(ir::stmt::BlockRef block) { + VLOG(4) << "Enter ReplaceVarWithExprStmtMutator::operator()"; if (tensor_name_.empty()) visit_all_ = true; ir::stmt::StmtMutator<>::VisitBlock(block); + VLOG(4) << "Exit ReplaceVarWithExprStmtMutator::operator()"; } private: @@ -55,15 +59,37 @@ struct ReplaceVarWithExprStmtMutator : public ir::stmt::StmtMutator<> { } void VisitExpr(ir::Expr* expr) { - if (expr->as_var()->name == var_->name && (do_replace_ || visit_all_)) { - auto copied = ir::ir_utils::IRCopy(expr_); - *expr = copied; + VLOG(4) << "Enter VisitExpr(ir::Expr* expr)"; + if (expr->is_var()) { + VLOG(4) << "Hit var: " << expr->as_var()->name; + if (expr->as_var()->name == var_->name && (do_replace_ || visit_all_)) { + auto copied = ir::ir_utils::IRCopy(expr_); + *expr = copied; + } + return; } + for (auto field : (*expr)->expr_fields()) { + VLOG(4) << "field: " << *field; + VisitExpr(field); + } + VLOG(4) << "Exit VisitExpr(ir::Expr* expr)"; } - void VisitStmt(ir::stmt::Let stmt) override { return; } + void VisitStmt(ir::stmt::Let stmt) override { + VLOG(4) << "Enter VisitStmt(ir::stmt::Let stmt)"; + Expr symbol = stmt->symbol(); + VisitExpr(&symbol); + stmt->set_symbol(symbol); + if (stmt->body().defined()) { + Expr body = stmt->body(); + VisitExpr(&body); + stmt->set_body(body); + } + VLOG(4) << "Exit VisitStmt(ir::stmt::Let stmt)"; + } void VisitStmt(ir::stmt::Store stmt) override { + VLOG(4) << "Enter VisitStmt(ir::stmt::Store stmt)"; auto* tensor = stmt->tensor().as_tensor(); if (tensor && tensor->name == tensor_name_) { do_replace_ = true; @@ -71,49 +97,95 @@ struct ReplaceVarWithExprStmtMutator : public ir::stmt::StmtMutator<> { do_replace_ = false; } std::vector new_indices = stmt->indices(); - for (size_t i = 0; i < new_indices.size(); ++i) { - if (ShouldReplaceExpr(new_indices[i])) { - auto copied = ir::ir_utils::IRCopy(expr_); - new_indices[i] = copied; - } + for (Expr& index : new_indices) { + VLOG(4) << "index: " << index; + VisitExpr(&index); } stmt->set_indices(new_indices); do_replace_ = false; - if (ShouldReplaceExpr(stmt->tensor())) { - auto copied = ir::ir_utils::IRCopy(expr_); - stmt->set_tensor(copied); - } - - if (ShouldReplaceExpr(stmt->value())) { - auto copied = ir::ir_utils::IRCopy(expr_); - stmt->set_value(copied); - } + VLOG(4) << "stmt->tensor(): " << stmt->tensor(); + Expr tensor_expr = stmt->tensor(); + VisitExpr(&tensor_expr); + stmt->set_tensor(tensor_expr); + + Expr value = stmt->value(); + VisitExpr(&value); + stmt->set_value(value); + VLOG(4) << "Exit VisitStmt(ir::stmt::Store stmt)"; } void VisitStmt(ir::stmt::For stmt) override { - if (ShouldReplaceExpr(stmt->min())) { - auto copied = ir::ir_utils::IRCopy(expr_); - stmt->set_min(copied); - } - if (ShouldReplaceExpr(stmt->extent())) { - auto copied = ir::ir_utils::IRCopy(expr_); - stmt->set_extent(copied); - } + VLOG(4) << "Enter VisitStmt(ir::stmt::For stmt)"; + VLOG(4) << "stmt->min(): " << stmt->min(); + Expr min = stmt->min(); + VisitExpr(&min); + VLOG(4) << "stmt->extent(): " << stmt->extent(); + Expr extent = stmt->extent(); + VisitExpr(&extent); + VLOG(4) << "stmt->body(): " << stmt->body(); VisitBlock(stmt->body()); if (stmt->loop_var()->name == var_->name && expr_.as_var() && visit_all_) { - stmt->set_loop_var(expr_.as_var_ref()); + auto copied = ir::ir_utils::IRCopy(expr_); + stmt->set_loop_var(copied.as_var_ref()); } + VLOG(4) << "Exit VisitStmt(ir::stmt::For stmt)"; + } + + void VisitStmt(ir::stmt::Alloc stmt) override { + VLOG(4) << "Enter VisitStmt(ir::stmt::Alloc stmt)"; + return; } - void VisitStmt(ir::stmt::Alloc stmt) override { return; } + void VisitStmt(ir::stmt::Free stmt) override { + VLOG(4) << "Enter VisitStmt(ir::stmt::Free stmt)"; + return; + } - void VisitStmt(ir::stmt::Free stmt) override { return; } + void VisitStmt(ir::stmt::IfThenElse stmt) override { + VLOG(4) << "Enter VisitStmt(ir::stmt::IfThenElse stmt)"; + Expr condition = stmt->condition(); + VisitExpr(&condition); + ir::stmt::BlockRef true_case = stmt->true_case(); + VisitBlock(true_case); + stmt->set_true_case(true_case); + if (stmt->false_case().defined()) { + ir::stmt::BlockRef false_case = stmt->false_case(); + VisitBlock(false_case); + stmt->set_false_case(false_case); + } + } - void VisitStmt(ir::stmt::IfThenElse stmt) override { return; } + void VisitStmt(ir::stmt::Evaluate) override { + VLOG(4) << "Enter VisitStmt(ir::stmt::Evaluate)"; + return; + } - void VisitStmt(ir::stmt::Evaluate) override { return; } + void VisitStmt(ir::stmt::Schedule stmt) override { + VLOG(4) << "Enter VisitStmt(ir::stmt::Schedule stmt)"; + std::vector vars = stmt->iter_vars(); + for (ir::Var& var : vars) { + if (var->lower_bound.defined() && ShouldReplaceExpr(var->lower_bound)) { + auto copied = ir::ir_utils::IRCopy(expr_); + var->lower_bound = copied; + } + if (var->upper_bound.defined() && ShouldReplaceExpr(var->upper_bound)) { + auto copied = ir::ir_utils::IRCopy(expr_); + var->upper_bound = copied; + } + } + std::vector new_read_buffers = stmt->read_buffers(); + for (Expr& read_buffer : new_read_buffers) { + VisitExpr(&read_buffer); + } + stmt->set_read_buffers(new_read_buffers); - void VisitStmt(ir::stmt::Schedule stmt) override { return; } + std::vector new_write_buffers = stmt->write_buffers(); + for (Expr& write_buffer : new_write_buffers) { + VisitExpr(&write_buffer); + } + stmt->set_write_buffers(new_write_buffers); + VisitBlock(stmt->body()); + } private: bool do_replace_{false}; @@ -215,6 +287,10 @@ void ReplaceVarWithExprInBlock(ir::stmt::BlockRef source, const Var& var, const Expr& expr, const std::string& tensor_name) { + VLOG(4) << "Enter ReplaceVarWithExprInBlock"; + VLOG(4) << "source: " << source; + VLOG(4) << "var: " << var; + VLOG(4) << "expr: " << expr; ReplaceVarWithExprStmtMutator mutator(var, expr, tensor_name); mutator(source); } From 38dee97d91c5b48e485c65304f60fd954371c26a Mon Sep 17 00:00:00 2001 From: ZhouXin Date: Fri, 20 Dec 2024 08:30:39 +0000 Subject: [PATCH 4/5] Remove logs, and merge tow replaceVarMutator --- .../cinn/optim/remove_schedule_block_pass.cc | 25 ++-- .../cinn/optim/remove_schedule_block_pass.h | 30 ++++ paddle/cinn/optim/replace_var_with_expr.cc | 130 ++++++------------ paddle/cinn/optim/replace_var_with_expr.h | 16 +-- 4 files changed, 91 insertions(+), 110 deletions(-) diff --git a/paddle/cinn/optim/remove_schedule_block_pass.cc b/paddle/cinn/optim/remove_schedule_block_pass.cc index 5488d16efd799e..c035d5c5a33b0a 100644 --- a/paddle/cinn/optim/remove_schedule_block_pass.cc +++ b/paddle/cinn/optim/remove_schedule_block_pass.cc @@ -16,25 +16,26 @@ #include "paddle/cinn/optim/replace_var_with_expr.h" namespace cinn { namespace optim { +using ir::stmt::BlockRef; +using ir::stmt::StmtRef; LogicalResult RemoveScheduleBlockPass::Run(ir::stmt::BlockRef block) { - VLOG(4) << "RemoveScheduleBlockPass Run"; - auto merge_stmt_vector = [&](std::vector& dest, - const std::vector& source) { + const auto& MergeStmtVector = [&](std::vector& dest, + const std::vector& source) { dest.insert(dest.end(), source.begin(), source.end()); }; - const auto& stmts = block->stmts(); - std::vector new_stmts; - for (auto& stmt : stmts) { + const std::vector& stmts = block->stmts(); + std::vector new_stmts; + for (const StmtRef& stmt : stmts) { if (!stmt.isa()) { new_stmts.push_back(stmt); continue; } - auto schedule_stmt = stmt.as(); - auto iter_values = schedule_stmt->iter_values(); - auto body = schedule_stmt->body(); - auto iter_vars = schedule_stmt->iter_vars(); + const ir::stmt::Schedule schedule_stmt = stmt.as(); + const std::vector iter_values = schedule_stmt->iter_values(); + const BlockRef body = schedule_stmt->body(); + const std::vector iter_vars = schedule_stmt->iter_vars(); PADDLE_ENFORCE_EQ(iter_vars.size(), iter_values.size(), ::common::errors::InvalidArgument( @@ -43,9 +44,9 @@ LogicalResult RemoveScheduleBlockPass::Run(ir::stmt::BlockRef block) { iter_vars.size(), iter_values.size())); for (int i = 0; i < iter_vars.size(); i++) { - optim::ReplaceVarWithExprInBlock(body, iter_vars[i], iter_values[i], ""); + optim::ReplaceVarWithExpr(body, iter_vars[i], iter_values[i], ""); } - merge_stmt_vector(new_stmts, body->stmts()); + MergeStmtVector(new_stmts, body->stmts()); } block->set_stmts(new_stmts); return LogicalResult::success(); diff --git a/paddle/cinn/optim/remove_schedule_block_pass.h b/paddle/cinn/optim/remove_schedule_block_pass.h index 920336b555a6b2..16ea2237ac129b 100644 --- a/paddle/cinn/optim/remove_schedule_block_pass.h +++ b/paddle/cinn/optim/remove_schedule_block_pass.h @@ -23,6 +23,36 @@ class RemoveScheduleBlockPass : public BlockPass { LogicalResult Run(ir::stmt::BlockRef block) override; }; +/** + * Removes ScheduleBlock nodes from the IR tree. + * + * This pass is applicable in scenarios where ScheduleBlock nodes are present in + * the IR tree but are no longer needed for further optimization. + * + * When applied, this pass will traverse the IR tree and replace each + * ScheduleBlockRealize node with its body. During this process, it will also + * replace the iter_vars in the body with their corresponding iter_values. This + * effectively removes the ScheduleBlock structure while preserving the + * computational logic within it. + * + * Performance impact: This pass addresses the overhead of maintaining + * ScheduleBlock structures in the IR. By removing these structures, it + * simplifies the IR, which can lead to faster subsequent passes and potentially + * more efficient code generation. + * + * Examples: + * 1. Basic ScheduleBlock removal: + * Input IR: + * ScheduleBlock { + * iter_vars: [i, j] + * iter_values: [0, 1] + * body { + * body: A[i, j] = B[i, j] + C[i, j] + * } + * } + * Output IR: + * A[0, 1] = B[0, 1] + C[0, 1] + */ std::unique_ptr CreateRemoveScheduleBlockPass(); } // namespace optim diff --git a/paddle/cinn/optim/replace_var_with_expr.cc b/paddle/cinn/optim/replace_var_with_expr.cc index e6577c8837f8d2..51fb70354262f7 100644 --- a/paddle/cinn/optim/replace_var_with_expr.cc +++ b/paddle/cinn/optim/replace_var_with_expr.cc @@ -27,56 +27,50 @@ namespace cinn { namespace optim { -struct ReplaceVarWithExprStmtMutator : public ir::stmt::StmtMutator<> { - ReplaceVarWithExprStmtMutator(const Var& var, - const Expr& expr, - const std::string& tensor_name) +struct ReplaceVarWithExprMutator : public ir::IRMutator<>, + public ir::stmt::StmtMutator<> { + ReplaceVarWithExprMutator(const Var& var, + const Expr& expr, + const std::string& tensor_name) : var_(var), expr_(expr), tensor_name_(tensor_name) {} + void operator()(Expr* expr) { + if (tensor_name_.empty()) visit_all_ = true; + IRMutator::Visit(expr, expr); + } + void operator()(ir::stmt::StmtRef stmt) { - VLOG(4) << "Enter ReplaceVarWithExprStmtMutator::operator()"; if (tensor_name_.empty()) visit_all_ = true; ir::stmt::StmtMutator<>::VisitStmt(stmt); - VLOG(4) << "Exit ReplaceVarWithExprStmtMutator::operator()"; } void operator()(ir::stmt::BlockRef block) { - VLOG(4) << "Enter ReplaceVarWithExprStmtMutator::operator()"; if (tensor_name_.empty()) visit_all_ = true; ir::stmt::StmtMutator<>::VisitBlock(block); - VLOG(4) << "Exit ReplaceVarWithExprStmtMutator::operator()"; } private: - bool ShouldReplaceExpr(const Expr& expr) { - if (!expr.is_var()) { - return false; - } - if (expr.as_var()->name == var_->name && (do_replace_ || visit_all_)) { - return true; + void VisitVar(Var* var) { + if (var->get()->name == var_->name && (do_replace_ || visit_all_)) { + Expr copied = ir::ir_utils::IRCopy(expr_); + *var = copied; } - return false; } void VisitExpr(ir::Expr* expr) { - VLOG(4) << "Enter VisitExpr(ir::Expr* expr)"; if (expr->is_var()) { - VLOG(4) << "Hit var: " << expr->as_var()->name; if (expr->as_var()->name == var_->name && (do_replace_ || visit_all_)) { - auto copied = ir::ir_utils::IRCopy(expr_); + Expr copied = ir::ir_utils::IRCopy(expr_); *expr = copied; } return; } for (auto field : (*expr)->expr_fields()) { - VLOG(4) << "field: " << *field; VisitExpr(field); } - VLOG(4) << "Exit VisitExpr(ir::Expr* expr)"; } void VisitStmt(ir::stmt::Let stmt) override { - VLOG(4) << "Enter VisitStmt(ir::stmt::Let stmt)"; Expr symbol = stmt->symbol(); VisitExpr(&symbol); stmt->set_symbol(symbol); @@ -85,25 +79,24 @@ struct ReplaceVarWithExprStmtMutator : public ir::stmt::StmtMutator<> { VisitExpr(&body); stmt->set_body(body); } - VLOG(4) << "Exit VisitStmt(ir::stmt::Let stmt)"; } void VisitStmt(ir::stmt::Store stmt) override { - VLOG(4) << "Enter VisitStmt(ir::stmt::Store stmt)"; auto* tensor = stmt->tensor().as_tensor(); if (tensor && tensor->name == tensor_name_) { do_replace_ = true; } else { do_replace_ = false; } + std::vector new_indices = stmt->indices(); for (Expr& index : new_indices) { - VLOG(4) << "index: " << index; VisitExpr(&index); } stmt->set_indices(new_indices); + do_replace_ = false; - VLOG(4) << "stmt->tensor(): " << stmt->tensor(); + Expr tensor_expr = stmt->tensor(); VisitExpr(&tensor_expr); stmt->set_tensor(tensor_expr); @@ -111,38 +104,21 @@ struct ReplaceVarWithExprStmtMutator : public ir::stmt::StmtMutator<> { Expr value = stmt->value(); VisitExpr(&value); stmt->set_value(value); - VLOG(4) << "Exit VisitStmt(ir::stmt::Store stmt)"; } void VisitStmt(ir::stmt::For stmt) override { - VLOG(4) << "Enter VisitStmt(ir::stmt::For stmt)"; - VLOG(4) << "stmt->min(): " << stmt->min(); Expr min = stmt->min(); VisitExpr(&min); - VLOG(4) << "stmt->extent(): " << stmt->extent(); Expr extent = stmt->extent(); VisitExpr(&extent); - VLOG(4) << "stmt->body(): " << stmt->body(); VisitBlock(stmt->body()); if (stmt->loop_var()->name == var_->name && expr_.as_var() && visit_all_) { - auto copied = ir::ir_utils::IRCopy(expr_); + Expr copied = ir::ir_utils::IRCopy(expr_); stmt->set_loop_var(copied.as_var_ref()); } - VLOG(4) << "Exit VisitStmt(ir::stmt::For stmt)"; - } - - void VisitStmt(ir::stmt::Alloc stmt) override { - VLOG(4) << "Enter VisitStmt(ir::stmt::Alloc stmt)"; - return; - } - - void VisitStmt(ir::stmt::Free stmt) override { - VLOG(4) << "Enter VisitStmt(ir::stmt::Free stmt)"; - return; } void VisitStmt(ir::stmt::IfThenElse stmt) override { - VLOG(4) << "Enter VisitStmt(ir::stmt::IfThenElse stmt)"; Expr condition = stmt->condition(); VisitExpr(&condition); ir::stmt::BlockRef true_case = stmt->true_case(); @@ -155,22 +131,14 @@ struct ReplaceVarWithExprStmtMutator : public ir::stmt::StmtMutator<> { } } - void VisitStmt(ir::stmt::Evaluate) override { - VLOG(4) << "Enter VisitStmt(ir::stmt::Evaluate)"; - return; - } - void VisitStmt(ir::stmt::Schedule stmt) override { - VLOG(4) << "Enter VisitStmt(ir::stmt::Schedule stmt)"; std::vector vars = stmt->iter_vars(); for (ir::Var& var : vars) { - if (var->lower_bound.defined() && ShouldReplaceExpr(var->lower_bound)) { - auto copied = ir::ir_utils::IRCopy(expr_); - var->lower_bound = copied; + if (var->lower_bound.defined()) { + VisitExpr(&var->lower_bound); } - if (var->upper_bound.defined() && ShouldReplaceExpr(var->upper_bound)) { - auto copied = ir::ir_utils::IRCopy(expr_); - var->upper_bound = copied; + if (var->upper_bound.defined()) { + VisitExpr(&var->upper_bound); } } std::vector new_read_buffers = stmt->read_buffers(); @@ -187,29 +155,9 @@ struct ReplaceVarWithExprStmtMutator : public ir::stmt::StmtMutator<> { VisitBlock(stmt->body()); } - private: - bool do_replace_{false}; - bool visit_all_{false}; - const Var& var_; - const Expr& expr_; - const std::string& tensor_name_; -}; - -struct ReplaceVarWithExprMutator : public ir::IRMutator<> { - ReplaceVarWithExprMutator(const Var& var, - const Expr& expr, - const std::string& tensor_name) - : var_(var), expr_(expr), tensor_name_(tensor_name) {} - - void operator()(Expr* expr) { - if (tensor_name_.empty()) visit_all_ = true; - IRMutator::Visit(expr, expr); - } - - private: void Visit(const ir::_Var_* expr, Expr* op) override { if (expr->name == var_->name && (do_replace_ || visit_all_)) { - auto copied = ir::ir_utils::IRCopy(expr_); + Expr copied = ir::ir_utils::IRCopy(expr_); *op = copied; } } @@ -267,6 +215,12 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<> { ir::IRMutator<>::Visit(&node->tensor, &node->tensor); } + void VisitStmt(ir::stmt::Alloc stmt) override { return; } + + void VisitStmt(ir::stmt::Free stmt) override { return; } + + void VisitStmt(ir::stmt::Evaluate) override { return; } + private: bool do_replace_{false}; bool visit_all_{false}; @@ -275,23 +229,19 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<> { const std::string& tensor_name_; }; -void ReplaceVarWithExprInStmt(ir::stmt::StmtRef source, - const Var& var, - const Expr& expr, - const std::string& tensor_name) { - ReplaceVarWithExprStmtMutator mutator(var, expr, tensor_name); +void ReplaceVarWithExpr(ir::stmt::StmtRef source, + const Var& var, + const Expr& expr, + const std::string& tensor_name) { + ReplaceVarWithExprMutator mutator(var, expr, tensor_name); mutator(source); } -void ReplaceVarWithExprInBlock(ir::stmt::BlockRef source, - const Var& var, - const Expr& expr, - const std::string& tensor_name) { - VLOG(4) << "Enter ReplaceVarWithExprInBlock"; - VLOG(4) << "source: " << source; - VLOG(4) << "var: " << var; - VLOG(4) << "expr: " << expr; - ReplaceVarWithExprStmtMutator mutator(var, expr, tensor_name); +void ReplaceVarWithExpr(ir::stmt::BlockRef source, + const Var& var, + const Expr& expr, + const std::string& tensor_name) { + ReplaceVarWithExprMutator mutator(var, expr, tensor_name); mutator(source); } diff --git a/paddle/cinn/optim/replace_var_with_expr.h b/paddle/cinn/optim/replace_var_with_expr.h index c226598bddd587..53f4e89cf7af8f 100644 --- a/paddle/cinn/optim/replace_var_with_expr.h +++ b/paddle/cinn/optim/replace_var_with_expr.h @@ -60,14 +60,14 @@ void ReplaceVarWithExpr(Expr *source, const Expr &expr, const std::string &tensor_name = ""); -void ReplaceVarWithExprInStmt(ir::stmt::StmtRef source, - const Var &var, - const Expr &expr, - const std::string &tensor_name = ""); -void ReplaceVarWithExprInBlock(ir::stmt::BlockRef source, - const Var &var, - const Expr &expr, - const std::string &tensor_name = ""); +void ReplaceVarWithExpr(ir::stmt::StmtRef source, + const Var &var, + const Expr &expr, + const std::string &tensor_name = ""); +void ReplaceVarWithExpr(ir::stmt::BlockRef source, + const Var &var, + const Expr &expr, + const std::string &tensor_name = ""); /** * Collect the specific tensor's indices. From 406eb93a97cf10b55ac725fafaf230c160f6b53d Mon Sep 17 00:00:00 2001 From: ZhouXin Date: Fri, 20 Dec 2024 11:56:06 +0000 Subject: [PATCH 5/5] Reuse old version of VisitExpr --- paddle/cinn/optim/replace_var_with_expr.cc | 96 ++++++++-------------- paddle/cinn/optim/replace_var_with_expr.h | 12 +-- 2 files changed, 37 insertions(+), 71 deletions(-) diff --git a/paddle/cinn/optim/replace_var_with_expr.cc b/paddle/cinn/optim/replace_var_with_expr.cc index 51fb70354262f7..94514ff440f0cf 100644 --- a/paddle/cinn/optim/replace_var_with_expr.cc +++ b/paddle/cinn/optim/replace_var_with_expr.cc @@ -32,51 +32,28 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<>, ReplaceVarWithExprMutator(const Var& var, const Expr& expr, const std::string& tensor_name) - : var_(var), expr_(expr), tensor_name_(tensor_name) {} - - void operator()(Expr* expr) { + : var_(var), expr_(expr), tensor_name_(tensor_name) { if (tensor_name_.empty()) visit_all_ = true; - IRMutator::Visit(expr, expr); } + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + void operator()(ir::stmt::StmtRef stmt) { - if (tensor_name_.empty()) visit_all_ = true; ir::stmt::StmtMutator<>::VisitStmt(stmt); } void operator()(ir::stmt::BlockRef block) { - if (tensor_name_.empty()) visit_all_ = true; ir::stmt::StmtMutator<>::VisitBlock(block); } private: - void VisitVar(Var* var) { - if (var->get()->name == var_->name && (do_replace_ || visit_all_)) { - Expr copied = ir::ir_utils::IRCopy(expr_); - *var = copied; - } - } - - void VisitExpr(ir::Expr* expr) { - if (expr->is_var()) { - if (expr->as_var()->name == var_->name && (do_replace_ || visit_all_)) { - Expr copied = ir::ir_utils::IRCopy(expr_); - *expr = copied; - } - return; - } - for (auto field : (*expr)->expr_fields()) { - VisitExpr(field); - } - } - void VisitStmt(ir::stmt::Let stmt) override { Expr symbol = stmt->symbol(); - VisitExpr(&symbol); + ir::IRMutator<>::Visit(&symbol, &symbol); stmt->set_symbol(symbol); if (stmt->body().defined()) { Expr body = stmt->body(); - VisitExpr(&body); + ir::IRMutator<>::Visit(&body, &body); stmt->set_body(body); } } @@ -91,26 +68,26 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<>, std::vector new_indices = stmt->indices(); for (Expr& index : new_indices) { - VisitExpr(&index); + ir::IRMutator<>::Visit(&index, &index); } stmt->set_indices(new_indices); do_replace_ = false; Expr tensor_expr = stmt->tensor(); - VisitExpr(&tensor_expr); + ir::IRMutator<>::Visit(&tensor_expr, &tensor_expr); stmt->set_tensor(tensor_expr); Expr value = stmt->value(); - VisitExpr(&value); + ir::IRMutator<>::Visit(&value, &value); stmt->set_value(value); } void VisitStmt(ir::stmt::For stmt) override { Expr min = stmt->min(); - VisitExpr(&min); + ir::IRMutator<>::Visit(&min, &min); Expr extent = stmt->extent(); - VisitExpr(&extent); + ir::IRMutator<>::Visit(&extent, &extent); VisitBlock(stmt->body()); if (stmt->loop_var()->name == var_->name && expr_.as_var() && visit_all_) { Expr copied = ir::ir_utils::IRCopy(expr_); @@ -120,7 +97,7 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<>, void VisitStmt(ir::stmt::IfThenElse stmt) override { Expr condition = stmt->condition(); - VisitExpr(&condition); + ir::IRMutator<>::Visit(&condition, &condition); ir::stmt::BlockRef true_case = stmt->true_case(); VisitBlock(true_case); stmt->set_true_case(true_case); @@ -135,26 +112,32 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<>, std::vector vars = stmt->iter_vars(); for (ir::Var& var : vars) { if (var->lower_bound.defined()) { - VisitExpr(&var->lower_bound); + ir::IRMutator<>::Visit(&var->lower_bound, &var->lower_bound); } if (var->upper_bound.defined()) { - VisitExpr(&var->upper_bound); + ir::IRMutator<>::Visit(&var->upper_bound, &var->upper_bound); } } std::vector new_read_buffers = stmt->read_buffers(); for (Expr& read_buffer : new_read_buffers) { - VisitExpr(&read_buffer); + ir::IRMutator<>::Visit(&read_buffer, &read_buffer); } stmt->set_read_buffers(new_read_buffers); std::vector new_write_buffers = stmt->write_buffers(); for (Expr& write_buffer : new_write_buffers) { - VisitExpr(&write_buffer); + ir::IRMutator<>::Visit(&write_buffer, &write_buffer); } stmt->set_write_buffers(new_write_buffers); VisitBlock(stmt->body()); } + void VisitStmt(ir::stmt::Alloc stmt) override { return; } + + void VisitStmt(ir::stmt::Free stmt) override { return; } + + void VisitStmt(ir::stmt::Evaluate) override { return; } + void Visit(const ir::_Var_* expr, Expr* op) override { if (expr->name == var_->name && (do_replace_ || visit_all_)) { Expr copied = ir::ir_utils::IRCopy(expr_); @@ -215,12 +198,6 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<>, ir::IRMutator<>::Visit(&node->tensor, &node->tensor); } - void VisitStmt(ir::stmt::Alloc stmt) override { return; } - - void VisitStmt(ir::stmt::Free stmt) override { return; } - - void VisitStmt(ir::stmt::Evaluate) override { return; } - private: bool do_replace_{false}; bool visit_all_{false}; @@ -229,29 +206,26 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<>, const std::string& tensor_name_; }; -void ReplaceVarWithExpr(ir::stmt::StmtRef source, - const Var& var, - const Expr& expr, - const std::string& tensor_name) { - ReplaceVarWithExprMutator mutator(var, expr, tensor_name); - mutator(source); -} - -void ReplaceVarWithExpr(ir::stmt::BlockRef source, - const Var& var, - const Expr& expr, - const std::string& tensor_name) { - ReplaceVarWithExprMutator mutator(var, expr, tensor_name); - mutator(source); -} - -void ReplaceVarWithExpr(Expr* source, +template +void ReplaceVarWithExpr(SourceType source, const Var& var, const Expr& expr, const std::string& tensor_name) { ReplaceVarWithExprMutator mutator(var, expr, tensor_name); mutator(source); } +template void ReplaceVarWithExpr(Expr*, + const Var&, + const Expr&, + const std::string&); +template void ReplaceVarWithExpr(ir::stmt::StmtRef, + const Var&, + const Expr&, + const std::string&); +template void ReplaceVarWithExpr(ir::stmt::BlockRef, + const Var&, + const Expr&, + const std::string&); struct CollectTensorIndexMutator : public ir::IRMutator<> { explicit CollectTensorIndexMutator(const std::string& tensor_name) diff --git a/paddle/cinn/optim/replace_var_with_expr.h b/paddle/cinn/optim/replace_var_with_expr.h index 53f4e89cf7af8f..f480e575aca53a 100644 --- a/paddle/cinn/optim/replace_var_with_expr.h +++ b/paddle/cinn/optim/replace_var_with_expr.h @@ -55,16 +55,8 @@ namespace optim { * for(j, 0, 10) * B[k,j] = A[k,j] */ -void ReplaceVarWithExpr(Expr *source, - const Var &var, - const Expr &expr, - const std::string &tensor_name = ""); - -void ReplaceVarWithExpr(ir::stmt::StmtRef source, - const Var &var, - const Expr &expr, - const std::string &tensor_name = ""); -void ReplaceVarWithExpr(ir::stmt::BlockRef source, +template +void ReplaceVarWithExpr(SourceType source, const Var &var, const Expr &expr, const std::string &tensor_name = "");