Skip to content

Commit

Permalink
Add support to dump unsupported ops. Add lite_interpter_load test. (#…
Browse files Browse the repository at this point in the history
…34072)

Summary:
Pull Request resolved: pytorch/pytorch#34072

This diff helps check all the ops not supported by lite_interpreter.
Helpful mainly to find all the ops that need to be added instead of adding them
one by one.

Test Plan:
buck run caffe2/binaries:lite_interpreter_model_load --
--model=<bytecode-model-path>

Reviewed By: iseeyuan

Differential Revision: D20194092

fbshipit-source-id: 0d596cd0204308027194af7ed738551d0c32a374
  • Loading branch information
kimishpatel authored and facebook-github-bot committed Mar 4, 2020
1 parent 385067e commit 17a5c67
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 4 deletions.
30 changes: 30 additions & 0 deletions binaries/lite_interpreter_model_load.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "ATen/ATen.h"
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/serialization/import.h>
#include "torch/script.h"

C10_DEFINE_string(model, "", "The given bytecode model to check if it is supported by lite_interpreter.");

int main(int argc, char** argv) {
c10::SetUsageMessage(
"Check if exported bytecode model is runnable by lite_interpreter.\n"
"Example usage:\n"
"./lite_interpreter_model_load"
" --model=<model_file>");

if (!c10::ParseCommandLineFlags(&argc, &argv)) {
std::cerr << "Failed to parse command line flags!" << std::endl;
return 1;
}

if (FLAGS_model.empty()) {
std::cerr << FLAGS_model << ":Model file is not provided\n";
return -1;
}

torch::jit::mobile::Module bc = torch::jit::_load_for_mobile(FLAGS_model);
return 0;
}
7 changes: 5 additions & 2 deletions torch/csrc/jit/mobile/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void Function::append_instruction(OpCode op, int X, int N) {
code_->instructions_.emplace_back(op, X, N);
}

void Function::append_operator(const std::string& name,
bool Function::append_operator(const std::string& name,
const std::string& overload_name) {
// Keep the original opname in code_
code_->op_names_.emplace_back(name, overload_name);
Expand All @@ -29,13 +29,16 @@ void Function::append_operator(const std::string& name,
opname.name = "_" + opname.name;
}
auto op = c10::Dispatcher::singleton().findSchema(opname);
TORCH_CHECK(op.has_value(), opname.name, ".", opname.overload_name, " cannot be found.");
if (not op.has_value()) {
return false;
}
// TODO: operator.h now does not depend on Node* so we can also look up operators from
// that registry for use in mobile as a way to share implementations.
auto fn = [op](Stack& stack) {
c10::Dispatcher::singleton().callBoxed(*op, &stack);
};
code_->operators_.emplace_back(fn);
return true;
}

void Function::append_constant(const c10::IValue& constant) {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/mobile/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Function{
const std::string& name() const;
const c10::QualifiedName& qualname() const;
void append_instruction(OpCode op, int X, int N);
void append_operator(const std::string& name,
bool append_operator(const std::string& name,
const std::string& overload_name);
void append_constant(const c10::IValue& constant);
void append_type(const c10::TypePtr& type);
Expand Down
19 changes: 18 additions & 1 deletion torch/csrc/jit/mobile/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ IValue expect_field(IValue tup, const std::string& expected_name, size_t entry){
return row->elements().at(1);
}

void print_unsupported_ops_and_throw(const std::unordered_set<std::string>& unsupported_ops) {
std::string error_message("{");
for (const auto& op_name : unsupported_ops) {
error_message += op_name + ", ";
}
error_message += "}";
TORCH_CHECK(false, "Following ops cannot be found:", error_message);
}

void parseMethods(const std::vector<IValue>& vals, mobile::CompilationUnit& mcu) {
for (const auto& element : vals) {
const auto& m_tuple = element.toTuple()->elements();
Expand All @@ -72,14 +81,22 @@ void parseMethods(const std::vector<IValue>& vals, mobile::CompilationUnit& mcu)
function->append_instruction(op_code, X, N);
}

std::unordered_set<std::string> unsupported_op_names;
for (const auto& op : ops_list) {
auto op_item = op.toTuple()->elements();
TORCH_CHECK(op_item.size() == 2,
"There should be two parts in an operator name.");
function->append_operator(op_item[0].toString()->string(),
auto op_found = function->append_operator(op_item[0].toString()->string(),
op_item[1].toString()->string());
if (not op_found) {
unsupported_op_names.emplace(op_item[0].toString()->string() + "." + op_item[1].toString()->string());
}
}

if (not unsupported_op_names.empty()) {
print_unsupported_ops_and_throw(unsupported_op_names);
};

for (const auto& constant : consts_list) {
function->append_constant(constant);
}
Expand Down

0 comments on commit 17a5c67

Please # to comment.