Skip to content

Remove tf function where unecessary #823

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorflow_addons/activations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#### Standard API
In order to conform with the current API standard, all activations
must:
* Be a `tf.function`.
* Be a `tf.function` unless it is a straightforward call to a custom op or likely to be retraced.
* Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')`
* Add the addon to the `py_library` in this sub-package's BUILD file.

Expand Down
1 change: 0 additions & 1 deletion tensorflow_addons/activations/gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


@tf.keras.utils.register_keras_serializable(package='Addons')
@tf.function
def gelu(x, approximate=True):
"""Gaussian Error Linear Unit.

Expand Down
8 changes: 0 additions & 8 deletions tensorflow_addons/activations/gelu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,6 @@ def test_theoretical_gradients(self, dtype):
self.assertAllCloseAccordingToType(
theoretical, numerical, atol=1e-4)

def test_unknown_shape(self):
fn = gelu.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))

for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), gelu(x))


if __name__ == "__main__":
tf.test.main()
1 change: 0 additions & 1 deletion tensorflow_addons/activations/hardshrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


@tf.keras.utils.register_keras_serializable(package='Addons')
@tf.function
def hardshrink(x, lower=-0.5, upper=0.5):
"""Hard shrink function.

Expand Down
8 changes: 0 additions & 8 deletions tensorflow_addons/activations/hardshrink_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ def test_theoretical_gradients(self, dtype):
theoretical, numerical = tf.test.compute_gradient(hardshrink, [x])
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)

def test_unknown_shape(self):
fn = hardshrink.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))

for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), hardshrink(x))


if __name__ == "__main__":
tf.test.main()
1 change: 0 additions & 1 deletion tensorflow_addons/activations/lisht.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


@tf.keras.utils.register_keras_serializable(package='Addons')
@tf.function
def lisht(x):
"""LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function.

Expand Down
8 changes: 0 additions & 8 deletions tensorflow_addons/activations/lisht_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,6 @@ def test_theoretical_gradients(self, dtype):
self.assertAllCloseAccordingToType(
theoretical, numerical, rtol=5e-4, atol=5e-4)

def test_unknown_shape(self):
fn = lisht.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))

for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), lisht(x))


if __name__ == "__main__":
tf.test.main()
1 change: 0 additions & 1 deletion tensorflow_addons/activations/mish.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


@tf.keras.utils.register_keras_serializable(package='Addons')
@tf.function
def mish(x):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function.

Expand Down
8 changes: 0 additions & 8 deletions tensorflow_addons/activations/mish_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,6 @@ def test_theoretical_gradients(self, dtype):
theoretical, numerical = tf.test.compute_gradient(mish, [x])
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)

def test_unknown_shape(self):
fn = mish.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))

for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), mish(x))


if __name__ == "__main__":
tf.test.main()
1 change: 0 additions & 1 deletion tensorflow_addons/activations/softshrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


@tf.keras.utils.register_keras_serializable(package='Addons')
@tf.function
def softshrink(x, lower=-0.5, upper=0.5):
"""Soft shrink function.

Expand Down
8 changes: 0 additions & 8 deletions tensorflow_addons/activations/softshrink_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ def test_theoretical_gradients(self, dtype):
theoretical, numerical = tf.test.compute_gradient(softshrink, [x])
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)

def test_unknown_shape(self):
fn = softshrink.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))

for shape in [(1,), (1, 2), (1, 2, 3), (1, 2, 3, 4)]:
x = tf.ones(shape=shape, dtype=tf.float32)
self.assertAllClose(fn(x), softshrink(x))


if __name__ == "__main__":
tf.test.main()
1 change: 0 additions & 1 deletion tensorflow_addons/activations/tanhshrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


@tf.keras.utils.register_keras_serializable(package='Addons')
@tf.function
def tanhshrink(x):
"""Applies the element-wise function: x - tanh(x)

Expand Down
1 change: 0 additions & 1 deletion tensorflow_addons/image/distance_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
tf.no_gradient("Addons>EuclideanDistanceTransform")


@tf.function
def euclidean_dist_transform(images, dtype=tf.float32, name=None):
"""Applies euclidean distance transform(s) to the image(s).

Expand Down
8 changes: 0 additions & 8 deletions tensorflow_addons/image/distance_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_batch_binary_images(self):
0, 1, 1, 1, 0,
0, 0, 0, 0, 0
] * batch_size)
# yapf: enable
images = tf.constant([image] * batch_size, dtype=tf.uint8)
for output_dtype in [tf.float16, tf.float32, tf.float64]:
output = dist_ops.euclidean_dist_transform(
Expand Down Expand Up @@ -121,13 +120,6 @@ def test_all_ones(self):
expected_output = np.full([10, 10, 1], tf.float32.max)
self.assertAllClose(output, expected_output)

def test_unknown_shape(self):
fn = dist_ops.euclidean_dist_transform.get_concrete_function(
tf.TensorSpec(None, tf.uint8))
for shape in [[5, 10], [10, 7, 1], [4, 10, 10, 1]]:
image = tf.zeros(shape, dtype=tf.uint8)
self.assertAllClose(image, fn(image))


if __name__ == "__main__":
tf.test.main()
2 changes: 0 additions & 2 deletions tensorflow_addons/image/distort_image_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


# pylint: disable=invalid-name
@tf.function
def random_hsv_in_yiq(image,
max_delta_hue=0,
lower_saturation=1,
Expand Down Expand Up @@ -106,7 +105,6 @@ def random_hsv_in_yiq(image,
image, delta_hue, scale_saturation, scale_value, name=scope)


@tf.function
def adjust_hsv_in_yiq(image,
delta_hue=0,
scale_saturation=1,
Expand Down
76 changes: 54 additions & 22 deletions tensorflow_addons/image/distort_image_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,33 @@ def test_adjust_random_hue_in_yiq(self):
self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4)

def test_invalid_rank(self):
msg = "Shape must be at least rank 3 but is rank 2"
x_np = np.random.rand(2, 3) * 255.
delta_h = np.random.rand() * 2.0 - 1.0
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h))
if tf.executing_eagerly():
msg = "input must be at least 3-D"
with self.assertRaisesRegex(tf.errors.InvalidArgumentError, msg):
self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h))
else:
msg = "Shape must be at least rank 3 but is rank 2"
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h))

def test_invalid_channels(self):
msg = "Dimension must be 3 but is 4"
x_np = np.random.rand(4, 2, 4) * 255.
delta_h = np.random.rand() * 2.0 - 1.0
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h))
if tf.executing_eagerly():
msg = "input must have 3 channels but instead has 4"
with self.assertRaisesRegex(tf.errors.InvalidArgumentError, msg):
self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h))
else:
msg = "Dimension must be 3 but is 4"
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_hue_in_yiq_tf(x_np, delta_h))

def test_adjust_hsv_in_yiq_unknown_shape(self):
fn = distort_image_ops.adjust_hsv_in_yiq.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float64))
fn = tf.function(
distort_image_ops.adjust_hsv_in_yiq).get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float64))
for shape in (2, 3, 3), (4, 2, 3, 3):
image_np = np.random.rand(*shape) * 255.
image_tf = tf.constant(image_np)
Expand All @@ -127,8 +138,9 @@ def test_adjust_hsv_in_yiq_unknown_shape(self):
atol=1e-4)

def test_random_hsv_in_yiq_unknown_shape(self):
fn = distort_image_ops.random_hsv_in_yiq.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))
fn = tf.function(
distort_image_ops.random_hsv_in_yiq).get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))
for shape in (2, 3, 3), (4, 2, 3, 3):
image_tf = tf.ones(shape)
self.assertAllEqual(fn(image_tf), fn(image_tf))
Expand Down Expand Up @@ -182,18 +194,28 @@ def test_adjust_random_value_in_yiq(self):
self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4)

def test_invalid_rank(self):
msg = "Shape must be at least rank 3 but is rank 2"
x_np = np.random.rand(2, 3) * 255.
scale = np.random.rand() * 2.0 - 1.0
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale))
if tf.executing_eagerly():
msg = "input must be at least 3-D"
with self.assertRaisesRegex(tf.errors.InvalidArgumentError, msg):
self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale))
else:
msg = "Shape must be at least rank 3 but is rank 2"
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale))

def test_invalid_channels(self):
msg = "Dimension must be 3 but is 4"
x_np = np.random.rand(4, 2, 4) * 255.
scale = np.random.rand() * 2.0 - 1.0
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale))
if tf.executing_eagerly():
msg = "input must have 3 channels but instead has 4"
with self.assertRaisesRegex(tf.errors.InvalidArgumentError, msg):
self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale))
else:
msg = "Dimension must be 3 but is 4"
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_value_in_yiq_tf(x_np, scale))


@test_utils.run_all_in_graph_and_eager_modes
Expand Down Expand Up @@ -248,18 +270,28 @@ def test_adjust_random_saturation_in_yiq(self):
self.assertAllClose(y_tf, y_baseline, rtol=2e-4, atol=1e-4)

def test_invalid_rank(self):
msg = "Shape must be at least rank 3 but is rank 2"
x_np = np.random.rand(2, 3) * 255.
scale = np.random.rand() * 2.0 - 1.0
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale))
if tf.executing_eagerly():
msg = "input must be at least 3-D"
with self.assertRaisesRegex(tf.errors.InvalidArgumentError, msg):
self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale))
else:
msg = "Shape must be at least rank 3 but is rank 2"
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale))

def test_invalid_channels(self):
msg = "Dimension must be 3 but is 4"
x_np = np.random.rand(4, 2, 4) * 255.
scale = np.random.rand() * 2.0 - 1.0
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale))
if tf.executing_eagerly():
msg = "input must have 3 channels but instead has 4 "
with self.assertRaisesRegex(tf.errors.InvalidArgumentError, msg):
self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale))
else:
msg = "Dimension must be 3 but is 4"
with self.assertRaisesRegex(ValueError, msg):
self.evaluate(self._adjust_saturation_in_yiq_tf(x_np, scale))


# TODO: get rid of sessions
Expand Down
5 changes: 0 additions & 5 deletions tensorflow_addons/image/transform_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def transform(images,
return img_utils.from_4D_image(output, original_ndims)


@tf.function
def compose_transforms(transforms, name=None):
"""Composes the transforms tensors.

Expand All @@ -131,7 +130,6 @@ def compose_transforms(transforms, name=None):
return matrices_to_flat_transforms(composed)


@tf.function
def flat_transforms_to_matrices(transforms, name=None):
"""Converts projective transforms to affine matrices.

Expand Down Expand Up @@ -165,7 +163,6 @@ def flat_transforms_to_matrices(transforms, name=None):
tf.constant([-1, 3, 3]))


@tf.function
def matrices_to_flat_transforms(transform_matrices, name=None):
"""Converts affine matrices to projective transforms.

Expand Down Expand Up @@ -199,7 +196,6 @@ def matrices_to_flat_transforms(transform_matrices, name=None):
return transforms[:, :8]


@tf.function
def angles_to_projective_transforms(angles,
image_height,
image_width,
Expand Down Expand Up @@ -282,7 +278,6 @@ def _image_projective_transform_grad(op, grad):
return [output, None, None]


@tf.function
def rotate(images, angles, interpolation="NEAREST", name=None):
"""Rotate image(s) counterclockwise by the passed angle(s) in radians.

Expand Down
4 changes: 2 additions & 2 deletions tensorflow_addons/image/transform_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_transform_static_output_shape(self):
self.assertAllEqual([3, 5], result.shape)

def test_transform_unknown_shape(self):
fn = transform_ops.transform.get_concrete_function(
fn = tf.function(transform_ops.transform).get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32),
[1, 0, 0, 0, 1, 0, 0, 0])
for shape in (2, 4), (2, 4, 3), (1, 2, 4, 3):
Expand Down Expand Up @@ -276,7 +276,7 @@ def test_rotate_static_shape(self):
self.assertEqual(image.get_shape(), result.get_shape())

def test_unknown_shape(self):
fn = transform_ops.rotate.get_concrete_function(
fn = tf.function(transform_ops.rotate).get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32), 0)
for shape in (2, 4), (2, 4, 3), (1, 2, 4, 3):
image = tf.ones(shape=shape)
Expand Down
1 change: 0 additions & 1 deletion tensorflow_addons/image/translate_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from tensorflow_addons.image.transform_ops import transform


@tf.function
def translations_to_projective_transforms(translations, name=None):
"""Returns projective transform(s) for the given translation(s).

Expand Down
2 changes: 0 additions & 2 deletions tensorflow_addons/image/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def get_ndims(image):
return image.get_shape().ndims or tf.rank(image)


@tf.function
def to_4D_image(image):
"""Convert 2/3/4D image to 4D image.

Expand Down Expand Up @@ -69,7 +68,6 @@ def _dynamic_to_4D_image(image):
return tf.reshape(image, new_shape)


@tf.function
def from_4D_image(image, ndims):
"""Convert back to an image with `ndims` rank.

Expand Down
Loading