Skip to content

Commit 930d582

Browse files
committed
fix: Fix linear lowering pass, lift layer_norm scale layer restriction and matmul layer nbdims restriction
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent 3cb4917 commit 930d582

File tree

3 files changed

+71
-24
lines changed

3 files changed

+71
-24
lines changed

core/conversion/converters/impl/layer_norm.cpp

+24-5
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,31 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
117117
}
118118

119119
auto power = Weights(ctx, at::ones(expand_size));
120-
auto scale_nd = ctx->net->addScaleNd(
121-
*div_out, nvinfer1::ScaleMode::kELEMENTWISE, beta_weights.data, gamma_weights.data, power.data, 1);
122-
scale_nd->setName((util::node_info(n) + "_scale_nd").c_str());
123-
auto scale_nd_out = scale_nd->getOutput(0);
124120

125-
ctx->AssociateValueAndTensor(n->outputs()[0], scale_nd_out);
121+
auto gamma_tensor = ctx->net->addConstant(gamma_weights.shape, gamma_weights.data)->getOutput(0);
122+
auto scale_l = add_elementwise(
123+
ctx, nvinfer1::ElementWiseOperation::kPROD, div_out, gamma_tensor, (util::node_info(n) + "_scale").c_str());
124+
125+
auto beta_tensor = ctx->net->addConstant(beta_weights.shape, beta_weights.data)->getOutput(0);
126+
auto shift_l = add_elementwise(
127+
ctx,
128+
nvinfer1::ElementWiseOperation::kSUM,
129+
scale_l->getOutput(0),
130+
beta_tensor,
131+
(util::node_info(n) + "_shift").c_str());
132+
133+
auto power_tensor = ctx->net->addConstant(power.shape, power.data)->getOutput(0);
134+
auto power_l = add_elementwise(
135+
ctx,
136+
nvinfer1::ElementWiseOperation::kPOW,
137+
shift_l->getOutput(0),
138+
power_tensor,
139+
(util::node_info(n) + "_power").c_str());
140+
141+
power_l->setName((util::node_info(n) + "_scale_nd").c_str());
142+
auto power_l_out = power_l->getOutput(0);
143+
144+
ctx->AssociateValueAndTensor(n->outputs()[0], power_l_out);
126145
return true;
127146
}});
128147

core/conversion/converters/impl/matrix_multiply.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "core/conversion/converters/converter_util.h"
12
#include "core/conversion/converters/converters.h"
23
#include "core/util/prelude.h"
34

@@ -16,10 +17,12 @@ auto mm_registrations TRTORCH_UNUSED =
1617
LOG_DEBUG("self tensor shape: " << self->getDimensions());
1718

1819
auto other = args[1].ITensorOrFreeze(ctx);
19-
LOG_DEBUG("other tensor shape: " << other->getDimensions());
20+
// "other" tensor should have same nbDims as self
21+
auto wt_tensor = addPadding(ctx, n, other, self->getDimensions().nbDims, false, false);
22+
LOG_DEBUG("other tensor shape: " << wt_tensor->getDimensions());
2023

2124
auto mm_layer = ctx->net->addMatrixMultiply(
22-
*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
25+
*self, nvinfer1::MatrixOperation::kNONE, *wt_tensor, nvinfer1::MatrixOperation::kNONE);
2326
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
2427
mm_layer->setName(util::node_info(n).c_str());
2528
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
@@ -73,4 +76,4 @@ auto mm_registrations TRTORCH_UNUSED =
7376
} // namespace converters
7477
} // namespace conversion
7578
} // namespace core
76-
} // namespace trtorch
79+
} // namespace trtorch

core/lowering/passes/linear_to_addmm.cpp

+41-16
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,55 @@
1-
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
1+
2+
#include <torch/csrc/jit/runtime/operator.h>
3+
#include "torch/csrc/jit/ir/alias_analysis.h"
4+
#include "torch/csrc/jit/jit_log.h"
5+
#include "torch/csrc/jit/passes/constant_propagation.h"
6+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
7+
#include "torch/csrc/jit/passes/guard_elimination.h"
8+
#include "torch/csrc/jit/passes/peephole.h"
9+
#include "torch/csrc/jit/runtime/graph_executor.h"
210

311
#include "core/util/prelude.h"
12+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
413

514
namespace trtorch {
615
namespace core {
716
namespace lowering {
817
namespace passes {
918

19+
void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph) {
20+
// Define the decomposition function for aten::linear for the case where bias (mat2) is None.
21+
static torch::jit::CompilationUnit decompose_funcs(R"SCRIPT(
22+
def linear(self: Tensor, mat1: Tensor, mat2: Tensor):
23+
return torch.matmul(self, mat1.t())
24+
)SCRIPT");
25+
26+
// Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case)
27+
auto block = graph->block();
28+
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
29+
auto n = *it;
30+
if (n->kind().toQualString() == std::string("aten::linear")) {
31+
auto input_values = n->inputs();
32+
// input_values[2] is the bias. If none, replace it with the decomposed linear graph.
33+
if (input_values[2]->type()->isSubtypeOf(c10::TensorType::get())) {
34+
continue;
35+
} else {
36+
torch::jit::WithInsertPoint guard(*it);
37+
std::shared_ptr<torch::jit::Graph> d_graph = decompose_funcs.get_function("linear").graph();
38+
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
39+
new_output->setType(it->output()->type());
40+
it->output()->replaceAllUsesWith(new_output);
41+
it.destroyCurrent();
42+
}
43+
}
44+
}
45+
}
46+
1047
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
1148
// TensorRT implicitly adds a flatten layer infront of FC layers if necessary
1249
std::string flatten_linear_pattern = R"IR(
1350
graph(%input, %weight, %bias):
1451
%res = aten::linear(%input, %weight, %bias)
1552
return (%res))IR";
16-
std::string flatten_linear_bias_none_pattern = R"IR(
17-
graph(%input, %weight):
18-
%bias: Tensor? = prim::Constant()
19-
%res = aten::linear(%input, %weight, %bias)
20-
return (%res))IR";
2153

2254
std::string fused_linear = R"IR(
2355
graph(%input, %weight_t, %bias):
@@ -27,20 +59,13 @@ void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
2759
%b_f: Tensor = trt::const(%bias)
2860
%out: Tensor = aten::add(%b_f, %mm, %1)
2961
return (%out))IR";
30-
std::string fused_linear_bias_none = R"IR(
31-
graph(%input, %weight_t):
32-
%weight = aten::t(%weight_t)
33-
%mm: Tensor = aten::matmul(%input, %weight)
34-
return (%mm))IR";
62+
63+
// First find and replace aten::linear nodes with non-tensor bias values.
64+
replaceLinearWithBiasNonePattern(graph);
3565

3666
torch::jit::SubgraphRewriter flatten_linear_to_linear;
3767
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
3868
flatten_linear_to_linear.runOnGraph(graph);
39-
40-
torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
41-
flatten_linear_bias_none_to_linear.RegisterRewritePattern(flatten_linear_bias_none_pattern, fused_linear_bias_none);
42-
flatten_linear_bias_none_to_linear.runOnGraph(graph);
43-
LOG_GRAPH("Post linear to addmm: " << *graph);
4469
}
4570

4671
} // namespace passes

0 commit comments

Comments
 (0)