Skip to content

Commit ccab7b9

Browse files
committed
feat(//core/conversion/converters/impl): Non dimensional reduce
converter Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent de8659b commit ccab7b9

File tree

2 files changed

+126
-8
lines changed

2 files changed

+126
-8
lines changed

core/conversion/converters/impl/reduce.cpp

+78-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,23 @@ namespace impl {
99
namespace {
1010
auto reduced_registrations = RegisterNodeConversionPatterns()
1111
.pattern({
12+
"aten::mean(Tensor self, *, ScalarType? dtype=None) -> (Tensor)",
13+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
14+
auto in_tensor = args[0].ITensor();
15+
auto in_dims = util::toVec(in_tensor->getDimensions());
16+
LOG_WARNING("Mean Converter disregards dtype");
17+
18+
uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1);
19+
20+
auto mean_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kAVG, axis_mask, false);
21+
22+
TRTORCH_CHECK(mean_layer, "Unable to create mean layer from node: " << *n);
23+
24+
mean_layer->setName(util::node_info(n).c_str());
25+
ctx->AssociateValueAndTensor(n->outputs()[0], mean_layer->getOutput(0));
26+
return true;
27+
}
28+
}).pattern({
1229
"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)",
1330
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1431
auto in_tensor = args[0].ITensor();
@@ -23,7 +40,7 @@ auto reduced_registrations = RegisterNodeConversionPatterns()
2340
TRTORCH_CHECK(mean_layer, "Unable to create mean layer from node: " << *n);
2441

2542
mean_layer->setName(util::node_info(n).c_str());
26-
associate_value_and_tensor(ctx, n->outputs()[0], mean_layer->getOutput(0));
43+
ctx->AssociateValueAndTensor(n->outputs()[0], mean_layer->getOutput(0));
2744
return true;
2845
}
2946
});
@@ -32,5 +49,64 @@ auto reduced_registrations = RegisterNodeConversionPatterns()
3249
} // namespace converters
3350
} // namespace conversion
3451
} // namespace core
35-
} // namespace trtorch
52+
} // namespace trtorch
53+
54+
// #include "core/util/prelude.h"
55+
// #include "core/conversion/converters/converters.h"
56+
57+
// namespace trtorch {
58+
// namespace core {
59+
// namespace conversion {
60+
// namespace converters {
61+
// namespace impl {
62+
// namespace {
63+
64+
// #define convert(unary, trt_type) \
65+
// auto unary##_registrations TRTORCH_UNUSED = \
66+
// RegisterNodeConversionPatterns().pattern( \
67+
// {"aten::" #unary "(Tensor self) -> Tensor", \
68+
// [](ConversionCtx *ctx, const torch::jit::Node *n, \
69+
// args &args) -> bool { \
70+
// auto in = args[0].ITensor(); \
71+
// auto unary = \
72+
// ctx->net->addUnary(*in, nvinfer1::UnaryOperation::trt_type); \
73+
// \
74+
// TRTORCH_CHECK( \
75+
// unary, \
76+
// "Unable to create " #unary " layer from node: " << *n); \
77+
// \
78+
// unary->setName(util::node_info(n).c_str()); \
79+
// auto out_tensor = ctx->AssociateValueAndTensor( \
80+
// n->outputs()[0], \
81+
// unary->getOutput(0)); \
82+
// LOG_DEBUG( \
83+
// "Output tensor shape: " << out_tensor->getDimensions()); \
84+
// \
85+
// return true; \
86+
// }});
87+
88+
// convert(cos, kCOS);
89+
// convert(acos, kACOS);
90+
// convert(cosh, kCOSH);
91+
// convert(sin, kSIN);
92+
// convert(asin, kASIN);
93+
// convert(sinh, kSINH);
94+
// convert(tan, kTAN);
95+
// convert(atan, kATAN);
96+
// convert(abs, kABS);
97+
// convert(floor, kFLOOR);
98+
// convert(reciprocal, kRECIP);
99+
// convert(log, kLOG);
100+
// convert(ceil, kCEIL);
101+
// convert(sqrt, kSQRT);
102+
// convert(exp, kEXP);
103+
// convert(neg, kNEG);
104+
105+
// #undef convert
36106

107+
// } // namespace
108+
// } // namespace impl
109+
// } // namespace converters
110+
// } // namespace conversion
111+
// } // namespace core
112+
// } // namespace trtorch

tests/core/converters/test_reduce.cpp

+48-6
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,55 @@
55
#include "core/compiler.h"
66

77
TEST(Converters, ATenMeanConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%4 : None = prim::Constant()
11+
%5 : Tensor = aten::mean(%0, %4)
12+
return (%5))IR";
13+
14+
auto g = std::make_shared<torch::jit::Graph>();
15+
torch::jit::script::parseIR(graph, &*g);
16+
17+
auto in = at::randint(-5, 5, {4, 4}, at::kCUDA);
18+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
19+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
20+
21+
in = at::clone(in);
22+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
23+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
24+
25+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
26+
}
27+
28+
TEST(Converters, ATenMeanHigherDimensionConvertsCorrectly) {
29+
const auto graph = R"IR(
30+
graph(%0 : Tensor):
31+
%4 : None = prim::Constant()
32+
%5 : Tensor = aten::mean(%0, %4)
33+
return (%5))IR";
34+
35+
auto g = std::make_shared<torch::jit::Graph>();
36+
torch::jit::script::parseIR(graph, &*g);
37+
38+
auto in = at::randint(-5, 5, {4, 4, 4, 4}, at::kCUDA);
39+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
40+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
41+
42+
in = at::clone(in);
43+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
44+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
45+
46+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
47+
}
48+
49+
TEST(Converters, ATenMeanRowConvertsCorrectly) {
850
const auto graph = R"IR(
951
graph(%0 : Tensor):
1052
%1 : int = prim::Constant[value=1]()
11-
%2 : int[] = prim::ListConstruct(%1)
53+
%2 : int[] = prim::ListConstruct(%1)
1254
%3 : bool = prim::Constant[value=0]()
1355
%4 : None = prim::Constant()
14-
%5 : Tensor = aten::mean(%0, %2, %3, %4)
56+
%5 : Tensor = aten::mean(%0, %2, %3, %4)
1557
return (%5))IR";
1658

1759
auto g = std::make_shared<torch::jit::Graph>();
@@ -24,18 +66,18 @@ TEST(Converters, ATenMeanConvertsCorrectly) {
2466
in = at::clone(in);
2567
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
2668
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
27-
69+
2870
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
2971
}
3072

3173
TEST(Converters, ATenMeanKeepDimsConvertsCorrectly) {
3274
const auto graph = R"IR(
3375
graph(%0 : Tensor):
3476
%1 : int = prim::Constant[value=1]()
35-
%2 : int[] = prim::ListConstruct(%1)
77+
%2 : int[] = prim::ListConstruct(%1)
3678
%3 : bool = prim::Constant[value=1]()
3779
%4 : None = prim::Constant()
38-
%5 : Tensor = aten::mean(%0, %2, %3, %4)
80+
%5 : Tensor = aten::mean(%0, %2, %3, %4)
3981
return (%5))IR";
4082

4183
auto g = std::make_shared<torch::jit::Graph>();
@@ -48,6 +90,6 @@ TEST(Converters, ATenMeanKeepDimsConvertsCorrectly) {
4890
in = at::clone(in);
4991
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
5092
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
51-
93+
5294
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
5395
}

0 commit comments

Comments
 (0)