Skip to content

Commit

Permalink
Exposed the pure python implementation of mish and softshrink. (#1252)
Browse files Browse the repository at this point in the history
* Exposed the pure python implementation of mish.

* Exposed the pure implementation of softshrink publicly.

* Fix name.

* The same error is not raise in python and in the custom op.
  • Loading branch information
gabrieldemarmiesse authored Mar 9, 2020
1 parent 9808b3e commit b1079c2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
12 changes: 12 additions & 0 deletions tensorflow_addons/activations/mish.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from tensorflow_addons.utils import types
from tensorflow_addons.utils.resource_loader import LazySO
from tensorflow_addons import options

_activation_so = LazySO("custom_ops/activations/_activation_ops.so")

Expand All @@ -36,6 +37,17 @@ def mish(x: types.TensorLike) -> tf.Tensor:
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)

if not options.TF_ADDONS_PY_OPS:
try:
return _mish_custom_op(x)
except tf.errors.NotFoundError:
options.warn_fallback("mish")

return _mish_custom_op(x)


def _mish_custom_op(x):
return _activation_so.ops.addons_mish(x)


Expand Down
12 changes: 12 additions & 0 deletions tensorflow_addons/activations/softshrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from tensorflow_addons.utils import types
from tensorflow_addons.utils.resource_loader import LazySO
from tensorflow_addons import options

_activation_so = LazySO("custom_ops/activations/_activation_ops.so")

Expand All @@ -40,6 +41,17 @@ def softshrink(
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)

if not options.TF_ADDONS_PY_OPS:
try:
return _softshrink_custom_op(x, lower, upper)
except tf.errors.NotFoundError:
options.warn_fallback("softshrink")

return _softshrink_py(x, lower, upper)


def _softshrink_custom_op(x, lower, upper):
return _activation_so.ops.addons_softshrink(x, lower, upper)


Expand Down
7 changes: 5 additions & 2 deletions tensorflow_addons/activations/softshrink_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@
import numpy as np
import tensorflow as tf
from tensorflow_addons.activations import softshrink
from tensorflow_addons.activations.softshrink import _softshrink_py
from tensorflow_addons.activations.softshrink import (
_softshrink_py,
_softshrink_custom_op,
)
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class SoftshrinkTest(tf.test.TestCase, parameterized.TestCase):
def test_invalid(self):
with self.assertRaisesOpError("lower must be less than or equal to upper."):
y = softshrink(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0)
y = _softshrink_custom_op(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0)
self.evaluate(y)

@parameterized.named_parameters(
Expand Down

0 comments on commit b1079c2

Please # to comment.