Skip to content

Commit 769bbc9

Browse files
committed
feat(aten::sum): Allow for negative indices less than -1
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 3e7cf8e commit 769bbc9

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

core/conversion/converters/impl/reduce.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ auto reduce_registrations TRTORCH_UNUSED =
8585
LOG_DEBUG("InDims " << in_dims); // Some abuse of toDim but just for debug info
8686
LOG_DEBUG(
8787
"Dim to reduce(original):" << util::toDims(dims)); // Some abuse of toDim but just for debug info
88-
for (int i = 0; i < dims.size(); i++) {
89-
auto dim_val = dims[i] == -1 ? (in_dims.size() - 1) : dims[i];
88+
for (size_t i = 0; i < dims.size(); i++) {
89+
auto dim_val = dims[i] < 0 ? (in_dims.size() + dims[i]) : dims[i];
9090
calculated_dims.push_back(dim_val);
9191
}
9292

tests/core/conversion/converters/test_reduce.cpp

+52
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,58 @@ converts_keepdims_correctly(mean, Mean);
134134

135135
#undef converts_keepdims_correctly
136136

137+
TEST(Converters, ATenSumDimNegOneIndexConvertsCorrectly) {
138+
const auto graph = R"IR(
139+
graph(%0 : Tensor):
140+
%1 : int = prim::Constant[value=-1]()
141+
%2 : int[] = prim::ListConstruct(%1)
142+
%3 : bool = prim::Constant[value=0]()
143+
%4 : None = prim::Constant()
144+
%5 : Tensor = aten::sum(%0, %2, %3, %4)
145+
return (%5))IR";
146+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
147+
test_body(graph, in);
148+
}
149+
150+
TEST(Converters, ATenSumDimNegOneIndexKeepDimsConvertsCorrectly) {
151+
const auto graph = R"IR(
152+
graph(%0 : Tensor):
153+
%1 : int = prim::Constant[value=-1]()
154+
%2 : int[] = prim::ListConstruct(%1)
155+
%3 : bool = prim::Constant[value=1]()
156+
%4 : None = prim::Constant()
157+
%5 : Tensor = aten::sum(%0, %2, %3, %4)
158+
return (%5))IR";
159+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
160+
test_body(graph, in);
161+
}
162+
163+
TEST(Converters, ATenSumDimNegIndexConvertsCorrectly) {
164+
const auto graph = R"IR(
165+
graph(%0 : Tensor):
166+
%1 : int = prim::Constant[value=-2]()
167+
%2 : int[] = prim::ListConstruct(%1)
168+
%3 : bool = prim::Constant[value=0]()
169+
%4 : None = prim::Constant()
170+
%5 : Tensor = aten::sum(%0, %2, %3, %4)
171+
return (%5))IR";
172+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
173+
test_body(graph, in);
174+
}
175+
176+
TEST(Converters, ATenSumDimNegIndexKeepDimsConvertsCorrectly) {
177+
const auto graph = R"IR(
178+
graph(%0 : Tensor):
179+
%1 : int = prim::Constant[value=-2]()
180+
%2 : int[] = prim::ListConstruct(%1)
181+
%3 : bool = prim::Constant[value=1]()
182+
%4 : None = prim::Constant()
183+
%5 : Tensor = aten::sum(%0, %2, %3, %4)
184+
return (%5))IR";
185+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
186+
test_body(graph, in);
187+
}
188+
137189
TEST(Converters, ATenProdDimConvertsCorrectly) {
138190
const auto graph = R"IR(
139191
graph(%0 : Tensor):

0 commit comments

Comments
 (0)