From 54e407e3e2c405b0f922035d3c69ef8d6c627055 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 25 Mar 2021 12:47:01 -0500 Subject: [PATCH] feat: support Int/Bool and other constants' inputs/outputs for TensorRT segments Signed-off-by: Bo Wang --- core/conversion/evaluators/aten.cpp | 2 +- core/partitioning/partitioning.cpp | 41 +++++++++++++++++++++++++++-- core/partitioning/partitioning.h | 14 ++++++++++ 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 587bfe5b6e..c9b1df55fd 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -436,7 +436,7 @@ auto aten_registrations TRTORCH_UNUSED = if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); auto b = args.at(n->input(1)).unwrapToInt(); - return std::floor(a / b); + return static_cast(std::floor(a / b)); } else if (args.at(n->input(0)).IValue()->isDouble()) { auto a = args.at(n->input(0)).unwrapToDouble(); auto b = args.at(n->input(1)).unwrapToDouble(); diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index d817c46a24..8ac161dff6 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -3,6 +3,7 @@ #include "core/lowering/passes/passes.h" #include "core/util/prelude.h" #include "torch/csrc/jit/api/module.h" +#include "torch/csrc/jit/ir/constants.h" namespace trtorch { namespace core { @@ -67,6 +68,7 @@ void registerSegmentInOutIValues( // create a module to run the graph auto g = seg_block.g(); auto copy_g = g->copy(); +// LOG_INFO(*copy_g << "(copy graph)\n"); // create tuple for multiple outputs if (seg_block.raw_outputs().size() > 1) { @@ -163,19 +165,53 @@ void registerSegmentsInputsOutputs( input_values.insert(graph_output); } - for (auto& mini_graph_input : input_values) { - for (auto& seg_block : segmented_blocks) { + // should be careful here because some in-place operations don't return any values + for (auto& seg_block : segmented_blocks) { + for (auto& mini_graph_input : input_values) { if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) == seg_block.raw_inputs().end() && seg_block.contain_raw_input(mini_graph_input)) { seg_block.registerOutput(mini_graph_input); } } + if (seg_block.raw_outputs().empty()) { + seg_block.registerOutput(seg_block.raw_inputs()[0]); + } } return; } +void eraseNonTensorInputsOutputs( + SegmentedBlock& seg_block, + std::unordered_map& ivalues_maps) { + if (seg_block.target() == SegmentedBlock::kTorch) + return; + auto mini_graph = seg_block.g(); + + for (int i = seg_block.raw_inputs().size() - 1; i >= 0; --i) { + // erase this input and prepend a prim::Constant if it's not Tensor + if (!seg_block.raw_inputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) && + !seg_block.raw_inputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) { + auto new_val = torch::jit::insertConstant(*mini_graph, ivalues_maps[seg_block.raw_inputs()[i]]); + seg_block.inputs()[i]->replaceAllUsesWith(new_val); + seg_block.eraseInput(i); + } + } + + for (int i = seg_block.raw_outputs().size() - 1; i >= 0; --i) { + if (!seg_block.raw_outputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) && + !seg_block.raw_outputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) { + seg_block.eraseOutput(i); + } + } + + // not sure to delete this block or just fallback to pytorch + if (seg_block.raw_outputs().empty()) { + seg_block.update_target(SegmentedBlock::kTorch); + } +} + void construct_segments( std::vector& pytorch_nodes, std::vector& tensorrt_nodes, @@ -240,6 +276,7 @@ std::vector segment_graph( // register every segment's input shape, and it's running output Ivalues for (auto& seg_block : segmented_blocks) { registerSegmentInOutIValues(seg_block, ivalues_maps); + eraseNonTensorInputsOutputs(seg_block, ivalues_maps); } return segmented_blocks; diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h index 101addc242..4d4b5f07d7 100644 --- a/core/partitioning/partitioning.h +++ b/core/partitioning/partitioning.h @@ -66,10 +66,20 @@ struct SegmentedBlock { return g_->inputs(); } + void eraseInput(size_t i) { + inputs_.erase(inputs_.begin() + i); + g_->eraseInput(i); + } + c10::ArrayRef outputs() { return g_->outputs(); } + void eraseOutput(size_t i) { + outputs_.erase(outputs_.begin() + i); + g_->eraseOutput(i); + } + const std::vector& raw_inputs() const { return inputs_; } @@ -102,6 +112,10 @@ struct SegmentedBlock { g_ = new_g; } + void update_target(SegmentedBlockTarget new_target) { + target_ = new_target; + } + private: SegmentedBlockTarget target_; std::vector in_shape_;