From 17a5c677963dc3ecb7ff505585ed15eadaaf74ef Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 4 Mar 2020 13:14:35 -0800 Subject: [PATCH] Add support to dump unsupported ops. Add lite_interpter_load test. (#34072) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34072 This diff helps check all the ops not supported by lite_interpreter. Helpful mainly to find all the ops that need to be added instead of adding them one by one. Test Plan: buck run caffe2/binaries:lite_interpreter_model_load -- --model= Reviewed By: iseeyuan Differential Revision: D20194092 fbshipit-source-id: 0d596cd0204308027194af7ed738551d0c32a374 --- binaries/lite_interpreter_model_load.cc | 30 +++++++++++++++++++++++++ torch/csrc/jit/mobile/function.cpp | 7 ++++-- torch/csrc/jit/mobile/function.h | 2 +- torch/csrc/jit/mobile/import.cpp | 19 +++++++++++++++- 4 files changed, 54 insertions(+), 4 deletions(-) create mode 100644 binaries/lite_interpreter_model_load.cc diff --git a/binaries/lite_interpreter_model_load.cc b/binaries/lite_interpreter_model_load.cc new file mode 100644 index 00000000000..5467d4dc939 --- /dev/null +++ b/binaries/lite_interpreter_model_load.cc @@ -0,0 +1,30 @@ +#include "ATen/ATen.h" +#include +#include +#include +#include +#include +#include "torch/script.h" + +C10_DEFINE_string(model, "", "The given bytecode model to check if it is supported by lite_interpreter."); + +int main(int argc, char** argv) { + c10::SetUsageMessage( + "Check if exported bytecode model is runnable by lite_interpreter.\n" + "Example usage:\n" + "./lite_interpreter_model_load" + " --model="); + + if (!c10::ParseCommandLineFlags(&argc, &argv)) { + std::cerr << "Failed to parse command line flags!" << std::endl; + return 1; + } + + if (FLAGS_model.empty()) { + std::cerr << FLAGS_model << ":Model file is not provided\n"; + return -1; + } + + torch::jit::mobile::Module bc = torch::jit::_load_for_mobile(FLAGS_model); + return 0; +} diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index f1b879d36db..c44fe90808d 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -18,7 +18,7 @@ void Function::append_instruction(OpCode op, int X, int N) { code_->instructions_.emplace_back(op, X, N); } -void Function::append_operator(const std::string& name, +bool Function::append_operator(const std::string& name, const std::string& overload_name) { // Keep the original opname in code_ code_->op_names_.emplace_back(name, overload_name); @@ -29,13 +29,16 @@ void Function::append_operator(const std::string& name, opname.name = "_" + opname.name; } auto op = c10::Dispatcher::singleton().findSchema(opname); - TORCH_CHECK(op.has_value(), opname.name, ".", opname.overload_name, " cannot be found."); + if (not op.has_value()) { + return false; + } // TODO: operator.h now does not depend on Node* so we can also look up operators from // that registry for use in mobile as a way to share implementations. auto fn = [op](Stack& stack) { c10::Dispatcher::singleton().callBoxed(*op, &stack); }; code_->operators_.emplace_back(fn); + return true; } void Function::append_constant(const c10::IValue& constant) { diff --git a/torch/csrc/jit/mobile/function.h b/torch/csrc/jit/mobile/function.h index 1c34f27e1f9..decb9d49798 100644 --- a/torch/csrc/jit/mobile/function.h +++ b/torch/csrc/jit/mobile/function.h @@ -18,7 +18,7 @@ class Function{ const std::string& name() const; const c10::QualifiedName& qualname() const; void append_instruction(OpCode op, int X, int N); - void append_operator(const std::string& name, + bool append_operator(const std::string& name, const std::string& overload_name); void append_constant(const c10::IValue& constant); void append_type(const c10::TypePtr& type); diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index bc99e345c2c..797717195de 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -47,6 +47,15 @@ IValue expect_field(IValue tup, const std::string& expected_name, size_t entry){ return row->elements().at(1); } +void print_unsupported_ops_and_throw(const std::unordered_set& unsupported_ops) { + std::string error_message("{"); + for (const auto& op_name : unsupported_ops) { + error_message += op_name + ", "; + } + error_message += "}"; + TORCH_CHECK(false, "Following ops cannot be found:", error_message); +} + void parseMethods(const std::vector& vals, mobile::CompilationUnit& mcu) { for (const auto& element : vals) { const auto& m_tuple = element.toTuple()->elements(); @@ -72,14 +81,22 @@ void parseMethods(const std::vector& vals, mobile::CompilationUnit& mcu) function->append_instruction(op_code, X, N); } + std::unordered_set unsupported_op_names; for (const auto& op : ops_list) { auto op_item = op.toTuple()->elements(); TORCH_CHECK(op_item.size() == 2, "There should be two parts in an operator name."); - function->append_operator(op_item[0].toString()->string(), + auto op_found = function->append_operator(op_item[0].toString()->string(), op_item[1].toString()->string()); + if (not op_found) { + unsupported_op_names.emplace(op_item[0].toString()->string() + "." + op_item[1].toString()->string()); + } } + if (not unsupported_op_names.empty()) { + print_unsupported_ops_and_throw(unsupported_op_names); + }; + for (const auto& constant : consts_list) { function->append_constant(constant); }