diff --git a/tensorflow_addons/image/dense_image_warp.py b/tensorflow_addons/image/dense_image_warp.py index a426a77503..d534d14c95 100644 --- a/tensorflow_addons/image/dense_image_warp.py +++ b/tensorflow_addons/image/dense_image_warp.py @@ -46,33 +46,54 @@ def interpolate_bilinear( ValueError: if the indexing mode is invalid, or if the shape of the inputs invalid. """ + return _interpolate_bilinear_with_checks(grid, query_points, indexing, name) + + +def _interpolate_bilinear_with_checks( + grid: types.TensorLike, + query_points: types.TensorLike, + indexing: str, + name: Optional[str], +) -> tf.Tensor: + """Perform checks on inputs without tf.function decorator to avoid flakiness.""" if indexing != "ij" and indexing != "xy": raise ValueError("Indexing mode must be 'ij' or 'xy'") + grid = tf.convert_to_tensor(grid) + query_points = tf.convert_to_tensor(query_points) + grid_shape = tf.shape(grid) + query_shape = tf.shape(query_points) + + with tf.control_dependencies( + [ + tf.debugging.assert_equal(tf.rank(grid), 4, "Grid must be 4D Tensor"), + tf.debugging.assert_greater_equal( + grid_shape[1], 2, "Grid height must be at least 2." + ), + tf.debugging.assert_greater_equal( + grid_shape[2], 2, "Grid width must be at least 2." + ), + tf.debugging.assert_equal( + tf.rank(query_points), 3, "Query points must be 3 dimensional." + ), + tf.debugging.assert_equal( + query_shape[2], 2, "Query points last dimension must be 2." + ), + ] + ): + return _interpolate_bilinear_impl(grid, query_points, indexing, name) + + +def _interpolate_bilinear_impl( + grid: types.TensorLike, + query_points: types.TensorLike, + indexing: str, + name: Optional[str], +) -> tf.Tensor: + """tf.function implementation of interpolate_bilinear.""" with tf.name_scope(name or "interpolate_bilinear"): - grid = tf.convert_to_tensor(grid) - query_points = tf.convert_to_tensor(query_points) - - # grid shape checks - grid_static_shape = grid.shape grid_shape = tf.shape(grid) - if grid_static_shape.dims is not None: - if len(grid_static_shape) != 4: - raise ValueError("Grid must be 4D Tensor") - if grid_static_shape[1] is not None and grid_static_shape[1] < 2: - raise ValueError("Grid height must be at least 2.") - if grid_static_shape[2] is not None and grid_static_shape[2] < 2: - raise ValueError("Grid width must be at least 2.") - - # query_points shape checks - query_static_shape = query_points.shape query_shape = tf.shape(query_points) - if query_static_shape.dims is not None: - if len(query_static_shape) != 3: - raise ValueError("Query points must be 3 dimensional.") - query_hw = query_static_shape[2] - if query_hw is not None and query_hw != 2: - raise ValueError("Query points last dimension must be 2.") batch_size, height, width, channels = ( grid_shape[0], diff --git a/tensorflow_addons/image/tests/dense_image_warp_test.py b/tensorflow_addons/image/tests/dense_image_warp_test.py index 0118929ceb..495532235f 100644 --- a/tensorflow_addons/image/tests/dense_image_warp_test.py +++ b/tensorflow_addons/image/tests/dense_image_warp_test.py @@ -239,7 +239,9 @@ def test_interpolation(): def test_size_exception(): """Make sure it throws an exception for images that are too small.""" shape = [1, 2, 1, 1] - with pytest.raises(ValueError, match="Grid width must be at least 2."): + with pytest.raises( + tf.errors.InvalidArgumentError, match="Grid width must be at least 2." + ): _check_interpolation_correctness(shape, "float32", "float32") @@ -250,11 +252,3 @@ def test_unknown_shapes(): shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]] for shape in shapes_to_try: _check_interpolation_correctness(shape, "float32", "float32", True) - - -@pytest.mark.usefixtures("only_run_functions_eagerly") -def test_symbolic_tensor_shape(): - image = tf.keras.layers.Input(shape=(7, 7, 192)) - flow = tf.ones((1, 7, 7, 2)) - interp = dense_image_warp(image, flow) - np.testing.assert_array_equal(interp.shape.as_list(), [None, 7, 7, 192]) diff --git a/tools/testing/source_code_test.py b/tools/testing/source_code_test.py index c54bf73ea2..4d309a90d7 100644 --- a/tools/testing/source_code_test.py +++ b/tools/testing/source_code_test.py @@ -178,6 +178,7 @@ def test_no_tf_control_dependencies(): allowlist = [ "tensorflow_addons/layers/wrappers.py", "tensorflow_addons/image/utils.py", + "tensorflow_addons/image/dense_image_warp.py", "tensorflow_addons/optimizers/average_wrapper.py", "tensorflow_addons/optimizers/yogi.py", "tensorflow_addons/optimizers/lookahead.py",