5
5
#include " core/compiler.h"
6
6
7
7
TEST (Converters, ATenMeanConvertsCorrectly) {
8
+ const auto graph = R"IR(
9
+ graph(%0 : Tensor):
10
+ %4 : None = prim::Constant()
11
+ %5 : Tensor = aten::mean(%0, %4)
12
+ return (%5))IR" ;
13
+
14
+ auto g = std::make_shared<torch::jit::Graph>();
15
+ torch::jit::script::parseIR (graph, &*g);
16
+
17
+ auto in = at::randint (-5 , 5 , {4 , 4 }, at::kCUDA );
18
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
19
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
20
+
21
+ in = at::clone (in);
22
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
23
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
24
+
25
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
26
+ }
27
+
28
+ TEST (Converters, ATenMeanHigherDimensionConvertsCorrectly) {
29
+ const auto graph = R"IR(
30
+ graph(%0 : Tensor):
31
+ %4 : None = prim::Constant()
32
+ %5 : Tensor = aten::mean(%0, %4)
33
+ return (%5))IR" ;
34
+
35
+ auto g = std::make_shared<torch::jit::Graph>();
36
+ torch::jit::script::parseIR (graph, &*g);
37
+
38
+ auto in = at::randint (-5 , 5 , {4 , 4 , 4 , 4 }, at::kCUDA );
39
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
40
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
41
+
42
+ in = at::clone (in);
43
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
44
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
45
+
46
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
47
+ }
48
+
49
+ TEST (Converters, ATenMeanRowConvertsCorrectly) {
8
50
const auto graph = R"IR(
9
51
graph(%0 : Tensor):
10
52
%1 : int = prim::Constant[value=1]()
11
- %2 : int[] = prim::ListConstruct(%1)
53
+ %2 : int[] = prim::ListConstruct(%1)
12
54
%3 : bool = prim::Constant[value=0]()
13
55
%4 : None = prim::Constant()
14
- %5 : Tensor = aten::mean(%0, %2, %3, %4)
56
+ %5 : Tensor = aten::mean(%0, %2, %3, %4)
15
57
return (%5))IR" ;
16
58
17
59
auto g = std::make_shared<torch::jit::Graph>();
@@ -24,18 +66,18 @@ TEST(Converters, ATenMeanConvertsCorrectly) {
24
66
in = at::clone (in);
25
67
params = trtorch::core::conversion::get_named_params (g->inputs (), {});
26
68
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
27
-
69
+
28
70
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
29
71
}
30
72
31
73
TEST (Converters, ATenMeanKeepDimsConvertsCorrectly) {
32
74
const auto graph = R"IR(
33
75
graph(%0 : Tensor):
34
76
%1 : int = prim::Constant[value=1]()
35
- %2 : int[] = prim::ListConstruct(%1)
77
+ %2 : int[] = prim::ListConstruct(%1)
36
78
%3 : bool = prim::Constant[value=1]()
37
79
%4 : None = prim::Constant()
38
- %5 : Tensor = aten::mean(%0, %2, %3, %4)
80
+ %5 : Tensor = aten::mean(%0, %2, %3, %4)
39
81
return (%5))IR" ;
40
82
41
83
auto g = std::make_shared<torch::jit::Graph>();
@@ -48,6 +90,6 @@ TEST(Converters, ATenMeanKeepDimsConvertsCorrectly) {
48
90
in = at::clone (in);
49
91
params = trtorch::core::conversion::get_named_params (g->inputs (), {});
50
92
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
51
-
93
+
52
94
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
53
95
}
0 commit comments