Skip to content

Commit 8024ea2

Browse files
committed
feat: Implement fast approximation of Gelu as lowering pass to improve performance
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com> Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com> Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com> chore: refactor converters Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com> chore: Upload reduce_gelu.cpp Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com> chore: Add files Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent da15fa5 commit 8024ea2

File tree

8 files changed

+99
-32
lines changed

8 files changed

+99
-32
lines changed

core/conversion/converters/impl/activation.cpp

-32
Original file line numberDiff line numberDiff line change
@@ -166,39 +166,7 @@ auto acthardtanh TORCHTRT_UNUSED =
166166
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
167167
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
168168
return true;
169-
}})
170-
.pattern({"aten::gelu(Tensor self) -> (Tensor)",
171-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
172-
auto in = args[0].ITensorOrFreeze(ctx);
173-
nvinfer1::DataType type = in->getType();
174-
TORCHTRT_CHECK(
175-
type == nvinfer1::DataType::kFLOAT || type == nvinfer1::DataType::kHALF,
176-
"gelu only supports kFLOAT and kHALF");
177-
std::string pluginName = "CustomGeluPluginDynamic";
178-
nvinfer1::PluginFieldCollection fc;
179-
std::vector<nvinfer1::PluginField> f;
180-
// REVIEW is this right?
181-
int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) ==
182-
ctx->settings.enabled_precisions.end()
183-
? 0
184-
: 1; // Integer encoding the DataType (0: FP32, 1: FP16)
185-
f.emplace_back(nvinfer1::PluginField("type_id", &type_id, nvinfer1::PluginFieldType::kINT32, 1));
186-
fc.nbFields = f.size();
187-
fc.fields = f.data();
188-
189-
auto creator = getPluginRegistry()->getPluginCreator("CustomGeluPluginDynamic", "1", "");
190-
auto gelu_plugin = creator->createPlugin("gelu", &fc);
191-
192-
TORCHTRT_CHECK(gelu_plugin, "Unable to create gelu plugin from TensorRT plugin registry" << *n);
193-
auto new_layer =
194-
ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *gelu_plugin);
195-
new_layer->setName(util::node_info(n).c_str());
196-
auto out_tensor = new_layer->getOutput(0);
197-
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
198-
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
199-
return true;
200169
}});
201-
202170
} // namespace
203171
} // namespace impl
204172
} // namespace converters

core/lowering/lowering.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
4343
passes::UnpackHardSwish(g);
4444
passes::EliminateExceptionOrPassPattern(g);
4545
passes::ReduceToOperation(g);
46+
passes::ReduceGelu(g);
4647
passes::RemoveContiguous(g);
4748
passes::RemoveDropout(g);
4849
passes::LinearToAddMM(g);

core/lowering/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ cc_library(
1717
"module_fallback.cpp",
1818
"op_aliasing.cpp",
1919
"reduce_to.cpp",
20+
"reduce_gelu.cpp",
2021
"remove_bn_dim_check.cpp",
2122
"remove_contiguous.cpp",
2223
"remove_dropout.cpp",

core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
2020
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
2121
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
2222
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
23+
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph);
2324
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims);
2425
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
2526
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);

core/lowering/passes/reduce_gelu.cpp

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
#include "core/util/prelude.h"
3+
4+
namespace torch_tensorrt {
5+
namespace core {
6+
namespace lowering {
7+
namespace passes {
8+
9+
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
10+
std::string gelu_pattern = R"IR(
11+
graph(%x):
12+
%out : Tensor = aten::gelu(%x)
13+
return (%out))IR";
14+
15+
std::string gelu_reduce_pattern = R"IR(
16+
graph(%x.1 : Tensor):
17+
%6 : float = prim::Constant[value=0.044714999999999998]()
18+
%5 : float = prim::Constant[value=0.79788456080000003]()
19+
%4 : float = prim::Constant[value=1.]()
20+
%3 : float = prim::Constant[value=0.5]()
21+
%2 : int = prim::Constant[value=1]()
22+
%7 : Tensor = aten::mul(%x.1, %3)
23+
%8 : Tensor = aten::mul(%x.1, %5)
24+
%9 : Tensor = aten::mul(%x.1, %6)
25+
%10 : Tensor = aten::mul(%9, %x.1)
26+
%11 : Tensor = aten::add(%10, %4, %2)
27+
%12 : Tensor = aten::mul(%8, %11)
28+
%13 : Tensor = aten::tanh(%12)
29+
%14 : Tensor = aten::add(%13, %4, %2)
30+
%15 : Tensor = aten::mul(%7, %14)
31+
return (%15))IR";
32+
33+
// replace aten::gelu with pointwise operations
34+
torch::jit::SubgraphRewriter map_gelu_to_pointwise_ops;
35+
map_gelu_to_pointwise_ops.RegisterRewritePattern(gelu_pattern, gelu_reduce_pattern);
36+
map_gelu_to_pointwise_ops.runOnGraph(graph);
37+
38+
LOG_GRAPH("Post lowering of [aten::gelu] -> " << *graph);
39+
}
40+
41+
} // namespace passes
42+
} // namespace lowering
43+
} // namespace core
44+
} // namespace torch_tensorrt

tests/core/conversion/converters/test_activation.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <string>
22
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
34
#include "gtest/gtest.h"
45
#include "tests/util/util.h"
56
#include "torch/csrc/jit/ir/irparser.h"
@@ -211,6 +212,10 @@ TEST(Converters, ATenGELUConvertsCorrectly) {
211212

212213
auto in = at::randint(-5, 5, {5}, {at::kCUDA});
213214

215+
// Lower aten::gelu to pointwise operators using Fast approximation
216+
// Gelu(x) = 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
217+
torch_tensorrt::core::lowering::passes::ReduceGelu(g);
218+
214219
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
215220
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
216221

tests/core/lowering/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ lowering_test(
4242
name = "test_reduce_to_pass",
4343
)
4444

45+
lowering_test(
46+
name = "test_reduce_gelu",
47+
)
48+
4549
lowering_test(
4650
name = "test_remove_detach_pass",
4751
)
@@ -73,6 +77,7 @@ test_suite(
7377
":test_remove_detach_pass",
7478
":test_remove_dropout_pass",
7579
":test_reduce_to_pass",
80+
":test_reduce_gelu",
7681
":test_unpack_hardswish",
7782
":test_unpack_reduce_ops"
7883
],
+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, ReduceGeluCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%x):
12+
%out : Tensor = aten::gelu(%x)
13+
return (%out))IR";
14+
std::string target_graph = R"IR(
15+
graph(%x.1 : Tensor):
16+
%6 : float = prim::Constant[value=0.044714999999999998]()
17+
%5 : float = prim::Constant[value=0.79788456080000003]()
18+
%4 : float = prim::Constant[value=1.]()
19+
%3 : float = prim::Constant[value=0.5]()
20+
%2 : int = prim::Constant[value=1]()
21+
%7 : Tensor = aten::mul(%x.1, %3)
22+
%8 : Tensor = aten::mul(%x.1, %5)
23+
%9 : Tensor = aten::mul(%x.1, %6)
24+
%10 : Tensor = aten::mul(%9, %x.1)
25+
%11 : Tensor = aten::add(%10, %4, %2)
26+
%12 : Tensor = aten::mul(%8, %11)
27+
%13 : Tensor = aten::tanh(%12)
28+
%14 : Tensor = aten::add(%13, %4, %2)
29+
%15 : Tensor = aten::mul(%7, %14)
30+
return (%15))IR";
31+
32+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
33+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
34+
auto sg = std::make_shared<torch::jit::Graph>();
35+
torch::jit::parseIR(source_graph, &*sg);
36+
torch_tensorrt::core::lowering::passes::ReduceGelu(sg);
37+
38+
auto tg = std::make_shared<torch::jit::Graph>();
39+
torch::jit::parseIR(target_graph, &*tg);
40+
41+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
42+
}

0 commit comments

Comments
 (0)