Skip to content

Commit 259aa4c

Browse files
committed
feat(//core/conversion/converters/impl/reduce): Mean reduce converter
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 79c909c commit 259aa4c

File tree

4 files changed

+97
-0
lines changed

4 files changed

+97
-0
lines changed

core/conversion/converters/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ cc_library(
1616
"impl/element_wise.cpp",
1717
"impl/linear.cpp",
1818
"impl/pooling.cpp",
19+
"impl/reduce.cpp",
1920
"impl/softmax.cpp",
2021
"impl/unary.cpp",
2122
],
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include "core/util/prelude.h"
2+
#include "core/conversion/converters/converters.h"
3+
4+
namespace trtorch {
5+
namespace core {
6+
namespace conversion {
7+
namespace converters {
8+
namespace impl {
9+
namespace {
10+
auto reduced_registrations = RegisterNodeConversionPatterns()
11+
.pattern({
12+
"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)",
13+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
14+
auto in_tensor = args[0].ITensor();
15+
auto dim = args[1].unwrapToIntList();
16+
auto keepdim = args[2].unwrapToBool();
17+
18+
uint32_t axis_mask = 1 << dim[0];
19+
20+
LOG_WARNING("Mean converter disregards dtype");
21+
auto mean_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kAVG, axis_mask, keepdim);
22+
mean_layer->setName(util::node_info(n).c_str());
23+
24+
auto out_value = n->outputs()[0];
25+
auto out_tensor = mean_layer->getOutput(0);
26+
out_tensor->setName(out_value->debugName().c_str());
27+
ctx->value_tensor_map[out_value] = out_tensor;
28+
29+
return true;
30+
}
31+
});
32+
} // namespace
33+
} // namespace impl
34+
} // namespace converters
35+
} // namespace conversion
36+
} // namespace core
37+
} // namespace trtorch
38+

tests/core/converters/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ converter_test(
2828
name = "test_conv"
2929
)
3030

31+
converter_test(
32+
name = "test_reduce"
33+
)
34+
3135
test_suite(
3236
name = "test_converters",
3337
tests = [
@@ -38,6 +42,7 @@ test_suite(
3842
":test_linear",
3943
":test_element_wise",
4044
":test_conv",
45+
":test_reduce"
4146
]
4247
)
4348

tests/core/converters/test_reduce.cpp

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/irparser.h"
4+
#include "tests/util/util.h"
5+
#include "core/compiler.h"
6+
7+
TEST(Converters, ATenMeanConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%1 : int = prim::Constant[value=1]()
11+
%2 : int[] = prim::ListConstruct(%1)
12+
%3 : bool = prim::Constant[value=0]()
13+
%4 : None = prim::Constant()
14+
%5 : Tensor = aten::mean(%0, %2, %3, %4)
15+
return (%5))IR";
16+
17+
auto g = std::make_shared<torch::jit::Graph>();
18+
torch::jit::script::parseIR(graph, &*g);
19+
20+
auto in = at::randint(-5, 5, {4, 4}, at::kCUDA);
21+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
22+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
23+
24+
in = at::clone(in);
25+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
26+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
27+
28+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
29+
}
30+
31+
TEST(Converters, ATenMeanKeepDimsConvertsCorrectly) {
32+
const auto graph = R"IR(
33+
graph(%0 : Tensor):
34+
%1 : int = prim::Constant[value=1]()
35+
%2 : int[] = prim::ListConstruct(%1)
36+
%3 : bool = prim::Constant[value=1]()
37+
%4 : None = prim::Constant()
38+
%5 : Tensor = aten::mean(%0, %2, %3, %4)
39+
return (%5))IR";
40+
41+
auto g = std::make_shared<torch::jit::Graph>();
42+
torch::jit::script::parseIR(graph, &*g);
43+
44+
auto in = at::randint(-5, 5, {4, 4}, at::kCUDA);
45+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
46+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
47+
48+
in = at::clone(in);
49+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
50+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
51+
52+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
53+
}

0 commit comments

Comments
 (0)