From 024a6b270c73ce1059e182e91b8cf11ae2578d18 Mon Sep 17 00:00:00 2001 From: Abhiram Iyer Date: Thu, 25 Jun 2020 12:38:59 -0700 Subject: [PATCH] fix(): need to fix gather converter Signed-off-by: Abhiram Iyer Signed-off-by: Abhiram Iyer --- core/conversion/converters/BUILD | 3 +- core/conversion/converters/impl/select.cpp | 67 ++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100755 core/conversion/converters/impl/select.cpp diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 263b04b1df..cf4bcc9446 100755 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -28,7 +28,8 @@ cc_library( "impl/shuffle.cpp", "impl/softmax.cpp", "impl/unary.cpp", - "impl/interpolate.cpp" + "impl/interpolate.cpp", + "impl/select.cpp" ], deps = [ "@tensorrt//:nvinfer", diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp new file mode 100755 index 0000000000..bfdeeb9f65 --- /dev/null +++ b/core/conversion/converters/impl/select.cpp @@ -0,0 +1,67 @@ +#include "torch/torch.h" +#include "core/util/prelude.h" +#include "core/conversion/converters/converters.h" +#include "NvInfer.h" +#include "torch/csrc/autograd/generated/variable_factories.h" + +#include +#include + +#include + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { + +auto select_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() + .pattern({ + "aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + std::cout << "select.int converter recognized" << std::endl; + + auto in = args[0].ITensor(); + auto axis = args[1].unwrapToInt(); + auto ind = (int32_t) args[2].unwrapToInt(); + + // tried: vector for input + //std::vector indices_input = {ind}; + + auto options = torch::TensorOptions().device(torch::kCUDA, 1).dtype(torch::kInt32); + at::Tensor indices = torch::tensor(torch::detail::TensorDataContainer(ind), options); + + auto weights = Weights(ctx, indices); + // manually setting weights + // weights.data.type = nvinfer1::DataType::kINT32; + + auto const_layer = ctx->net->addConstant(weights.shape, weights.data); + const_layer->setName(util::node_info(n).c_str()); + // manually setting output type + // const_layer->setOutputType(0, nvinfer1::DataType::kINT32); + + auto const_out = ctx->AssociateValueAndTensor(n->outputs()[0], const_layer->getOutput(0)); + + auto gather_layer = ctx->net->addGather(*in, *const_out, axis); + gather_layer->setName(util::node_info(n).c_str()); + // manually setting output type + // gather_layer->setOutputType(0, nvinfer1::DataType::kINT32); + + auto gather_output = ctx->AssociateValueAndTensor(n->outputs()[0], gather_layer->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << gather_output->getDimensions()); + + // for debugging + // std::raise(SIGTRAP); + + return true; + } + }); + +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch \ No newline at end of file