From c7d6b498503bc6b6f5eb6c9c8150ab01ff6c4a53 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Sun, 14 Jun 2020 15:54:04 -0700 Subject: [PATCH] feat(aten::permute): Implement permute support Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/converters/impl/shuffle.cpp | 21 ++++ tests/core/converters/test_shuffle.cpp | 122 +++++++++++++++----- 2 files changed, 115 insertions(+), 28 deletions(-) diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index 951635a8fc..463cd4edd2 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -59,6 +59,27 @@ static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; + } + }).pattern({ + "aten::permute(Tensor(a) self, int[] dims) -> (Tensor(a))", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + auto new_order = args[1].unwrapToIntList().vec(); + + LOG_DEBUG("Shuffle to: " << util::toDims(new_order)); + + auto shuffle = ctx->net->addShuffle(*in); + TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); + nvinfer1::Permutation permute; + std::copy(new_order.begin(), new_order.end(), permute.order); + shuffle->setSecondTranspose(permute); + shuffle->setName(util::node_info(n).c_str()); + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; } }); diff --git a/tests/core/converters/test_shuffle.cpp b/tests/core/converters/test_shuffle.cpp index fc9326da94..01b75378aa 100644 --- a/tests/core/converters/test_shuffle.cpp +++ b/tests/core/converters/test_shuffle.cpp @@ -30,17 +30,87 @@ TEST(Converters, ATenFlattenConvertsCorrectly) { // TODO: IR Parser doesnt work well with neg numbers TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %1 : int = prim::Constant[value=1]() - %2 : int = prim::Constant[value=2]() - %3 : Tensor = aten::flatten(%0, %1, %2) - return (%3))IR"; + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=1]() + %2 : int = prim::Constant[value=2]() + %3 : Tensor = aten::flatten(%0, %1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(0, 5, {2, 3, 3}, {at::kCUDA}); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + auto trt = trt_results[0].reshape_as(jit_results[0]); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} - auto g = std::make_shared(); +TEST(Converters, ATenReshapeConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=3]() + %2 : int = prim::Constant[value=2]() + %3 : int[] = prim::ListConstruct(%1, %2) + %4 : Tensor = aten::reshape(%0, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(0, 5, {2, 3}, {at::kCUDA}); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + auto trt = trt_results[0].reshape_as(jit_results[0]); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenViewConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=3]() + %2 : int = prim::Constant[value=2]() + %3 : int[] = prim::ListConstruct(%1, %2) + %4 : Tensor = aten::view(%0, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(0, 5, {2, 3}, {at::kCUDA}); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + auto trt = trt_results[0].reshape_as(jit_results[0]); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenPermuteConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[3, 0, 1, 2]]() + %3 : Tensor = aten::permute(%x.1, %2) + return (%3))IR"; + + auto g = std::make_shared(); torch::jit::parseIR(graph, &*g); - auto in = at::randint(0, 5, {2, 3, 3}, {at::kCUDA}); + auto in = at::randint(0, 5, {2, 3, 2, 3}, {at::kCUDA}); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); @@ -52,19 +122,17 @@ TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) { ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } -TEST(Converters, ATenReshapeConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %1 : int = prim::Constant[value=3]() - %2 : int = prim::Constant[value=2]() - %3 : int[] = prim::ListConstruct(%1, %2) - %4 : Tensor = aten::reshape(%0, %3) - return (%4))IR"; +TEST(Converters, ATenPermute3DConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[0, 2, 1]]() + %3 : Tensor = aten::permute(%x.1, %2) + return (%3))IR"; - auto g = std::make_shared(); + auto g = std::make_shared(); torch::jit::parseIR(graph, &*g); - auto in = at::randint(0, 5, {2, 3}, {at::kCUDA}); + auto in = at::randint(0, 5, {2, 2, 3}, {at::kCUDA}); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); @@ -76,19 +144,17 @@ TEST(Converters, ATenReshapeConvertsCorrectly) { ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } -TEST(Converters, ATenViewConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %1 : int = prim::Constant[value=3]() - %2 : int = prim::Constant[value=2]() - %3 : int[] = prim::ListConstruct(%1, %2) - %4 : Tensor = aten::view(%0, %3) - return (%4))IR"; +TEST(Converters, ATenPermute5DConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[3, 4, 0, 2, 1]]() + %3 : Tensor = aten::permute(%x.1, %2) + return (%3))IR"; - auto g = std::make_shared(); + auto g = std::make_shared(); torch::jit::parseIR(graph, &*g); - auto in = at::randint(0, 5, {2, 3}, {at::kCUDA}); + auto in = at::randint(0, 5, {2, 2, 1, 2, 3}, {at::kCUDA}); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});