From 3d14cdac6d43afe99fd8721f81bfbc406c0915c9 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 21 Jan 2021 19:37:23 -0800 Subject: [PATCH] feat(//core/lowering): Adding a new pass to handle new dim checks for batchnorm Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/lowering/lowering.cpp | 1 + core/lowering/passes/BUILD | 1 + core/lowering/passes/passes.h | 1 + core/lowering/passes/remove_bn_dim_check.cpp | 88 ++++++++++++++++++++ 4 files changed, 91 insertions(+) create mode 100644 core/lowering/passes/remove_bn_dim_check.cpp diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index ab1f5561ef..7ee86f0d4c 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -40,6 +40,7 @@ void LowerGraph(std::shared_ptr& g) { passes::Conv2DToConvolution(g); passes::Conv3DToConvolution(g); passes::FuseAddMMBranches(g); + passes::RemoveBNDimCheck(g); torch::jit::EliminateCommonSubexpression(g); // torch::jit::UnrollLoops(g); torch::jit::EliminateCommonSubexpression(g); diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index be6f3fcf42..9d3f328e20 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -18,6 +18,7 @@ cc_library( "exception_elimination.cpp", "fuse_addmm_branches.cpp", "fuse_flatten_linear.cpp", + "remove_bn_dim_check.cpp", "remove_contiguous.cpp", "remove_dropout.cpp", "remove_to.cpp", diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index f1c72c2aca..977533a3d6 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -12,6 +12,7 @@ void Conv3DToConvolution(std::shared_ptr& graph); void FuseAddMMBranches(std::shared_ptr graph); void FuseFlattenLinear(std::shared_ptr& graph); void EliminateExceptionOrPassPattern(std::shared_ptr graph); +void RemoveBNDimCheck(std::shared_ptr graph); void RemoveContiguous(std::shared_ptr& graph); void RemoveDropout(std::shared_ptr& graph); void RemoveTo(std::shared_ptr graph); diff --git a/core/lowering/passes/remove_bn_dim_check.cpp b/core/lowering/passes/remove_bn_dim_check.cpp new file mode 100644 index 0000000000..92e48137e4 --- /dev/null +++ b/core/lowering/passes/remove_bn_dim_check.cpp @@ -0,0 +1,88 @@ +#include "torch/csrc/jit/ir/alias_analysis.h" +#include "torch/csrc/jit/jit_log.h" +#include "torch/csrc/jit/passes/constant_propagation.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" +#include "torch/csrc/jit/passes/guard_elimination.h" +#include "torch/csrc/jit/passes/peephole.h" +#include "torch/csrc/jit/runtime/graph_executor.h" + +#include "core/util/prelude.h" + +#include + +namespace trtorch { +namespace core { +namespace lowering { +namespace passes { +namespace { +using namespace torch::jit; +struct BNDimCheckRemoval { + BNDimCheckRemoval(std::shared_ptr graph) : graph_(std::move(graph)) {} + + void run() { + findBNDimCheckNodes(graph_->block()); + torch::jit::EliminateDeadCode(graph_); + LOG_GRAPH("Post aten::addmm branch fusion: " << *graph_); + } + + private: + bool isBNDimCheckNodes(Node* n) { + /// Check if this Node hosts a pattern like so: + /// %290 : bool = aten::ne(%289, %9) + /// = prim::If(%290) + /// block0(): + /// %291 : str = aten::format(%10, %289) + /// = prim::RaiseException(%291) + /// -> () + /// block1(): + /// -> () + + if (n->blocks().size() != 2) { + return false; + } + auto arm1 = n->blocks()[0]; + auto arm2 = n->blocks()[1]; + if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) { + // Make sure that the node doesn't actually produce any Value that are + // used by other nodes + return false; + } + + auto arm1_start = arm1->nodes().begin(); + + if ((*arm1_start)->kind() != c10::Symbol::fromQualString("aten::format") && (*(++arm1_start))->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) { + // Make sure that block0 is solely just the exception and the return + return false; + } + + if ((*(arm2->nodes().begin()))->kind() != prim::Return) { + // Make sure that block1 is solely the return + return false; + } + + return true; + } + + void findBNDimCheckNodes(Block* b) { + for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { + auto n = *it; + if (n->kind() == prim::If && isBNDimCheckNodes(n)) { + LOG_GRAPH("Found that node " << *n << " is an batch norm dim check node (EliminateChecks)" << std::endl); + it.destroyCurrent(); + } + } + } + + std::shared_ptr graph_; +}; +} // namespace + +void RemoveBNDimCheck(std::shared_ptr graph) { + BNDimCheckRemoval bndcr(std::move(graph)); + bndcr.run(); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace trtorch