Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/ops/opset3.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ declared in `namespace opset3`.
* [ROIAlign](detection/ROIAlign_3.md)
* [ROIPooling](detection/ROIPooling_1.md)
* [ScatterElementsUpdate](movement/ScatterElementsUpdate_3.md)
* [ScatterNDUpdate](movement/ScatterNDUpdate_3.md)
* [ScatterUpdate](movement/ScatterUpdate_3.md)
* [Select](condition/Select_1.md)
* [Selu](arithmetic/Selu_1.md)
Expand Down
4 changes: 2 additions & 2 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ extensions/front/tf/sparse_fill_empty_rows_ext.py
extensions/front/tf/sparse_segment_mean_ext.py
extensions/front/tf/sparse_segment_sqrtn_ext.py
extensions/front/tf/sparse_segment_sum_ext.py
extensions/front/tf/sparse_to_dense_ext.py
extensions/front/tf/sparse_to_dense_replacer.py
extensions/front/tf/split_ext.py
extensions/front/tf/ssd_support.json
extensions/front/tf/ssd_support_api_v1.14.json
Expand Down Expand Up @@ -654,6 +654,7 @@ extensions/ops/RNNCell.py
extensions/ops/roialign.py
extensions/ops/roifeatureextractor_onnx.py
extensions/ops/scatter.py
extensions/ops/scatternd.py
extensions/ops/select.py
extensions/ops/shufflechannel.py
extensions/ops/simplernms.py
Expand All @@ -665,7 +666,6 @@ extensions/ops/sparse_reshape.py
extensions/ops/sparse_segment_mean.py
extensions/ops/sparse_segment_sqrtn.py
extensions/ops/sparse_segment_sum.py
extensions/ops/sparse_to_dense.py
extensions/ops/spatial_transformer.py
extensions/ops/splice.py
extensions/ops/split.py
Expand Down
28 changes: 0 additions & 28 deletions model-optimizer/extensions/front/tf/sparse_to_dense_ext.py

This file was deleted.

66 changes: 66 additions & 0 deletions model-optimizer/extensions/front/tf/sparse_to_dense_replacer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Copyright (C) 2020 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy as np

from extensions.ops.Cast import Cast
from extensions.ops.scatternd import ScatterNDUpdate
from mo.front.common.replacement import FrontReplacementOp
from mo.graph.graph import Node, Graph, rename_nodes
from mo.ops.broadcast import Broadcast
from mo.ops.const import Const


class SparseToDenseReplacer(FrontReplacementOp):
"""
This replacer substitutes TensorFlow SparseToDense operation with Broadcast -> ScatterND chain.
The Broadcast operation creates a tensor filled with default value and of required shape.
The ScatterND operation updates the created tensor with required values at required locations.
"""
op = "SparseToDense"
enabled = True

def run_after(self):
from extensions.front.tf.CTCGreedyDecoder import CTCGreedyDecoderReplacement
return [CTCGreedyDecoderReplacement]

def replace_op(self, graph: Graph, node: Node):
node_name = node.soft_get('name', node.id)

# broadcast default value to required shape
broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node()
node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1))
if not node.in_port(3).disconnected():
# TODO: remove casting once we start to support I64 model input
# cast default value to I32 due limitation about I64 input support
# so that input parameter and default value will be of the same I32 type as required ScatterNDUpdate
cast_default_value = Cast(graph, {'name': node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
node.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
broadcast_node.in_port(0).connect(cast_default_value.out_port(0))
else:
broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_',
'value': np.float32(0)}
).create_node().out_port(0))

# update broadcasted tensor with required values at required locations
scatternd_node = ScatterNDUpdate(graph, {'name': node_name + '/ScatterNDUpdate_'}).create_node()
scatternd_node.in_port(0).connect(broadcast_node.out_port(0))
node.in_port(0).get_connection().set_destination(scatternd_node.in_port(1))
node.in_port(2).get_connection().set_destination(scatternd_node.in_port(2))

rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)])

return [scatternd_node.id]
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Copyright (C) 2020 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

from extensions.front.tf.sparse_to_dense_replacer import SparseToDenseReplacer
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph
from mo.utils.unittest.graph import build_graph, const


class SparseToDenseFrontReplacersTest(unittest.TestCase):
def test1(self):
nodes_attributes = {
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_values' : {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},

'sparse_to_dense' : {'kind': 'op', 'op': 'SparseToDense'},
'broadcast' : {'kind': 'op', 'op': 'Broadcast'},
'scatternd' : {'kind': 'op', 'op': 'ScatterNDUpdate'},
'cast_default_value': {'kind': 'op', 'op': 'Cast'},

'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},

**const('input_dense_shape', int64_array([50, 40])),
**const('input_default_value', int64_array(0))}

graph = build_graph(nodes_attributes,
[('input_indices', 'sparse_to_dense', {'out': 0, 'in': 0}),
('input_dense_shape', 'sparse_to_dense', {'out': 0, 'in': 1}),
('input_values', 'sparse_to_dense', {'out': 0, 'in': 2}),
('input_default_value', 'sparse_to_dense', {'out': 0, 'in': 3}),
('sparse_to_dense', 'last', {'out': 0, 'in': 0})],
nodes_with_edges_only=True)
graph.stage = 'front'
SparseToDenseReplacer().find_and_replace_pattern(graph)

graph_ref = build_graph(nodes_attributes,
[('input_default_value', 'cast_default_value', {'in': 0}),
('cast_default_value', 'broadcast', {'in': 0}),
('input_dense_shape', 'broadcast', {'in': 1}),
('broadcast', 'scatternd', {'in': 0}),
('input_indices', 'scatternd', {'in': 1}),
('input_values', 'scatternd', {'in': 2}),
('scatternd', 'last', {'in': 0})],
nodes_with_edges_only=True)

(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
self.assertTrue(flag, resp)
103 changes: 103 additions & 0 deletions model-optimizer/extensions/ops/scatternd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
Copyright (C) 2020 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy as np

from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.ops.op import Op


class ScatterNDBase(Op):
enabled = False

op = op_type = None
version = None

def __init__(self, graph: Graph, attrs: dict):
assert self.op is not None and self.op_type is not None and self.version is not None, \
'Please use specialized ScatterNDBase operation class, ScatterNDBase is base class'

mandatory_props = {
'op': self.op,
'type': self.op_type,
'version': self.version,

'infer': self.infer,

'in_ports_count': 3,
'out_ports_count': 1,
}
super().__init__(graph, mandatory_props, attrs)

@staticmethod
def infer(node: Node):
node_name = node.soft_get('name', node.id)

input_shape = node.in_port(0).data.get_shape()
indices_shape = node.in_port(1).data.get_shape()
updates_shape = node.in_port(2).data.get_shape()
assert input_shape is not None and updates_shape is not None and indices_shape is not None, \
'The node "{}" input shape is None'.format(node_name)

# check that shapes are correct
# 1. ranks of both input and indices must be at least 1
assert len(input_shape) >= 1 and len(indices_shape) >= 1, \
'The node "{}" input and indices ranks must be at least 1'.format(node_name)

# 2. the last dimension of indices shape must be at most a rank of input
assert indices_shape[-1] <= len(input_shape), \
'The last dimension of indices shape must be at most a rank of input for the node "{}"'.format(node_name)

# 3. updates is a tensor of shape indices_shape[:-1] + input_shape[indices_shape[-1]:]
expected_updates_shape = np.concatenate((indices_shape[:-1], input_shape[indices_shape[-1]:]), axis=0)
assert np.array_equal(updates_shape, expected_updates_shape), \
'The updates shape must be equal to indices_shape[:-1] + input_shape[indices_shape[-1]:] for the node "{}"'.format(node_name)

node.out_port(0).data.set_shape(input_shape)

@staticmethod
def type_infer(node: Node):
assert node.in_port(0).get_source().get_data_type() == node.in_port(2).get_source().get_data_type(), \
'The data type of the first and the third inputs must be equal for the node {}'.format(node.name)
node.out_port(0).set_data_type(node.in_port(0).get_data_type())


class ScatterNDUpdate(ScatterNDBase):
op = op_type = 'ScatterNDUpdate'
version = 'opset4'

@staticmethod
def infer(node: Node):
ScatterNDBase.infer(node)

input_value = node.in_port(0).data.get_value()
indices_shape = node.in_port(1).data.get_shape()
indices_value = node.in_port(1).data.get_value()
updates_value = node.in_port(2).data.get_value()

# compute output value if all inputs are constant
if input_value is not None and indices_value is not None and updates_value is not None:
output_value = input_value.copy()
indx_range = int64_array(indices_shape[:-1])
for indx in np.ndindex(tuple(indx_range)):
if indx == ():
# a case when updates is a scalar
indx = 0
updates_value = [updates_value]
output_value[indices_value[indx]] = updates_value[indx]

node.out_port(0).data.set_value(output_value)
Loading