Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[CINN][Backend Pass Update No.9] Update RemoveScheduleBlock pass #70334

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/cinn/optim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ gather_srcs(
lower_intrin.cc
cast_bool_to_int8.cc
var_mod_simplify.cc
remove_schedule_block.cc
remove_schedule_block_pass.cc
replace_cross_block_reduction.cc
replace_cross_thread_reduction.cc
replace_mod_to_max.cc
Expand Down
5 changes: 3 additions & 2 deletions paddle/cinn/optim/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include "paddle/cinn/optim/lower_intrin.h"
#include "paddle/cinn/optim/map_extern_call.h"
#include "paddle/cinn/optim/rearrange_load_instruction.h"
#include "paddle/cinn/optim/remove_schedule_block.h"
#include "paddle/cinn/optim/remove_schedule_block_pass.h"
#include "paddle/cinn/optim/replace_const_param_to_integer.h"
#include "paddle/cinn/optim/replace_cross_block_reduction.h"
#include "paddle/cinn/optim/replace_cross_thread_reduction.h"
Expand Down Expand Up @@ -116,7 +116,8 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
Simplify(&copied->body);
VLOG(10) << "After Optimize Simplify" << copied;

RemoveScheduleBlock(&copied->body);
pass_manager.AddPass(CreateRemoveScheduleBlockPass());
pass_manager.Run(copied);
VLOG(10) << "After RemoveScheduleBlock:" << copied;

LowerIntrin(&copied->body, target);
Expand Down
63 changes: 0 additions & 63 deletions paddle/cinn/optim/remove_schedule_block.cc

This file was deleted.

60 changes: 0 additions & 60 deletions paddle/cinn/optim/remove_schedule_block.h

This file was deleted.

59 changes: 59 additions & 0 deletions paddle/cinn/optim/remove_schedule_block_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/optim/remove_schedule_block_pass.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
namespace cinn {
namespace optim {

LogicalResult RemoveScheduleBlockPass::Run(ir::stmt::BlockRef block) {
VLOG(4) << "RemoveScheduleBlockPass Run";
auto merge_stmt_vector = [&](std::vector<ir::stmt::StmtRef>& dest,
const std::vector<ir::stmt::StmtRef>& source) {
dest.insert(dest.end(), source.begin(), source.end());
};

const auto& stmts = block->stmts();
std::vector<ir::stmt::StmtRef> new_stmts;
for (auto& stmt : stmts) {
if (!stmt.isa<ir::stmt::Schedule>()) {
new_stmts.push_back(stmt);
continue;
}
auto schedule_stmt = stmt.as<ir::stmt::Schedule>();
auto iter_values = schedule_stmt->iter_values();
auto body = schedule_stmt->body();
auto iter_vars = schedule_stmt->iter_vars();
PADDLE_ENFORCE_EQ(iter_vars.size(),
iter_values.size(),
::common::errors::InvalidArgument(
"The size of iter vars and iter values is not equal,"
"where iter vars:%d but iter values:%d.",
iter_vars.size(),
iter_values.size()));
for (int i = 0; i < iter_vars.size(); i++) {
optim::ReplaceVarWithExprInBlock(body, iter_vars[i], iter_values[i], "");
}
merge_stmt_vector(new_stmts, body->stmts());
}
block->set_stmts(new_stmts);
return LogicalResult::success();
}

std::unique_ptr<BlockPass> CreateRemoveScheduleBlockPass() {
return std::make_unique<RemoveScheduleBlockPass>();
}

} // namespace optim
} // namespace cinn
29 changes: 29 additions & 0 deletions paddle/cinn/optim/remove_schedule_block_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "paddle/cinn/pass/pass.h"

namespace cinn {
namespace optim {
class RemoveScheduleBlockPass : public BlockPass {
public:
RemoveScheduleBlockPass() : BlockPass("remove_schedule_block") {}
LogicalResult Run(ir::stmt::BlockRef block) override;
};

std::unique_ptr<BlockPass> CreateRemoveScheduleBlockPass();

} // namespace optim
} // namespace cinn
Loading