Skip to content

Commit 3361a9c

Browse files
committed
Update dense_image_warp.py
1 parent 483938d commit 3361a9c

File tree

1 file changed

+11
-19
lines changed

1 file changed

+11
-19
lines changed

tensorflow_addons/image/dense_image_warp.py

+11-19
Original file line numberDiff line numberDiff line change
@@ -64,28 +64,20 @@ def _interpolate_bilinear_with_checks(
6464
grid_shape = tf.shape(grid)
6565
query_shape = tf.shape(query_points)
6666

67-
with tf.control_dependencies(
68-
[
69-
tf.Assert(tf.equal(tf.rank(grid), 4), ["Grid must be 4D Tensor"]),
70-
tf.Assert(
71-
tf.greater_equal(grid_shape[1], 2), ["Grid height must be at least 2."]
72-
),
73-
tf.Assert(
74-
tf.greater_equal(grid_shape[2], 2), ["Grid width must be at least 2."]
75-
),
76-
tf.Assert(
77-
tf.equal(tf.rank(query_points), 3),
78-
["Query points must be 3 dimensional."],
79-
),
80-
tf.Assert(
81-
tf.equal(query_shape[2], 2), ["Query points last dimension must be 2."]
82-
),
83-
]
84-
):
67+
with tf.control_dependencies([
68+
tf.debugging.assert_equal(tf.rank(grid), 4, "Grid must be 4D Tensor"),
69+
tf.debugging.assert_greater_equal(grid_shape[1], 2,
70+
"Grid height must be at least 2."),
71+
tf.debugging.assert_greater_equal(grid_shape[2], 2,
72+
"Grid width must be at least 2."),
73+
tf.debugging.assert_equal(
74+
tf.rank(query_points), 3, "Query points must be 3 dimensional."),
75+
tf.debugging.assert_equal(query_shape[2], 2,
76+
"Query points last dimension must be 2.")
77+
]):
8578
return _interpolate_bilinear_impl(grid, query_points, indexing, name)
8679

8780

88-
@tf.function
8981
def _interpolate_bilinear_impl(
9082
grid: types.TensorLike,
9183
query_points: types.TensorLike,

0 commit comments

Comments
 (0)