@@ -134,6 +134,58 @@ converts_keepdims_correctly(mean, Mean);
134
134
135
135
#undef converts_keepdims_correctly
136
136
137
+ TEST (Converters, ATenSumDimNegOneIndexConvertsCorrectly) {
138
+ const auto graph = R"IR(
139
+ graph(%0 : Tensor):
140
+ %1 : int = prim::Constant[value=-1]()
141
+ %2 : int[] = prim::ListConstruct(%1)
142
+ %3 : bool = prim::Constant[value=0]()
143
+ %4 : None = prim::Constant()
144
+ %5 : Tensor = aten::sum(%0, %2, %3, %4)
145
+ return (%5))IR" ;
146
+ auto in = at::randint (-5 , 5 , {4 , 4 , 4 }, at::kCUDA );
147
+ test_body (graph, in);
148
+ }
149
+
150
+ TEST (Converters, ATenSumDimNegOneIndexKeepDimsConvertsCorrectly) {
151
+ const auto graph = R"IR(
152
+ graph(%0 : Tensor):
153
+ %1 : int = prim::Constant[value=-1]()
154
+ %2 : int[] = prim::ListConstruct(%1)
155
+ %3 : bool = prim::Constant[value=1]()
156
+ %4 : None = prim::Constant()
157
+ %5 : Tensor = aten::sum(%0, %2, %3, %4)
158
+ return (%5))IR" ;
159
+ auto in = at::randint (-5 , 5 , {4 , 4 , 4 }, at::kCUDA );
160
+ test_body (graph, in);
161
+ }
162
+
163
+ TEST (Converters, ATenSumDimNegIndexConvertsCorrectly) {
164
+ const auto graph = R"IR(
165
+ graph(%0 : Tensor):
166
+ %1 : int = prim::Constant[value=-2]()
167
+ %2 : int[] = prim::ListConstruct(%1)
168
+ %3 : bool = prim::Constant[value=0]()
169
+ %4 : None = prim::Constant()
170
+ %5 : Tensor = aten::sum(%0, %2, %3, %4)
171
+ return (%5))IR" ;
172
+ auto in = at::randint (-5 , 5 , {4 , 4 , 4 }, at::kCUDA );
173
+ test_body (graph, in);
174
+ }
175
+
176
+ TEST (Converters, ATenSumDimNegIndexKeepDimsConvertsCorrectly) {
177
+ const auto graph = R"IR(
178
+ graph(%0 : Tensor):
179
+ %1 : int = prim::Constant[value=-2]()
180
+ %2 : int[] = prim::ListConstruct(%1)
181
+ %3 : bool = prim::Constant[value=1]()
182
+ %4 : None = prim::Constant()
183
+ %5 : Tensor = aten::sum(%0, %2, %3, %4)
184
+ return (%5))IR" ;
185
+ auto in = at::randint (-5 , 5 , {4 , 4 , 4 }, at::kCUDA );
186
+ test_body (graph, in);
187
+ }
188
+
137
189
TEST (Converters, ATenProdDimConvertsCorrectly) {
138
190
const auto graph = R"IR(
139
191
graph(%0 : Tensor):
0 commit comments