Skip to content

Commit fc81e28

Browse files
committed
Generate Assertion Ops for interpolate_bilinear
1 parent a4113f6 commit fc81e28

File tree

2 files changed

+17
-28
lines changed

2 files changed

+17
-28
lines changed

tensorflow_addons/image/dense_image_warp.py

+14-19
Original file line numberDiff line numberDiff line change
@@ -52,27 +52,22 @@ def interpolate_bilinear(
5252
with tf.name_scope(name or "interpolate_bilinear"):
5353
grid = tf.convert_to_tensor(grid)
5454
query_points = tf.convert_to_tensor(query_points)
55-
56-
# grid shape checks
57-
grid_static_shape = grid.shape
5855
grid_shape = tf.shape(grid)
59-
if grid_static_shape.dims is not None:
60-
if len(grid_static_shape) != 4:
61-
raise ValueError("Grid must be 4D Tensor")
62-
if grid_static_shape[1] is not None and grid_static_shape[1] < 2:
63-
raise ValueError("Grid height must be at least 2.")
64-
if grid_static_shape[2] is not None and grid_static_shape[2] < 2:
65-
raise ValueError("Grid width must be at least 2.")
66-
67-
# query_points shape checks
68-
query_static_shape = query_points.shape
6956
query_shape = tf.shape(query_points)
70-
if query_static_shape.dims is not None:
71-
if len(query_static_shape) != 3:
72-
raise ValueError("Query points must be 3 dimensional.")
73-
query_hw = query_static_shape[2]
74-
if query_hw is not None and query_hw != 2:
75-
raise ValueError("Query points last dimension must be 2.")
57+
58+
tf.Assert(tf.equal(tf.rank(grid), 4), ["Grid must be 4D Tensor"])
59+
tf.Assert(
60+
tf.greater_equal(grid_shape[1], 2), ["Grid height must be at least 2."]
61+
)
62+
tf.Assert(
63+
tf.greater_equal(grid_shape[2], 2), ["Grid width must be at least 2."]
64+
)
65+
tf.Assert(
66+
tf.equal(tf.rank(query_points), 3), ["Query points must be 3 dimensional."]
67+
)
68+
tf.Assert(
69+
tf.equal(query_shape[2], 2), ["Query points last dimension must be 2."]
70+
)
7671

7772
batch_size, height, width, channels = (
7873
grid_shape[0],

tensorflow_addons/image/tests/dense_image_warp_test.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,9 @@ def test_interpolation():
239239
def test_size_exception():
240240
"""Make sure it throws an exception for images that are too small."""
241241
shape = [1, 2, 1, 1]
242-
with pytest.raises(ValueError, match="Grid width must be at least 2."):
242+
with pytest.raises(
243+
tf.errors.InvalidArgumentError, match="Grid width must be at least 2."
244+
):
243245
_check_interpolation_correctness(shape, "float32", "float32")
244246

245247

@@ -250,11 +252,3 @@ def test_unknown_shapes():
250252
shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]]
251253
for shape in shapes_to_try:
252254
_check_interpolation_correctness(shape, "float32", "float32", True)
253-
254-
255-
@pytest.mark.usefixtures("only_run_functions_eagerly")
256-
def test_symbolic_tensor_shape():
257-
image = tf.keras.layers.Input(shape=(7, 7, 192))
258-
flow = tf.ones((1, 7, 7, 2))
259-
interp = dense_image_warp(image, flow)
260-
np.testing.assert_array_equal(interp.shape.as_list(), [None, 7, 7, 192])

0 commit comments

Comments
 (0)