Skip to content

Commit e554dbd

Browse files
committed
feat: support aten::adaptive_max_pool1d, aten::adaptive_avg_pool3d and aten::adaptive_max_pool3d operators
Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>
1 parent deb9f74 commit e554dbd

File tree

3 files changed

+190
-4
lines changed

3 files changed

+190
-4
lines changed

core/conversion/converters/impl/pooling.cpp

+19-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ bool AdaptivePoolingConverter(
3737
ConversionCtx* ctx,
3838
const torch::jit::Node* n,
3939
args& args,
40-
nvinfer1::PoolingType pool_type, const std::string& mode) {
40+
nvinfer1::PoolingType pool_type,
41+
const std::string& mode) {
4142
auto in = args[0].ITensorOrFreeze(ctx);
4243
auto out_size = util::toDims(args[1].unwrapToIntList());
4344

@@ -226,15 +227,30 @@ auto pooling_registrations TORCHTRT_UNUSED =
226227
}})
227228
.pattern({"aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> (Tensor)",
228229
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
229-
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kAVERAGE, "adaptive_avg_pool1d");
230+
return AdaptivePoolingConverter(
231+
ctx, n, args, nvinfer1::PoolingType::kAVERAGE, "adaptive_avg_pool1d");
232+
}})
233+
.pattern({"aten::adaptive_max_pool1d(Tensor self, int[2] output_size) -> (Tensor, Tensor)",
234+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
235+
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kMAX, "adaptive_max_pool1d");
230236
}})
231237
.pattern({"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)",
232238
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
233-
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kAVERAGE, "adaptive_avg_pool2d");
239+
return AdaptivePoolingConverter(
240+
ctx, n, args, nvinfer1::PoolingType::kAVERAGE, "adaptive_avg_pool2d");
234241
}})
235242
.pattern({"aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)",
236243
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
237244
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kMAX, "adaptive_max_pool2d");
245+
}})
246+
.pattern({"aten::adaptive_avg_pool3d(Tensor self, int[3] output_size) -> (Tensor)",
247+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
248+
return AdaptivePoolingConverter(
249+
ctx, n, args, nvinfer1::PoolingType::kAVERAGE, "adaptive_avg_pool3d");
250+
}})
251+
.pattern({"aten::adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)",
252+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
253+
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kMAX, "adaptive_max_pool3d");
238254
}});
239255
} // namespace
240256
} // namespace impl

core/plugins/impl/interpolate_plugin.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,18 @@ int InterpolatePlugin::enqueue(
289289
out = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_);
290290
} else if (mode_ == "trilinear") {
291291
out = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_);
292-
} else if(mode_ == "adaptive_avg_pool1d"){
292+
} else if (mode_ == "adaptive_avg_pool1d") {
293293
out = at::adaptive_avg_pool1d(input, {size_[0]});
294+
} else if (mode_ == "adaptive_max_pool1d") {
295+
out = std::get<0>(at::adaptive_max_pool1d(input, {size_[0]}));
294296
} else if (mode_ == "adaptive_avg_pool2d") {
295297
out = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
296298
} else if (mode_ == "adaptive_max_pool2d") {
297299
out = std::get<0>(at::adaptive_max_pool2d(input, {size_[0], size_[1]}));
300+
} else if (mode_ == "adaptive_avg_pool3d") {
301+
out = at::adaptive_avg_pool3d(input, {size_[0], size_[1], size_[2]});
302+
} else if (mode_ == "adaptive_max_pool3d") {
303+
out = std::get<0>(at::adaptive_max_pool3d(input, {size_[0], size_[1], size_[2]}));
298304
}
299305
}
300306

tests/core/conversion/converters/test_pooling.cpp

+164
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,58 @@ TEST(Converters, ATenAdaptiveAvgPool1DUsingPluginConvertsCorrectly) {
566566
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
567567
}
568568

569+
TEST(Converters, ATenAdaptiveMaxPool1DGlobalPoolingConvertsCorrectly) {
570+
const auto graph =
571+
R"IR(
572+
graph(%0 : Tensor):
573+
%2 : int = prim::Constant[value=1]()
574+
%6 : int[] = prim::ListConstruct(%2)
575+
%10 : Tensor, %11 : Tensor = aten::adaptive_max_pool1d(%0, %6)
576+
return (%10, %11))IR";
577+
578+
auto g = std::make_shared<torch::jit::Graph>();
579+
torch::jit::parseIR(graph, g.get());
580+
581+
// PyTorch adaptive_max_pool1d needs a 3D input or a 2D input
582+
auto in = at::randint(-5, 5, {1, 3, 16}, at::kCUDA);
583+
584+
auto jit_in = at::clone(in);
585+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
586+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
587+
588+
auto trt_in = at::clone(in);
589+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
590+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
591+
592+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
593+
}
594+
595+
TEST(Converters, ATenAdaptiveMaxPool1DUsingPluginConvertsCorrectly) {
596+
const auto graph =
597+
R"IR(
598+
graph(%0 : Tensor):
599+
%2 : int = prim::Constant[value=3]()
600+
%6 : int[] = prim::ListConstruct(%2)
601+
%10 : Tensor, %11 : Tensor = aten::adaptive_max_pool1d(%0, %6)
602+
return (%10, %11))IR";
603+
604+
auto g = std::make_shared<torch::jit::Graph>();
605+
torch::jit::parseIR(graph, g.get());
606+
607+
// PyTorch adaptive_max_pool1d needs a 3D input or a 2D input
608+
auto in = at::randint(-5, 5, {1, 3, 16}, at::kCUDA);
609+
610+
auto jit_in = at::clone(in);
611+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
612+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
613+
614+
auto trt_in = at::clone(in);
615+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
616+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
617+
618+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
619+
}
620+
569621
TEST(Converters, ATenAdaptiveMaxPool2DConvertsCorrectly) {
570622
const auto graph = R"IR(
571623
graph(%0 : Tensor):
@@ -617,3 +669,115 @@ TEST(Converters, ATenAdaptiveMaxPool2DConvertsCorrectlyWithDynamicInput) {
617669

618670
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
619671
}
672+
673+
TEST(Converters, ATenAdaptiveAvgPool3DGlobalPoolingConvertsCorrectly) {
674+
const auto graph =
675+
R"IR(
676+
graph(%0 : Tensor):
677+
%2 : int = prim::Constant[value=1]()
678+
%3 : int = prim::Constant[value=1]()
679+
%4 : int = prim::Constant[value=1]()
680+
%6 : int[] = prim::ListConstruct(%2, %3, %4)
681+
%10 : Tensor = aten::adaptive_avg_pool3d(%0, %6)
682+
return (%10))IR";
683+
684+
auto g = std::make_shared<torch::jit::Graph>();
685+
torch::jit::parseIR(graph, g.get());
686+
687+
// PyTorch adaptive_avg_pool3d needs a 5D input or a 4D input
688+
auto in = at::randint(-5, 5, {4, 5, 3, 15, 16}, at::kCUDA);
689+
690+
auto jit_in = at::clone(in);
691+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
692+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
693+
694+
auto trt_in = at::clone(in);
695+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
696+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
697+
698+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
699+
}
700+
701+
TEST(Converters, ATenAdaptiveAvgPool3DUsingPluginConvertsCorrectly) {
702+
const auto graph =
703+
R"IR(
704+
graph(%0 : Tensor):
705+
%2 : int = prim::Constant[value=7]()
706+
%3 : int = prim::Constant[value=6]()
707+
%4 : int = prim::Constant[value=5]()
708+
%6 : int[] = prim::ListConstruct(%2, %3, %4)
709+
%10 : Tensor = aten::adaptive_avg_pool3d(%0, %6)
710+
return (%10))IR";
711+
712+
auto g = std::make_shared<torch::jit::Graph>();
713+
torch::jit::parseIR(graph, g.get());
714+
715+
// PyTorch adaptive_avg_pool3d needs a 5D input or a 4D input
716+
auto in = at::randint(-5, 5, {4, 5, 3, 15, 16}, at::kCUDA);
717+
718+
auto jit_in = at::clone(in);
719+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
720+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
721+
722+
auto trt_in = at::clone(in);
723+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
724+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
725+
726+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
727+
}
728+
729+
TEST(Converters, ATenAdaptiveMaxPool3DGlobalPoolingConvertsCorrectly) {
730+
const auto graph =
731+
R"IR(
732+
graph(%0 : Tensor):
733+
%2 : int = prim::Constant[value=1]()
734+
%3 : int = prim::Constant[value=1]()
735+
%4 : int = prim::Constant[value=1]()
736+
%6 : int[] = prim::ListConstruct(%2, %3, %4)
737+
%10 : Tensor, %11 : Tensor = aten::adaptive_max_pool3d(%0, %6)
738+
return (%10, %11))IR";
739+
740+
auto g = std::make_shared<torch::jit::Graph>();
741+
torch::jit::parseIR(graph, g.get());
742+
743+
// PyTorch adaptive_max_pool3d needs a 5D input or a 4D input
744+
auto in = at::randint(-5, 5, {5, 3, 15, 16}, at::kCUDA);
745+
746+
auto jit_in = at::clone(in);
747+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
748+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
749+
750+
auto trt_in = at::clone(in);
751+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
752+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
753+
754+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
755+
}
756+
757+
TEST(Converters, ATenAdaptiveMaxPool3DUsingPluginConvertsCorrectly) {
758+
const auto graph =
759+
R"IR(
760+
graph(%0 : Tensor):
761+
%2 : int = prim::Constant[value=7]()
762+
%3 : int = prim::Constant[value=8]()
763+
%4 : int = prim::Constant[value=9]()
764+
%6 : int[] = prim::ListConstruct(%2, %3, %4)
765+
%10 : Tensor, %11 : Tensor = aten::adaptive_max_pool3d(%0, %6)
766+
return (%10, %11))IR";
767+
768+
auto g = std::make_shared<torch::jit::Graph>();
769+
torch::jit::parseIR(graph, g.get());
770+
771+
// PyTorch adaptive_max_pool3d needs a 5D input or a 4D input
772+
auto in = at::randint(-5, 5, {4, 5, 3, 15, 16}, at::kCUDA);
773+
774+
auto jit_in = at::clone(in);
775+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
776+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
777+
778+
auto trt_in = at::clone(in);
779+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
780+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
781+
782+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
783+
}

0 commit comments

Comments
 (0)