Skip to content

Commit 46ac757

Browse files
committedApr 4, 2022
feat(aten::Int): Adding a new pass to remove single use
0D Tensors Now we remove select more complex aten::Int cases found in models such as BERT, like the following: ``` graph(%0: int): %1: Tensor = prim::Constant[value={8}]() %2: int = prim::Constant[value=1]() %3: Tensor = prim::NumToTensor(%0) %4: Tensor = aten::add(%1, %3, %2) %5: int = aten::Int(%4) %6: int = aten::add(%5, %5) return (%6)"; graph(%0: int): %1: int = prim::Constant[value=8]() %4: int = aten::add(%1, %0) %6: int = aten::add(%4, %4) return (%6)"; ``` Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 908340f commit 46ac757

File tree

5 files changed

+232
-3
lines changed

5 files changed

+232
-3
lines changed
 

‎core/lowering/lowering.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
6464
passes::RemoveNOPs(g);
6565
passes::AliasOperators(g);
6666
passes::SiluToSigmoidMultipication(g);
67+
passes::RemoveSingleUse0DTensors(g);
6768
passes::RemoveUnnecessaryCasts(g);
6869
LOG_GRAPH(*g);
6970
}

‎core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
2727
void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph);
2828
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
2929
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
30+
void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g);
3031
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
3132
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
3233
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include <stack>
2+
#include <unordered_set>
3+
4+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
5+
6+
#include "core/lowering/passes/passes.h"
7+
#include "core/util/prelude.h"
8+
9+
namespace torch_tensorrt {
10+
namespace core {
11+
namespace lowering {
12+
namespace passes {
13+
14+
void RemoveSetAttrs(const torch::jit::Module& mod, std::string method_name) {
15+
auto g = mod.get_method(method_name).graph();
16+
17+
std::string set_attr_pattern = R"IR(
18+
graph(%self, %0):
19+
None = prim::SetAttr[name="_has_warned"](%self, %0)
20+
return ())IR";
21+
std::string no_set_attr_pattern = R"IR(
22+
graph(%self, %0):
23+
return ())IR";
24+
25+
// remove contiguous
26+
torch::jit::SubgraphRewriter remove_set_attr;
27+
remove_set_attr.RegisterRewritePattern(set_attr_pattern, no_set_attr_pattern);
28+
remove_set_attr.runOnGraph(g);
29+
LOG_GRAPH("Post remove contiguous: " << *g);
30+
}
31+
32+
} // namespace passes
33+
} // namespace lowering
34+
} // namespace core
35+
} // namespace torch_tensorrt

‎core/lowering/passes/remove_unnecessary_casts.cpp

+114
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
2+
#include "torch/csrc/jit/ir/constants.h"
23

34
#include "core/util/prelude.h"
45

@@ -55,6 +56,119 @@ void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph) {
5556
LOG_GRAPH("After RemoveUnnecessaryCasts: " << *graph);
5657
}
5758

59+
void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
60+
for (auto it = g->block()->nodes().begin(), end = g->block()->nodes().end(); it != end; ++it) {
61+
if (it->kind() == torch::jit::prim::Constant) {
62+
// Going from a constant and is single use means we can fuse
63+
if (it->output()->type()->isSubtypeOf(c10::TensorType::get())) {
64+
// Get the tensor stored in constant
65+
at::Tensor t = *torch::jit::constant_as<at::Tensor>(it->output());
66+
// If shape is 0D
67+
if (t.sizes() == std::vector<int64_t>({})) {
68+
LOG_GRAPH("Found a 0D Tensor: " << it->output()->debugName());
69+
LOG_GRAPH("Number of uses: " << it->output()->uses().size());
70+
// If the tensor is only used once
71+
if (it->output()->uses().size() == 1) {
72+
auto use = it->output()->uses()[0];
73+
auto user = use.user;
74+
75+
// Is a NumToTensor / aten::[Int/Float] case
76+
if (user->outputs().size() == 1 && user->outputs()[0]->type()->isSubtypeOf(c10::TensorType::get())) {
77+
if (user->output()->uses().size() == 1) {
78+
auto potential_cast = user->output()->uses()[0].user;
79+
// The downstream user is aten::Int
80+
if (potential_cast->kind() == c10::Symbol::fromQualString("aten::Int")
81+
|| potential_cast->kind() == c10::Symbol::fromQualString("aten::Float")) {
82+
LOG_GRAPH("Downstream user is aten::Int/aten::Float");
83+
auto arg = use.offset;
84+
85+
for (size_t k = 0; k < user->inputs().size(); ++k) {
86+
if (k != arg) {
87+
if (user->inputs()[k]->type()->isSubtypeOf(c10::TensorType::get())) {
88+
LOG_GRAPH("Input " << k << " is a Tensor");
89+
if (user->inputs()[k]->node()->kind() == c10::Symbol::fromQualString("prim::NumToTensor")) {
90+
auto num_to_tensor = user->inputs()[k]->node();
91+
92+
LOG_GRAPH("Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n "
93+
<< *(*it)
94+
<< *num_to_tensor
95+
<< *user
96+
<< *potential_cast);
97+
98+
// Replace the Tensor Constant with a scalar constant
99+
LOG_GRAPH("Deleting 0-dim Tensor: " << **it);
100+
torch::jit::WithInsertPoint gaurd(*it);
101+
102+
auto new_const_val = g->insertConstant(t.item(), c10::nullopt, it->scope());
103+
new_const_val->copyMetadata(it->output());
104+
// How to determine the internal scalar type instead of assuming?
105+
if (potential_cast->kind() == c10::aten::Int) {
106+
new_const_val->setType(c10::IntType::get());
107+
} else if (potential_cast->kind() == c10::aten::Float) {
108+
new_const_val->setType(c10::FloatType::get());
109+
}
110+
it->output()->replaceAllUsesWith(new_const_val);
111+
it.destroyCurrent();
112+
113+
LOG_GRAPH("New constant: " << *new_const_val->node());
114+
115+
// Delete NumToTensor
116+
LOG_GRAPH("Deleting NumToTensor: " << *num_to_tensor);
117+
num_to_tensor->output()->replaceAllUsesWith(num_to_tensor->inputs()[0]);
118+
num_to_tensor->destroy();
119+
120+
// Change intermediate op output type
121+
LOG_GRAPH(user->schema());
122+
123+
torch::jit::Node* new_node;
124+
switch (user->kind()) {
125+
// Use this to handle special cases where the scalar version of the intermediate operator
126+
// has a different schema than the original
127+
case c10::aten::add:
128+
new_node = g->create(
129+
user->kind(),
130+
torch::jit::ArrayRef<torch::jit::Value*>({user->inputs()[0], user->inputs()[1]}),
131+
1);
132+
new_node->insertAfter(user);
133+
new_node->outputs()[0]->setType(c10::IntType::get());
134+
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
135+
user->destroy();
136+
break;
137+
default:
138+
new_node = g->create(
139+
user->kind(),
140+
user->inputs(),
141+
1);
142+
new_node->insertAfter(user);
143+
new_node->outputs()[0]->setType(c10::IntType::get());
144+
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
145+
user->destroy();
146+
break;
147+
}
148+
149+
LOG_GRAPH("New intermediate operation: " << *new_node);
150+
LOG_GRAPH(new_node->schema());
151+
152+
// Delete aten::Int
153+
LOG_GRAPH("Deleting aten::[Int/Float]: " << *potential_cast);
154+
potential_cast->output()->replaceAllUsesWith(potential_cast->inputs()[0]);
155+
potential_cast->destroy();
156+
}
157+
}
158+
}
159+
}
160+
}
161+
}
162+
}
163+
}
164+
}
165+
}
166+
}
167+
}
168+
LOG_ERROR("Post removing single use 0-dim Tensor operations: " << *g);
169+
}
170+
171+
58172
} // namespace passes
59173
} // namespace lowering
60174
} // namespace core

‎tests/core/lowering/test_remove_unnecessary_casts.cpp

+81-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ TEST(LoweringPasses, RemoveUnnecessaryCastIntCorrectly) {
2222
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
2323
auto sg = std::make_shared<torch::jit::Graph>();
2424
torch::jit::parseIR(source_graph, sg.get());
25-
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);
25+
torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts(sg);
2626

2727
auto tg = std::make_shared<torch::jit::Graph>();
2828
torch::jit::parseIR(target_graph, tg.get());
@@ -46,7 +46,7 @@ TEST(LoweringPasses, RemoveUnnecessaryCastFloatCorrectly) {
4646
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
4747
auto sg = std::make_shared<torch::jit::Graph>();
4848
torch::jit::parseIR(source_graph, sg.get());
49-
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);
49+
torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts(sg);
5050

5151
auto tg = std::make_shared<torch::jit::Graph>();
5252
torch::jit::parseIR(target_graph, tg.get());
@@ -70,7 +70,85 @@ TEST(LoweringPasses, RemoveUnnecessaryCastBoolCorrectly) {
7070
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
7171
auto sg = std::make_shared<torch::jit::Graph>();
7272
torch::jit::parseIR(source_graph, sg.get());
73-
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);
73+
torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts(sg);
74+
75+
auto tg = std::make_shared<torch::jit::Graph>();
76+
torch::jit::parseIR(target_graph, tg.get());
77+
78+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
79+
}
80+
81+
TEST(LoweringPasses, RemoveSingleUse0DTensorsIntCorrectly) {
82+
std::string source_graph = R"IR(
83+
graph(%0: int):
84+
%1: Tensor = prim::Constant[value=[8]]()
85+
%2: int = prim::Constant[value=1]()
86+
%3: Tensor = prim::NumToTensor(%0)
87+
%4: Tensor = aten::add(%1, %3, %2)
88+
%5: int = aten::Int(%4)
89+
%6: int = aten::add(%5, %5)
90+
return (%6))IR";
91+
std::string target_graph = R"IR(
92+
graph(%0: int):
93+
%1: int = prim::Constant[value=8]()
94+
%4: int = aten::add(%1, %0)
95+
%6: int = aten::add(%4, %4)
96+
return (%6))IR";
97+
98+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
99+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
100+
auto sg = std::make_shared<torch::jit::Graph>();
101+
torch::jit::parseIR(source_graph, sg.get());
102+
103+
auto first_op = *(sg->block()->nodes().begin());
104+
torch::jit::WithInsertPoint guard(first_op);
105+
torch::jit::Value* r = sg->insertConstant(
106+
c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
107+
r->copyMetadata(first_op->output());
108+
r->setType(c10::TensorType::get());
109+
first_op->output()->replaceAllUsesWith(r);
110+
first_op->destroy();
111+
112+
torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);
113+
114+
auto tg = std::make_shared<torch::jit::Graph>();
115+
torch::jit::parseIR(target_graph, tg.get());
116+
117+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
118+
}
119+
120+
TEST(LoweringPasses, RemoveSingleUse0DTensorsFloatCorrectly) {
121+
std::string source_graph = R"IR(
122+
graph(%0: float):
123+
%1: Tensor = prim::Constant[value=[8.]]()
124+
%2: float = prim::Constant[value=1.]()
125+
%3: Tensor = prim::NumToTensor(%0)
126+
%4: Tensor = aten::add(%1, %3, %2)
127+
%5: float = aten::Float(%4)
128+
%6: float = aten::add(%5, %5)
129+
return (%6))IR";
130+
std::string target_graph = R"IR(
131+
graph(%0: float):
132+
%1: float = prim::Constant[value=8.]()
133+
%4: float = aten::add(%1, %0)
134+
%6: float = aten::add(%4, %4)
135+
return (%6))IR";
136+
137+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
138+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
139+
auto sg = std::make_shared<torch::jit::Graph>();
140+
torch::jit::parseIR(source_graph, sg.get());
141+
142+
auto first_op = *(sg->block()->nodes().begin());
143+
torch::jit::WithInsertPoint guard(first_op);
144+
torch::jit::Value* r = sg->insertConstant(
145+
c10::scalar_to_tensor(8.0), c10::nullopt, first_op->scope());
146+
r->copyMetadata(first_op->output());
147+
r->setType(c10::TensorType::get());
148+
first_op->output()->replaceAllUsesWith(r);
149+
first_op->destroy();
150+
151+
torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);
74152

75153
auto tg = std::make_shared<torch::jit::Graph>();
76154
torch::jit::parseIR(target_graph, tg.get());

0 commit comments

Comments
 (0)