From 58eeab732d308635c01020e8aa643c76221785bb Mon Sep 17 00:00:00 2001 From: vandanavk Date: Fri, 21 Sep 2018 16:24:24 -0700 Subject: [PATCH] ONNX export: Square and sum operators --- .../contrib/onnx/mx2onnx/_op_translations.py | 74 +++++++++++++++++++ .../onnx/export/onnx_backend_test.py | 2 + 2 files changed, 76 insertions(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 8eb2fda10f17..d42518f5b2dc 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2110,3 +2110,77 @@ def convert_sqrt(node, **kwargs): name=name, ) return [node] + +@mx_op.register("square") +def convert_square(node, **kwargs): + """Map MXNet's square operator attributes to onnx's Pow operator + and return the created node. + """ + onnx = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_a = proc_nodes[input_node_a_id].name + + initializer = kwargs["initializer"] + data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')] + + power2_name = "square_tensor" + str(kwargs["idx"]) + tensor_node = onnx.helper.make_tensor_value_info(power2_name, data_type, (1,)) + initializer.append( + onnx.helper.make_tensor( + name=power2_name, + data_type=data_type, + dims=(1,), + vals=[2], + raw=False, + ) + ) + + node = onnx.helper.make_node( + "Pow", + [input_node_a, power2_name], + [name], + name=None + ) + return [tensor_node, node] + +@mx_op.register("sum") +def convert_sum(node, **kwargs): + """Map MXNet's sum operator attributes to onnx's ReduceSum operator + and return the created node. + """ + onnx = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + attrs = node["attrs"] + + mx_axis = attrs.get("axis", None) + axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None + + keepdims = get_boolean_attribute_value(attrs, "keepdims") + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + if axes: + node = onnx.helper.make_node( + 'ReduceSum', + inputs=[input_node], + outputs=[name], + axes=axes, + keepdims=keepdims, + name=name + ) + else: + node = onnx.helper.make_node( + 'ReduceSum', + inputs=[input_node], + outputs=[name], + keepdims=keepdims, + name=name + ) + return [node] diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py b/tests/python-pytest/onnx/export/onnx_backend_test.py index 2216a8f407f7..f5ebabae5f36 100644 --- a/tests/python-pytest/onnx/export/onnx_backend_test.py +++ b/tests/python-pytest/onnx/export/onnx_backend_test.py @@ -64,6 +64,8 @@ 'test_reduce_max', 'test_reduce_mean', 'test_reduce_prod', + 'test_reduce_sum_d', + 'test_reduce_sum_keepdims_random', 'test_squeeze', 'test_softmax_example', 'test_softmax_large_number',