1
+ #include < string>
2
+ #include " gtest/gtest.h"
3
+ #include " torch/csrc/jit/irparser.h"
4
+ #include " tests/util/util.h"
5
+ #include " core/compiler.h"
6
+
7
+ TEST (Converters, ATenMeanConvertsCorrectly) {
8
+ const auto graph = R"IR(
9
+ graph(%0 : Tensor):
10
+ %1 : int = prim::Constant[value=1]()
11
+ %2 : int[] = prim::ListConstruct(%1)
12
+ %3 : bool = prim::Constant[value=0]()
13
+ %4 : None = prim::Constant()
14
+ %5 : Tensor = aten::mean(%0, %2, %3, %4)
15
+ return (%5))IR" ;
16
+
17
+ auto g = std::make_shared<torch::jit::Graph>();
18
+ torch::jit::script::parseIR (graph, &*g);
19
+
20
+ auto in = at::randint (-5 , 5 , {4 , 4 }, at::kCUDA );
21
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
22
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
23
+
24
+ in = at::clone (in);
25
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
26
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
27
+
28
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
29
+ }
30
+
31
+ TEST (Converters, ATenMeanKeepDimsConvertsCorrectly) {
32
+ const auto graph = R"IR(
33
+ graph(%0 : Tensor):
34
+ %1 : int = prim::Constant[value=1]()
35
+ %2 : int[] = prim::ListConstruct(%1)
36
+ %3 : bool = prim::Constant[value=1]()
37
+ %4 : None = prim::Constant()
38
+ %5 : Tensor = aten::mean(%0, %2, %3, %4)
39
+ return (%5))IR" ;
40
+
41
+ auto g = std::make_shared<torch::jit::Graph>();
42
+ torch::jit::script::parseIR (graph, &*g);
43
+
44
+ auto in = at::randint (-5 , 5 , {4 , 4 }, at::kCUDA );
45
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
46
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
47
+
48
+ in = at::clone (in);
49
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
50
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
51
+
52
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
53
+ }
0 commit comments