@@ -566,6 +566,58 @@ TEST(Converters, ATenAdaptiveAvgPool1DUsingPluginConvertsCorrectly) {
566
566
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
567
567
}
568
568
569
+ TEST (Converters, ATenAdaptiveMaxPool1DGlobalPoolingConvertsCorrectly) {
570
+ const auto graph =
571
+ R"IR(
572
+ graph(%0 : Tensor):
573
+ %2 : int = prim::Constant[value=1]()
574
+ %6 : int[] = prim::ListConstruct(%2)
575
+ %10 : Tensor, %11 : Tensor = aten::adaptive_max_pool1d(%0, %6)
576
+ return (%10, %11))IR" ;
577
+
578
+ auto g = std::make_shared<torch::jit::Graph>();
579
+ torch::jit::parseIR (graph, g.get ());
580
+
581
+ // PyTorch adaptive_max_pool1d needs a 3D input or a 2D input
582
+ auto in = at::randint (-5 , 5 , {1 , 3 , 16 }, at::kCUDA );
583
+
584
+ auto jit_in = at::clone (in);
585
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
586
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
587
+
588
+ auto trt_in = at::clone (in);
589
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
590
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
591
+
592
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
593
+ }
594
+
595
+ TEST (Converters, ATenAdaptiveMaxPool1DUsingPluginConvertsCorrectly) {
596
+ const auto graph =
597
+ R"IR(
598
+ graph(%0 : Tensor):
599
+ %2 : int = prim::Constant[value=3]()
600
+ %6 : int[] = prim::ListConstruct(%2)
601
+ %10 : Tensor, %11 : Tensor = aten::adaptive_max_pool1d(%0, %6)
602
+ return (%10, %11))IR" ;
603
+
604
+ auto g = std::make_shared<torch::jit::Graph>();
605
+ torch::jit::parseIR (graph, g.get ());
606
+
607
+ // PyTorch adaptive_max_pool1d needs a 3D input or a 2D input
608
+ auto in = at::randint (-5 , 5 , {1 , 3 , 16 }, at::kCUDA );
609
+
610
+ auto jit_in = at::clone (in);
611
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
612
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
613
+
614
+ auto trt_in = at::clone (in);
615
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
616
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
617
+
618
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
619
+ }
620
+
569
621
TEST (Converters, ATenAdaptiveMaxPool2DConvertsCorrectly) {
570
622
const auto graph = R"IR(
571
623
graph(%0 : Tensor):
@@ -617,3 +669,115 @@ TEST(Converters, ATenAdaptiveMaxPool2DConvertsCorrectlyWithDynamicInput) {
617
669
618
670
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
619
671
}
672
+
673
+ TEST (Converters, ATenAdaptiveAvgPool3DGlobalPoolingConvertsCorrectly) {
674
+ const auto graph =
675
+ R"IR(
676
+ graph(%0 : Tensor):
677
+ %2 : int = prim::Constant[value=1]()
678
+ %3 : int = prim::Constant[value=1]()
679
+ %4 : int = prim::Constant[value=1]()
680
+ %6 : int[] = prim::ListConstruct(%2, %3, %4)
681
+ %10 : Tensor = aten::adaptive_avg_pool3d(%0, %6)
682
+ return (%10))IR" ;
683
+
684
+ auto g = std::make_shared<torch::jit::Graph>();
685
+ torch::jit::parseIR (graph, g.get ());
686
+
687
+ // PyTorch adaptive_avg_pool3d needs a 5D input or a 4D input
688
+ auto in = at::randint (-5 , 5 , {4 , 5 , 3 , 15 , 16 }, at::kCUDA );
689
+
690
+ auto jit_in = at::clone (in);
691
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
692
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
693
+
694
+ auto trt_in = at::clone (in);
695
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
696
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
697
+
698
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
699
+ }
700
+
701
+ TEST (Converters, ATenAdaptiveAvgPool3DUsingPluginConvertsCorrectly) {
702
+ const auto graph =
703
+ R"IR(
704
+ graph(%0 : Tensor):
705
+ %2 : int = prim::Constant[value=7]()
706
+ %3 : int = prim::Constant[value=6]()
707
+ %4 : int = prim::Constant[value=5]()
708
+ %6 : int[] = prim::ListConstruct(%2, %3, %4)
709
+ %10 : Tensor = aten::adaptive_avg_pool3d(%0, %6)
710
+ return (%10))IR" ;
711
+
712
+ auto g = std::make_shared<torch::jit::Graph>();
713
+ torch::jit::parseIR (graph, g.get ());
714
+
715
+ // PyTorch adaptive_avg_pool3d needs a 5D input or a 4D input
716
+ auto in = at::randint (-5 , 5 , {4 , 5 , 3 , 15 , 16 }, at::kCUDA );
717
+
718
+ auto jit_in = at::clone (in);
719
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
720
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
721
+
722
+ auto trt_in = at::clone (in);
723
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
724
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
725
+
726
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
727
+ }
728
+
729
+ TEST (Converters, ATenAdaptiveMaxPool3DGlobalPoolingConvertsCorrectly) {
730
+ const auto graph =
731
+ R"IR(
732
+ graph(%0 : Tensor):
733
+ %2 : int = prim::Constant[value=1]()
734
+ %3 : int = prim::Constant[value=1]()
735
+ %4 : int = prim::Constant[value=1]()
736
+ %6 : int[] = prim::ListConstruct(%2, %3, %4)
737
+ %10 : Tensor, %11 : Tensor = aten::adaptive_max_pool3d(%0, %6)
738
+ return (%10, %11))IR" ;
739
+
740
+ auto g = std::make_shared<torch::jit::Graph>();
741
+ torch::jit::parseIR (graph, g.get ());
742
+
743
+ // PyTorch adaptive_max_pool3d needs a 5D input or a 4D input
744
+ auto in = at::randint (-5 , 5 , {5 , 3 , 15 , 16 }, at::kCUDA );
745
+
746
+ auto jit_in = at::clone (in);
747
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
748
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
749
+
750
+ auto trt_in = at::clone (in);
751
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
752
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
753
+
754
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
755
+ }
756
+
757
+ TEST (Converters, ATenAdaptiveMaxPool3DUsingPluginConvertsCorrectly) {
758
+ const auto graph =
759
+ R"IR(
760
+ graph(%0 : Tensor):
761
+ %2 : int = prim::Constant[value=7]()
762
+ %3 : int = prim::Constant[value=8]()
763
+ %4 : int = prim::Constant[value=9]()
764
+ %6 : int[] = prim::ListConstruct(%2, %3, %4)
765
+ %10 : Tensor, %11 : Tensor = aten::adaptive_max_pool3d(%0, %6)
766
+ return (%10, %11))IR" ;
767
+
768
+ auto g = std::make_shared<torch::jit::Graph>();
769
+ torch::jit::parseIR (graph, g.get ());
770
+
771
+ // PyTorch adaptive_max_pool3d needs a 5D input or a 4D input
772
+ auto in = at::randint (-5 , 5 , {4 , 5 , 3 , 15 , 16 }, at::kCUDA );
773
+
774
+ auto jit_in = at::clone (in);
775
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
776
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
777
+
778
+ auto trt_in = at::clone (in);
779
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
780
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
781
+
782
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
783
+ }
0 commit comments