Skip to content

Commit ed3c185

Browse files
committed
fix(aten::max_pool2d): Supressing error due to not filling in stride in
the default case Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 4ee6c20 commit ed3c185

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

core/conversion/converters/impl/pooling.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ bool MaxPoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& ar
3030
auto padding = util::toDims(args[3].unwrapToIntList());
3131
LOG_DEBUG("padding: " << padding);
3232
auto stride = util::toDims(args[2].unwrapToIntList());
33+
if (args[2].unwrapToIntList().size() == 0) {
34+
LOG_DEBUG("Stride not providied, using kernel_size as stride");
35+
stride = util::toDims(args[1].unwrapToIntList());
36+
}
3337
LOG_DEBUG("stride: " << stride);
3438

3539
auto dilation = util::toDims(args[4].unwrapToIntList());
@@ -88,6 +92,10 @@ bool AvgPoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& ar
8892
auto padding = util::toDims(args[3].unwrapToIntList());
8993
LOG_DEBUG("padding: " << padding);
9094
auto stride = util::toDims(args[2].unwrapToIntList());
95+
if (args[2].unwrapToIntList().size() == 0) {
96+
LOG_DEBUG("Stride not providied, using kernel_size as stride");
97+
stride = util::toDims(args[1].unwrapToIntList());
98+
}
9199
LOG_DEBUG("stride: " << stride);
92100

93101
bool ceil_mode = args[4].unwrapToBool();

0 commit comments

Comments
 (0)