Skip to content

Commit 1b7c53d

Browse files
nkovela1tensorflower-gardener
authored andcommitted
Adds Keras v3 saving testing coverage to Keras layers tests.
PiperOrigin-RevId: 527921888
1 parent e7c4d09 commit 1b7c53d

8 files changed

+359
-143
lines changed

keras/layers/attention/multi_head_attention_test.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from absl.testing import parameterized
2020

2121
import keras
22+
from keras.saving import object_registration
2223
from keras.testing_infra import test_combinations
2324
from keras.testing_infra import test_utils
2425

@@ -515,6 +516,7 @@ def test_initializer(self):
515516
self.assertEqual(output.shape.as_list(), [None, 40, 80])
516517

517518

519+
@object_registration.register_keras_serializable()
518520
class TestModel(keras.Model):
519521
def __init__(self):
520522
super().__init__()
@@ -540,12 +542,19 @@ def call(self, x, training=False):
540542

541543
@test_combinations.run_all_keras_modes(always_skip_v1=True)
542544
class KerasModelSavingTest(test_combinations.TestCase):
543-
def test_keras_saving_subclass(self):
545+
@parameterized.parameters("tf", "keras_v3")
546+
def test_keras_saving_subclass(self, save_format):
544547
model = TestModel()
545548
query = keras.Input(shape=(40, 80))
546549
_ = model(query)
547550
model_path = self.get_temp_dir() + "/tmp_model"
548-
keras.models.save_model(model, model_path, save_format="tf")
551+
if save_format == "keras_v3":
552+
if not tf.__internal__.tf2.enabled():
553+
self.skipTest(
554+
"TF2 must be enabled to use the new `.keras` saving."
555+
)
556+
model_path += ".keras"
557+
keras.models.save_model(model, model_path, save_format=save_format)
549558
reloaded_model = keras.models.load_model(model_path)
550559
self.assertEqual(
551560
len(model.trainable_variables),
@@ -556,7 +565,7 @@ def test_keras_saving_subclass(self):
556565
):
557566
self.assertAllEqual(src_v, loaded_v)
558567

559-
@parameterized.parameters("h5", "tf")
568+
@parameterized.parameters("h5", "tf", "keras_v3")
560569
def test_keras_saving_functional(self, save_format):
561570
model = TestModel()
562571
query = keras.Input(shape=(40, 80))
@@ -565,6 +574,12 @@ def test_keras_saving_functional(self, save_format):
565574
)(query, query)
566575
model = keras.Model(inputs=query, outputs=output)
567576
model_path = self.get_temp_dir() + "/tmp_model"
577+
if save_format == "keras_v3":
578+
if not tf.__internal__.tf2.enabled():
579+
self.skipTest(
580+
"TF2 must be enabled to use the new `.keras` saving."
581+
)
582+
model_path += ".keras"
568583
keras.models.save_model(model, model_path, save_format=save_format)
569584
reloaded_model = keras.models.load_model(model_path)
570585
self.assertEqual(

keras/layers/normalization/spectral_normalization_test.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,27 @@ def test_save_load_model(self):
5151
# initialize model
5252
model.predict(tf.random.uniform((2, 1)))
5353

54-
model.save("test.h5")
55-
new_model = keras.models.load_model("test.h5")
54+
with self.subTest("h5"):
55+
model.save("test.h5")
56+
new_model = keras.models.load_model("test.h5")
5657

57-
self.assertEqual(
58-
model.layers[0].get_config(), new_model.layers[0].get_config()
59-
)
58+
self.assertEqual(
59+
model.layers[0].get_config(), new_model.layers[0].get_config()
60+
)
61+
with self.subTest("savedmodel"):
62+
model.save("test")
63+
new_model = keras.models.load_model("test")
64+
65+
self.assertEqual(
66+
model.layers[0].get_config(), new_model.layers[0].get_config()
67+
)
68+
with self.subTest("keras_v3"):
69+
model.save("test.keras")
70+
new_model = keras.models.load_model("test.keras")
71+
72+
self.assertEqual(
73+
model.layers[0].get_config(), new_model.layers[0].get_config()
74+
)
6075

6176
@test_combinations.run_all_keras_modes
6277
def test_normalization(self):

keras/layers/preprocessing/hashed_crossing_test.py

+33-11
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_from_config(self):
154154
tf.sparse.to_dense(original_outputs),
155155
)
156156

157-
def test_saved_model_keras(self):
157+
def test_saving_keras(self):
158158
string_in = keras.Input(shape=(1,), dtype=tf.string)
159159
int_in = keras.Input(shape=(1,), dtype=tf.int64)
160160
out = hashed_crossing.HashedCrossing(num_bins=10)((string_in, int_in))
@@ -167,17 +167,39 @@ def test_saved_model_keras(self):
167167
output_data = model((string_data, int_data))
168168
self.assertAllClose(output_data, expected_output)
169169

170-
# Save the model to disk.
171-
output_path = os.path.join(self.get_temp_dir(), "saved_model")
172-
model.save(output_path, save_format="tf")
173-
loaded_model = keras.models.load_model(
174-
output_path,
175-
custom_objects={"HashedCrossing": hashed_crossing.HashedCrossing},
176-
)
170+
with self.subTest("savedmodel"):
171+
# Save the model to disk.
172+
output_path = os.path.join(self.get_temp_dir(), "saved_model")
173+
model.save(output_path, save_format="tf")
174+
loaded_model = keras.models.load_model(
175+
output_path,
176+
custom_objects={
177+
"HashedCrossing": hashed_crossing.HashedCrossing
178+
},
179+
)
180+
181+
# Validate correctness of the new model.
182+
new_output_data = loaded_model((string_data, int_data))
183+
self.assertAllClose(new_output_data, expected_output)
184+
185+
with self.subTest("keras_v3"):
186+
if not tf.__internal__.tf2.enabled():
187+
self.skipTest(
188+
"TF2 must be enabled to use the new `.keras` saving."
189+
)
190+
# Save the model to disk.
191+
output_path = os.path.join(self.get_temp_dir(), "model.keras")
192+
model.save(output_path, save_format="keras_v3")
193+
loaded_model = keras.models.load_model(
194+
output_path,
195+
custom_objects={
196+
"HashedCrossing": hashed_crossing.HashedCrossing
197+
},
198+
)
177199

178-
# Validate correctness of the new model.
179-
new_output_data = loaded_model((string_data, int_data))
180-
self.assertAllClose(new_output_data, expected_output)
200+
# Validate correctness of the new model.
201+
new_output_data = loaded_model((string_data, int_data))
202+
self.assertAllClose(new_output_data, expected_output)
181203

182204

183205
if __name__ == "__main__":

keras/layers/preprocessing/hashing_test.py

+24
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,30 @@ def test_saved_model(self):
414414
new_output_data = loaded_model(input_data)
415415
self.assertAllClose(new_output_data, original_output_data)
416416

417+
@test_utils.run_v2_only
418+
def test_save_keras_v3(self):
419+
input_data = np.array(
420+
["omar", "stringer", "marlo", "wire", "skywalker"]
421+
)
422+
423+
inputs = keras.Input(shape=(None,), dtype=tf.string)
424+
outputs = hashing.Hashing(num_bins=100)(inputs)
425+
model = keras.Model(inputs=inputs, outputs=outputs)
426+
427+
original_output_data = model(input_data)
428+
429+
# Save the model to disk.
430+
output_path = os.path.join(self.get_temp_dir(), "tf_keras_model.keras")
431+
model.save(output_path, save_format="keras_v3")
432+
loaded_model = keras.models.load_model(output_path)
433+
434+
# Ensure that the loaded model is unique (so that the save/load is real)
435+
self.assertIsNot(model, loaded_model)
436+
437+
# Validate correctness of the new model.
438+
new_output_data = loaded_model(input_data)
439+
self.assertAllClose(new_output_data, original_output_data)
440+
417441
@parameterized.named_parameters(
418442
(
419443
"list_input",

0 commit comments

Comments
 (0)