Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix quantize_graph pass error when there're multiple outputs from a s…
Browse files Browse the repository at this point in the history
…ingle node (#13000)

* fix quantize_graph pass error when there're multiple outputs from
a single node that need to insert 'contrib_quantize', 'min' and
'max' nodes for these outputs.

* fix lint

* Make the single output align with multiple outputs when inserting contrib_quantize

* Change op comparing from its name to itself

* skip unsupported quantize_concat

* retrigger ci
  • Loading branch information
ciyongch authored and TaoLv committed Nov 30, 2018
1 parent c72a38b commit 07a4319
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 111 deletions.
6 changes: 2 additions & 4 deletions python/mxnet/contrib/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def _quantize_params(qsym, params, th_dict):
elif name in params:
quantized_params[name] = params[name]
elif name.endswith(('_min')):
output = name[: - len('_min')] + "_output"
output = name[: - len('_min')]
if output in th_dict:
quantized_params[name] = ndarray.array([th_dict[output][0]])
elif name.endswith(('_max')):
output = name[: - len('_min')] + "_output"
output = name[: - len('_min')]
if output in th_dict:
quantized_params[name] = ndarray.array([th_dict[output][1]])
return quantized_params
Expand Down Expand Up @@ -513,8 +513,6 @@ def quantize_model(sym, arg_params, aux_params,
if not isinstance(calib_data, DataIter):
raise ValueError('calib_data must be of DataIter type when calib_mode=%s,'
' while received type %s' % (calib_mode, str(type(calib_data))))
if calib_layer is None:
calib_layer = lambda name: name.endswith('_output')

mod = Module(symbol=sym, data_names=data_names, label_names=label_names, context=ctx)
if len(calib_data.provide_label) > 0:
Expand Down
95 changes: 57 additions & 38 deletions src/operator/quantization/quantize_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ std::vector<NodeEntry> OfflineParams(std::vector<NodeEntry>&& outputs,
std::unordered_map<Node*, NodePtr> mirror_map;
nnvm::NodeEntryMap<NodePtr> entry_var;
auto need_offline = [&](NodePtr n) {
return n->op() &&
(n->op()->name == "_contrib_quantize") &&
return (n->op() == Op::Get("_contrib_quantize")) &&
n->inputs[0].node->is_variable() &&
offline_params.count(n->inputs[0].node->attrs.name);
};
Expand Down Expand Up @@ -117,9 +116,10 @@ inline bool NeedQuantize(NodePtr node, const std::unordered_set<std::string>& ex
}

Graph QuantizeGraph(Graph &&src) {
static auto& quantized_op_map = Op::GetAttr<mxnet::FQuantizedOp>("FQuantizedOp");
static auto& need_requantize_map = Op::GetAttr<mxnet::FNeedRequantize>("FNeedRequantize");
static auto& avoid_quantize_input_map =
static const auto& flist_outputs = nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
static const auto& quantized_op_map = Op::GetAttr<mxnet::FQuantizedOp>("FQuantizedOp");
static const auto& need_requantize_map = Op::GetAttr<mxnet::FNeedRequantize>("FNeedRequantize");
static const auto& avoid_quantize_input_map =
Op::GetAttr<mxnet::FAvoidQuantizeInput>("FAvoidQuantizeInput");
auto offline_params = src.GetAttr<std::unordered_set<std::string>>("offline_params");
auto excluded_nodes = src.GetAttr<std::unordered_set<std::string>>("excluded_nodes");
Expand All @@ -130,6 +130,7 @@ Graph QuantizeGraph(Graph &&src) {
// graph. Key is the currently visited graph's node pointer, and value is a copied node of the key
// node. The existing key's value may be updated with the newly created quantize/dequantize op.
std::unordered_map<Node*, NodePtr> mirror_map;
nnvm::NodeEntryMap<NodeEntry> mirror_entry_map;
DFSVisit(src.outputs, [&](const NodePtr& node) {
NodePtr new_node = Node::Create();
// If the currently visited node needs quantization, insert a quantize op node before the
Expand All @@ -154,30 +155,46 @@ Graph QuantizeGraph(Graph &&src) {
if (avoid_quantize_input_map.count(node->op()) &&
avoid_quantize_input_map[node->op()](node->attrs, i)) {
new_node->inputs.emplace_back(mirror_entry);
} else if (!NeedQuantize(e.node, excluded_nodes) &&
(mirror_node->op() == nullptr ||
mirror_node->op()->name != "_contrib_quantize")) {
NodePtr quantize_node = InsertNode("_contrib_quantize",
e.node->attrs.name + "_quantize", new_node, mirror_entry);
quantize_node->attrs.dict["out_type"] = quantized_dtype;
quantize_node->op()->attr_parser(&(quantize_node->attrs));
if (calib_quantize) {
NodePtr min_var = CreateNode("nullptr", e.node->attrs.name + "_min");
quantize_node->inputs.emplace_back(NodeEntry{min_var, 0, 0});
NodePtr max_var = CreateNode("nullptr", e.node->attrs.name + "_max");
quantize_node->inputs.emplace_back(NodeEntry{max_var, 0, 0});
} else if (!NeedQuantize(e.node, excluded_nodes)) {
if (mirror_entry_map.count(e)) {
new_node->inputs.emplace_back(mirror_entry_map[e]);
} else {
NodePtr min_node = InsertNode("min",
e.node->attrs.name + "_min", quantize_node, mirror_entry);
min_node->op()->attr_parser(&(min_node->attrs));
// When there're multiple entrys outgoing from a single node, need to add entry
// index (or output name) into quantize/min/max node to distinguish them.
// Or the output name is not ending with 'output', just put the output name here
// to better align with calibration phase. No need to change name to weights/bias.
std::string suffix = "";
if (mirror_node->op() != nullptr) {
auto list_output_names_func = flist_outputs.get(e.node->op(), nullptr);
if (list_output_names_func != nullptr) {
std::vector<std::string> names = list_output_names_func(e.node->attrs);
suffix = "_" + names[e.index];
} else {
suffix = "_" + std::to_string(e.index);
}
}

NodePtr quantize_node = InsertNode("_contrib_quantize",
e.node->attrs.name + suffix + "_quantize", new_node, mirror_entry);
quantize_node->attrs.dict["out_type"] = quantized_dtype;
quantize_node->op()->attr_parser(&(quantize_node->attrs));
if (calib_quantize) {
NodePtr min_var = CreateNode("nullptr", e.node->attrs.name + suffix + "_min");
quantize_node->inputs.emplace_back(NodeEntry{min_var, 0, 0});
NodePtr max_var = CreateNode("nullptr", e.node->attrs.name + suffix + "_max");
quantize_node->inputs.emplace_back(NodeEntry{max_var, 0, 0});
} else {
NodePtr min_node = InsertNode("min",
e.node->attrs.name + suffix + "_min", quantize_node, mirror_entry);
min_node->op()->attr_parser(&(min_node->attrs));

NodePtr max_node = InsertNode("max",
e.node->attrs.name + "_max", quantize_node, mirror_entry);
max_node->op()->attr_parser(&(max_node->attrs));
NodePtr max_node = InsertNode("max",
e.node->attrs.name + suffix + "_max", quantize_node, mirror_entry);
max_node->op()->attr_parser(&(max_node->attrs));
}
mirror_entry_map[e] = NodeEntry{quantize_node, 0, e.version};
}
mirror_map[e.node.get()] = std::move(quantize_node);
} else if (mirror_node->op() != nullptr
&& mirror_node->op()->name == "_contrib_dequantize") {
} else if (mirror_node->op() == Op::Get("_contrib_dequantize")) {
new_node->inputs.emplace_back(NodeEntry{mirror_node->inputs[0].node, e.index, e.version});
} else {
// If the entry e's node needs quantization, or mirror_entry is from a quantize op,
Expand All @@ -192,8 +209,7 @@ Graph QuantizeGraph(Graph &&src) {
for (size_t i = 0; i < node->inputs.size(); ++i) {
const auto& e = node->inputs[i];
NodePtr mirror_node = mirror_map.at(e.node.get());
if (mirror_node->op() != nullptr
&& mirror_node->op()->name == "_contrib_dequantize") {
if (mirror_node->op() == Op::Get("_contrib_dequantize")) {
mirror_node = mirror_node->inputs[0].node;
}
NodeEntry mirror_entry = NodeEntry{
Expand All @@ -215,12 +231,17 @@ Graph QuantizeGraph(Graph &&src) {
min_index = num_outputs + 2 * e.index;
max_index = num_outputs + 2 * e.index + 1;
} else {
CHECK(mirror_node->op() != nullptr &&
mirror_node->op()->name == "_contrib_quantize")
CHECK(mirror_entry_map.count(e))
<< "The input is not quantize or quantized_op";
}
new_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0});
new_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0});
if (mirror_entry_map.count(e)) {
auto quantize_entry = mirror_entry_map[e];
new_node->inputs.emplace_back(NodeEntry{quantize_entry.node, min_index, 0});
new_node->inputs.emplace_back(NodeEntry{quantize_entry.node, max_index, 0});
} else {
new_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0});
new_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0});
}
}

// If the new_node op registered attr FNeedRequantize, insert requantize node after it.
Expand Down Expand Up @@ -261,8 +282,7 @@ Graph QuantizeGraph(Graph &&src) {
mirror_node, e.index, e.version};
// if input node is quantized operator, add dequantize node
if (NeedQuantize(e.node, excluded_nodes) &&
(mirror_node->op() == nullptr ||
mirror_node->op()->name != "_contrib_dequantize")) {
(mirror_node->op() != Op::Get("_contrib_dequantize"))) {
// here we calculate the output number (exclude min/max, in order to
// calculate min/max index from mirror node) based on assumption that
// there is only 1min and 1max output from mirror node (which is
Expand All @@ -279,10 +299,9 @@ Graph QuantizeGraph(Graph &&src) {

new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0});
mirror_map[e.node.get()] = std::move(dequantize_node);
} else if (mirror_node->op() != nullptr
&& mirror_node->op()->name == "_contrib_quantize") {
} else if (mirror_entry_map.count(e)) {
new_node->inputs.emplace_back(
NodeEntry{mirror_node->inputs[0].node, e.index, e.version});
NodeEntry{mirror_entry_map[e].node->inputs[0].node, e.index, e.version});
} else {
new_node->inputs.emplace_back(
NodeEntry{mirror_node, e.index, e.version});
Expand Down Expand Up @@ -333,7 +352,7 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) {
// If the current op is requantize
// find the thresholds from the calibration table with the key equal
// to the current op's input node name, e.g. a quantized_conv2d node.
if (node->op() != nullptr && node->op()->name == "_contrib_requantize") {
if (node->op() == Op::Get("_contrib_requantize")) {
NodePtr quantized_op_node = node->inputs[0].node;
CHECK(quantized_op_node->op() != nullptr) << quantized_op_node->attrs.name
<< " must be an quantized op node";
Expand Down
Loading

0 comments on commit 07a4319

Please # to comment.