1
- #include " torch/csrc/jit/passes/subgraph_rewrite.h"
1
+
2
+ #include < torch/csrc/jit/runtime/operator.h>
3
+ #include " torch/csrc/jit/ir/alias_analysis.h"
4
+ #include " torch/csrc/jit/jit_log.h"
5
+ #include " torch/csrc/jit/passes/constant_propagation.h"
6
+ #include " torch/csrc/jit/passes/dead_code_elimination.h"
7
+ #include " torch/csrc/jit/passes/guard_elimination.h"
8
+ #include " torch/csrc/jit/passes/peephole.h"
9
+ #include " torch/csrc/jit/runtime/graph_executor.h"
2
10
3
11
#include " core/util/prelude.h"
12
+ #include " torch/csrc/jit/passes/subgraph_rewrite.h"
4
13
5
14
namespace trtorch {
6
15
namespace core {
7
16
namespace lowering {
8
17
namespace passes {
9
18
19
+ void replaceLinearWithBiasNonePattern (std::shared_ptr<torch::jit::Graph> graph) {
20
+ // Define the decomposition function for aten::linear for the case where bias (mat2) is None.
21
+ static torch::jit::CompilationUnit decompose_funcs (R"SCRIPT(
22
+ def linear(self: Tensor, mat1: Tensor, mat2: Tensor):
23
+ return torch.matmul(self, mat1.t())
24
+ )SCRIPT" );
25
+
26
+ // Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case)
27
+ auto block = graph->block ();
28
+ for (auto it = block->nodes ().begin (); it != block->nodes ().end (); it++) {
29
+ auto n = *it;
30
+ if (n->kind ().toQualString () == std::string (" aten::linear" )) {
31
+ auto input_values = n->inputs ();
32
+ // input_values[2] is the bias. If none, replace it with the decomposed linear graph.
33
+ if (input_values[2 ]->type ()->isSubtypeOf (c10::TensorType::get ())) {
34
+ continue ;
35
+ } else {
36
+ torch::jit::WithInsertPoint guard (*it);
37
+ std::shared_ptr<torch::jit::Graph> d_graph = decompose_funcs.get_function (" linear" ).graph ();
38
+ torch::jit::Value* new_output = insertGraph (*it->owningGraph (), *d_graph, it->inputs ()).at (0 );
39
+ new_output->setType (it->output ()->type ());
40
+ it->output ()->replaceAllUsesWith (new_output);
41
+ it.destroyCurrent ();
42
+ }
43
+ }
44
+ }
45
+ }
46
+
10
47
void LinearToAddMM (std::shared_ptr<torch::jit::Graph>& graph) {
11
48
// TensorRT implicitly adds a flatten layer infront of FC layers if necessary
12
49
std::string flatten_linear_pattern = R"IR(
13
50
graph(%input, %weight, %bias):
14
51
%res = aten::linear(%input, %weight, %bias)
15
52
return (%res))IR" ;
16
- std::string flatten_linear_bias_none_pattern = R"IR(
17
- graph(%input, %weight):
18
- %bias: Tensor? = prim::Constant()
19
- %res = aten::linear(%input, %weight, %bias)
20
- return (%res))IR" ;
21
53
22
54
std::string fused_linear = R"IR(
23
55
graph(%input, %weight_t, %bias):
@@ -27,20 +59,13 @@ void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
27
59
%b_f: Tensor = trt::const(%bias)
28
60
%out: Tensor = aten::add(%b_f, %mm, %1)
29
61
return (%out))IR" ;
30
- std::string fused_linear_bias_none = R"IR(
31
- graph(%input, %weight_t):
32
- %weight = aten::t(%weight_t)
33
- %mm: Tensor = aten::matmul(%input, %weight)
34
- return (%mm))IR" ;
62
+
63
+ // First find and replace aten::linear nodes with non-tensor bias values.
64
+ replaceLinearWithBiasNonePattern (graph);
35
65
36
66
torch::jit::SubgraphRewriter flatten_linear_to_linear;
37
67
flatten_linear_to_linear.RegisterRewritePattern (flatten_linear_pattern, fused_linear);
38
68
flatten_linear_to_linear.runOnGraph (graph);
39
-
40
- torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
41
- flatten_linear_bias_none_to_linear.RegisterRewritePattern (flatten_linear_bias_none_pattern, fused_linear_bias_none);
42
- flatten_linear_bias_none_to_linear.runOnGraph (graph);
43
- LOG_GRAPH (" Post linear to addmm: " << *graph);
44
69
}
45
70
46
71
} // namespace passes
0 commit comments