Skip to content

Commit 229344f

Browse files
authored
FIX: compile triplet loss within keras model (#298)
* FIX: compile triplet loss within keras model * remove dummy assertion * add a testcase when the shape of y_true is invalid * remove testcase of invalid shape
1 parent 25cf38e commit 229344f

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

tensorflow_addons/losses/triplet.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,8 @@ def triplet_semihard_loss(y_true, y_pred, margin=1.0):
7373
margin: Float, margin term in the loss definition.
7474
"""
7575
labels, embeddings = y_true, y_pred
76-
# Reshape [batch_size] label tensor to a [batch_size, 1] label tensor.
76+
# Reshape label tensor to [batch_size, 1].
7777
lshape = tf.shape(labels)
78-
assert lshape.shape == 1
7978
labels = tf.reshape(labels, [lshape[0], 1])
8079

8180
# Build pairwise squared distance matrix.

tensorflow_addons/losses/triplet_test.py

+7
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ def test_unweighted(self):
104104
loss = cce_obj(y_true, y_pred)
105105
self.assertAlmostEqual(self.evaluate(loss), loss_np, 3)
106106

107+
def test_keras_model_compile(self):
108+
model = tf.keras.models.Sequential([
109+
tf.keras.layers.Input(shape=(784,)),
110+
tf.keras.layers.Dense(10),
111+
])
112+
model.compile(loss="triplet_semihard_loss", optimizer="adam")
113+
107114

108115
if __name__ == '__main__':
109116
tf.test.main()

0 commit comments

Comments
 (0)