@@ -64,28 +64,20 @@ def _interpolate_bilinear_with_checks(
64
64
grid_shape = tf .shape (grid )
65
65
query_shape = tf .shape (query_points )
66
66
67
- with tf .control_dependencies (
68
- [
69
- tf .Assert (tf .equal (tf .rank (grid ), 4 ), ["Grid must be 4D Tensor" ]),
70
- tf .Assert (
71
- tf .greater_equal (grid_shape [1 ], 2 ), ["Grid height must be at least 2." ]
72
- ),
73
- tf .Assert (
74
- tf .greater_equal (grid_shape [2 ], 2 ), ["Grid width must be at least 2." ]
75
- ),
76
- tf .Assert (
77
- tf .equal (tf .rank (query_points ), 3 ),
78
- ["Query points must be 3 dimensional." ],
79
- ),
80
- tf .Assert (
81
- tf .equal (query_shape [2 ], 2 ), ["Query points last dimension must be 2." ]
82
- ),
83
- ]
84
- ):
67
+ with tf .control_dependencies ([
68
+ tf .debugging .assert_equal (tf .rank (grid ), 4 , "Grid must be 4D Tensor" ),
69
+ tf .debugging .assert_greater_equal (grid_shape [1 ], 2 ,
70
+ "Grid height must be at least 2." ),
71
+ tf .debugging .assert_greater_equal (grid_shape [2 ], 2 ,
72
+ "Grid width must be at least 2." ),
73
+ tf .debugging .assert_equal (
74
+ tf .rank (query_points ), 3 , "Query points must be 3 dimensional." ),
75
+ tf .debugging .assert_equal (query_shape [2 ], 2 ,
76
+ "Query points last dimension must be 2." )
77
+ ]):
85
78
return _interpolate_bilinear_impl (grid , query_points , indexing , name )
86
79
87
80
88
- @tf .function
89
81
def _interpolate_bilinear_impl (
90
82
grid : types .TensorLike ,
91
83
query_points : types .TensorLike ,
0 commit comments