diff --git a/tensorflow_addons/losses/triplet.py b/tensorflow_addons/losses/triplet.py index 0b3f275f35..6717ea722c 100644 --- a/tensorflow_addons/losses/triplet.py +++ b/tensorflow_addons/losses/triplet.py @@ -73,9 +73,8 @@ def triplet_semihard_loss(y_true, y_pred, margin=1.0): margin: Float, margin term in the loss definition. """ labels, embeddings = y_true, y_pred - # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. + # Reshape label tensor to [batch_size, 1]. lshape = tf.shape(labels) - assert lshape.shape == 1 labels = tf.reshape(labels, [lshape[0], 1]) # Build pairwise squared distance matrix. diff --git a/tensorflow_addons/losses/triplet_test.py b/tensorflow_addons/losses/triplet_test.py index 91d496825f..817ec6d50d 100644 --- a/tensorflow_addons/losses/triplet_test.py +++ b/tensorflow_addons/losses/triplet_test.py @@ -104,6 +104,13 @@ def test_unweighted(self): loss = cce_obj(y_true, y_pred) self.assertAlmostEqual(self.evaluate(loss), loss_np, 3) + def test_keras_model_compile(self): + model = tf.keras.models.Sequential([ + tf.keras.layers.Input(shape=(784,)), + tf.keras.layers.Dense(10), + ]) + model.compile(loss="triplet_semihard_loss", optimizer="adam") + if __name__ == '__main__': tf.test.main()