From f26f6177e7b0866794809bb6ebb960c44842d051 Mon Sep 17 00:00:00 2001 From: 11happy Date: Sat, 1 Feb 2025 09:01:06 +0530 Subject: [PATCH 1/5] feat: implement complex type support for selectv2 Signed-off-by: 11happy --- .../tensorflow_common/src/op/select.cpp | 58 ++++++++++++------- .../tensorflow_tests/test_tf_SelectV2.py | 52 +++++++++++++++++ 2 files changed, 88 insertions(+), 22 deletions(-) diff --git a/src/frontends/tensorflow_common/src/op/select.cpp b/src/frontends/tensorflow_common/src/op/select.cpp index f19e01f5a021e6..6b5d79b30d9784 100644 --- a/src/frontends/tensorflow_common/src/op/select.cpp +++ b/src/frontends/tensorflow_common/src/op/select.cpp @@ -31,7 +31,19 @@ OutputVector translate_select_base_op(const NodeContext& node, set_node_name(node.get_name(), select); return {select}; } - +bool has_complex_inputs(Output& x, Output& y, element::Type& complex_part_type) { + auto complex_type_mark_x = as_type_ptr(x.get_node_shared_ptr()); + auto complex_type_mark_y = as_type_ptr(y.get_node_shared_ptr()); + if (complex_type_mark_x) { + x = complex_type_mark_x->input_value(0); + complex_part_type = complex_type_mark_x->get_complex_part_type(); + } + if (complex_type_mark_y) { + y = complex_type_mark_y->input_value(0); + complex_part_type = complex_type_mark_y->get_complex_part_type(); + } + return (complex_type_mark_x || complex_type_mark_y); +} OutputVector translate_select_v2_op(const NodeContext& node) { // according to the TensorFlow documentation. See in the code: // https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/lite/kernels/select.cc#L188-L211 @@ -40,10 +52,25 @@ OutputVector translate_select_v2_op(const NodeContext& node) { // is true or the value of 'y' if false. There are valid condition input sizes: // 1. Either the same shape (in which case the select is elementwise), or // 2. Broadcastable shapes between 'condition', 'x' and 'y'. - default_op_checks(node, 3, {"SelectV2", "SELECT_V2"}); - // no preparation for inputs are needed - // inputs are already NumPy broadcastable - return translate_select_base_op(node, node.get_input(0), node.get_input(1), node.get_input(2)); + default_op_checks(node, 3, {"SelectV2", "SELECT_V2"}, true); + auto condition = node.get_input(0); + auto x = node.get_input(1); + auto y = node.get_input(2); + + element::Type complex_part_type; + auto is_complex = has_complex_inputs(x, y, complex_part_type); + + if (is_complex) { + auto cur_cond_shape = make_shared(condition, element::i32); + auto const_one = make_shared(element::i32, Shape{1}, 1); + auto new_cond_shape = make_shared(OutputVector{cur_cond_shape, const_one}, 0); + auto new_condition = make_shared(condition, new_cond_shape, false); + auto result = translate_select_base_op(node, new_condition, x, y); + auto complex_result = make_shared(result[0].get_node_shared_ptr(), complex_part_type); + return {complex_result->output(0)}; + } else { + return translate_select_base_op(node, condition, x, y); + } } OutputVector translate_select_op(const NodeContext& node) { @@ -59,21 +86,9 @@ OutputVector translate_select_op(const NodeContext& node) { auto condition = node.get_input(0); auto x = node.get_input(1); auto y = node.get_input(2); - auto complex_type_mark_x = as_type_ptr(x.get_node_shared_ptr()); - auto complex_type_mark_y = as_type_ptr(y.get_node_shared_ptr()); - auto is_complex = (complex_type_mark_x || complex_type_mark_y); element::Type complex_part_type; - - if (complex_type_mark_x) { - x = complex_type_mark_x->input_value(0); - complex_part_type = complex_type_mark_x->get_complex_part_type(); - } - - if (complex_type_mark_y) { - y = complex_type_mark_y->input_value(0); - complex_part_type = complex_type_mark_y->get_complex_part_type(); - } + auto is_complex = has_complex_inputs(x, y, complex_part_type); // compute number of dimensions to unsqueeze the condition auto cond_rank = compute_subgraph_scalar_rank(condition, element::i32); @@ -85,14 +100,13 @@ OutputVector translate_select_op(const NodeContext& node) { auto new_subshape = make_shared(const_one, num_new_axes); auto cond_shape = make_shared(condition, element::i32); // use extra dimensions in the begin to avoid concatenation of empty tensors that is not supported by Concat - auto const_1 = make_shared(element::i32, Shape{1}, 1); - auto new_cond_shape = make_shared(OutputVector{const_1, cond_shape, new_subshape}, 0); + auto new_cond_shape = make_shared(OutputVector{const_one, cond_shape, new_subshape}, 0); // prepare the condition to have the same rank as operands `x` and `y` auto prep_cond = make_shared(condition, new_cond_shape, false)->output(0); // squeeze prep_cond by one extra dimension specially added - auto const_0 = make_shared(element::i32, Shape{1}, 0); - prep_cond = make_shared(prep_cond, const_0); + auto const_zero = make_shared(element::i32, Shape{1}, 0); + prep_cond = make_shared(prep_cond, const_zero); auto result = translate_select_base_op(node, prep_cond, x, y); if (is_complex) { diff --git a/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py b/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py index 058f2e21a4a60b..c86ce282592850 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py @@ -51,3 +51,55 @@ def test_select_v2_basic(self, params, ie_device, precision, ir_version, temp_di self._test(*self.create_select_v2_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) + + +class TestComplexSelectV2(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + rng = np.random.default_rng() + assert 'cond:0' in inputs_info, "Test error: inputs_info must contain `cond`" + assert 'x_real:0' in inputs_info, "Test error: inputs_info must contain `x_real`" + assert 'x_imag:0' in inputs_info, "Test error: inputs_info must contain `x_imag`" + assert 'y_real:0' in inputs_info, "Test error: inputs_info must contain `y_real`" + assert 'y_imag:0' in inputs_info, "Test error: inputs_info must contain `y_imag`" + cond_shape = inputs_info['cond:0'] + inputs_data = {} + inputs_data['cond:0'] = np.random.randint(0, 2, cond_shape).astype(bool) + for part in ['x_real:0', 'x_imag:0', 'y_real:0', 'y_imag:0']: + inputs_data[part] = 4 * rng.random(inputs_info[part]).astype(np.float32) - 2 + return inputs_data + + def create_complex_select_v2_net(self, cond_shape, x_shape, y_shape): + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + cond = tf.compat.v1.placeholder(tf.bool, cond_shape, 'cond') + x_real = tf.compat.v1.placeholder(tf.float32, x_shape, 'x_real') + x_imag = tf.compat.v1.placeholder(tf.float32, x_shape, 'x_imag') + y_real = tf.compat.v1.placeholder(tf.float32, y_shape, 'y_real') + y_imag = tf.compat.v1.placeholder(tf.float32, y_shape, 'y_imag') + complex_x = tf.raw_ops.Complex(real=x_real, imag=x_imag) + complex_y = tf.raw_ops.Complex(real=y_real, imag=y_imag) + complex_select = tf.raw_ops.SelectV2(condition=cond, t=complex_x, e=complex_y) + tf.raw_ops.Real(input=complex_select) + tf.raw_ops.Imag(input=complex_select) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + return tf_net, None + + test_data_basic = [ + dict(cond_shape=[3, 1], x_shape=[3, 1], y_shape=[3, 1]), + dict(cond_shape=[], x_shape=[2], y_shape=[3, 2]), + dict(cond_shape=[4], x_shape=[3, 2, 1], y_shape=[2, 4]), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit + @pytest.mark.nightly + + def test_complex_select_v2(self, params, ie_device, precision, ir_version, temp_dir, + use_legacy_frontend): + if use_legacy_frontend: + pytest.skip("Select tests are not passing for the legacy frontend.") + self._test(*self.create_complex_select_v2_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_legacy_frontend=use_legacy_frontend) \ No newline at end of file From 2b1bfe7ed2965cb219c7d094d2cc5171018c23c7 Mon Sep 17 00:00:00 2001 From: Bhuminjay Soni Date: Sat, 1 Feb 2025 18:06:06 +0530 Subject: [PATCH 2/5] Update tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py Co-authored-by: Roman Kazantsev --- tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py b/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py index c86ce282592850..d9d971282c4a65 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py @@ -95,7 +95,6 @@ def create_complex_select_v2_net(self, cond_shape, x_shape, y_shape): @pytest.mark.parametrize("params", test_data_basic) @pytest.mark.precommit @pytest.mark.nightly - def test_complex_select_v2(self, params, ie_device, precision, ir_version, temp_dir, use_legacy_frontend): if use_legacy_frontend: From 9b59f8a42dc7dcb44c99395c6a578cec85029c67 Mon Sep 17 00:00:00 2001 From: Bhuminjay Soni Date: Sat, 1 Feb 2025 18:06:18 +0530 Subject: [PATCH 3/5] Update tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py Co-authored-by: Roman Kazantsev --- tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py b/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py index d9d971282c4a65..d199275bf34345 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py @@ -97,8 +97,6 @@ def create_complex_select_v2_net(self, cond_shape, x_shape, y_shape): @pytest.mark.nightly def test_complex_select_v2(self, params, ie_device, precision, ir_version, temp_dir, use_legacy_frontend): - if use_legacy_frontend: - pytest.skip("Select tests are not passing for the legacy frontend.") self._test(*self.create_complex_select_v2_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) \ No newline at end of file From faf332444006a9d9dcc57b3da02acce728d42074 Mon Sep 17 00:00:00 2001 From: 11happy Date: Sat, 1 Feb 2025 18:15:44 +0530 Subject: [PATCH 4/5] refactor: use unsqueeze Signed-off-by: 11happy --- src/frontends/tensorflow_common/src/op/select.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/frontends/tensorflow_common/src/op/select.cpp b/src/frontends/tensorflow_common/src/op/select.cpp index 6b5d79b30d9784..602e1259723a7c 100644 --- a/src/frontends/tensorflow_common/src/op/select.cpp +++ b/src/frontends/tensorflow_common/src/op/select.cpp @@ -61,10 +61,8 @@ OutputVector translate_select_v2_op(const NodeContext& node) { auto is_complex = has_complex_inputs(x, y, complex_part_type); if (is_complex) { - auto cur_cond_shape = make_shared(condition, element::i32); - auto const_one = make_shared(element::i32, Shape{1}, 1); - auto new_cond_shape = make_shared(OutputVector{cur_cond_shape, const_one}, 0); - auto new_condition = make_shared(condition, new_cond_shape, false); + auto const_negative_one = make_shared(element::i32, Shape{1}, -1); + auto new_condition = make_shared(condition, const_negative_one); auto result = translate_select_base_op(node, new_condition, x, y); auto complex_result = make_shared(result[0].get_node_shared_ptr(), complex_part_type); return {complex_result->output(0)}; From 7760eadd0c471868e0a334828e5a172e7ee3e646 Mon Sep 17 00:00:00 2001 From: 11happy Date: Sun, 2 Feb 2025 04:44:26 +0530 Subject: [PATCH 5/5] fix: include missing unsqueeze header Signed-off-by: 11happy --- src/frontends/tensorflow_common/src/op/select.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/frontends/tensorflow_common/src/op/select.cpp b/src/frontends/tensorflow_common/src/op/select.cpp index 602e1259723a7c..35c7e893e542e1 100644 --- a/src/frontends/tensorflow_common/src/op/select.cpp +++ b/src/frontends/tensorflow_common/src/op/select.cpp @@ -13,6 +13,7 @@ #include "openvino/op/shape_of.hpp" #include "openvino/op/squeeze.hpp" #include "openvino/op/subtract.hpp" +#include "openvino/op/unsqueeze.hpp" using namespace std; using namespace ov;