@@ -52,27 +52,22 @@ def interpolate_bilinear(
52
52
with tf .name_scope (name or "interpolate_bilinear" ):
53
53
grid = tf .convert_to_tensor (grid )
54
54
query_points = tf .convert_to_tensor (query_points )
55
-
56
- # grid shape checks
57
- grid_static_shape = grid .shape
58
55
grid_shape = tf .shape (grid )
59
- if grid_static_shape .dims is not None :
60
- if len (grid_static_shape ) != 4 :
61
- raise ValueError ("Grid must be 4D Tensor" )
62
- if grid_static_shape [1 ] is not None and grid_static_shape [1 ] < 2 :
63
- raise ValueError ("Grid height must be at least 2." )
64
- if grid_static_shape [2 ] is not None and grid_static_shape [2 ] < 2 :
65
- raise ValueError ("Grid width must be at least 2." )
66
-
67
- # query_points shape checks
68
- query_static_shape = query_points .shape
69
56
query_shape = tf .shape (query_points )
70
- if query_static_shape .dims is not None :
71
- if len (query_static_shape ) != 3 :
72
- raise ValueError ("Query points must be 3 dimensional." )
73
- query_hw = query_static_shape [2 ]
74
- if query_hw is not None and query_hw != 2 :
75
- raise ValueError ("Query points last dimension must be 2." )
57
+
58
+ tf .Assert (tf .equal (tf .rank (grid ), 4 ), ["Grid must be 4D Tensor" ])
59
+ tf .Assert (
60
+ tf .greater_equal (grid_shape [1 ], 2 ), ["Grid height must be at least 2." ]
61
+ )
62
+ tf .Assert (
63
+ tf .greater_equal (grid_shape [2 ], 2 ), ["Grid width must be at least 2." ]
64
+ )
65
+ tf .Assert (
66
+ tf .equal (tf .rank (query_points ), 3 ), ["Query points must be 3 dimensional." ]
67
+ )
68
+ tf .Assert (
69
+ tf .equal (query_shape [2 ], 2 ), ["Query points last dimension must be 2." ]
70
+ )
76
71
77
72
batch_size , height , width , channels = (
78
73
grid_shape [0 ],
0 commit comments