This repository has been archived by the owner on Jul 1, 2024. It is now read-only.
forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 65
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix axis error in normalization layer when loading model from tf back…
…end saved h5 (#258) * Fix axis error in normalization layer when loading model from tf backend saved h5 * Update normalization.py * Update normalization.py * Fix axis error in normalization layer when loading model from tf backend saved h5
- Loading branch information
1 parent
7d076c0
commit 5e5e74c
Showing
2 changed files
with
50 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import pytest | ||
|
||
import tempfile | ||
from keras import backend as K | ||
from keras.layers import Dense, BatchNormalization | ||
from keras.models import load_model, Sequential | ||
from keras.backend import tensorflow_backend as KTF | ||
import warnings | ||
|
||
pytestmark = pytest.mark.skipif(K.backend() != "mxnet", | ||
reason="Testing MXNet context supports only for MXNet backend") | ||
|
||
|
||
class TestMXNetTfModel(object): | ||
def test_batchnorm_layer_reload(self): | ||
# Save a tf backend keras h5 model | ||
tf_model = KTF.tf.keras.models.Sequential([ | ||
KTF.tf.keras.layers.Dense(10, kernel_initializer="zeros"), | ||
KTF.tf.keras.layers.BatchNormalization(), | ||
]) | ||
tf_model.build(input_shape=(1, 10)) | ||
_, fname = tempfile.mkstemp(".h5") | ||
tf_model.save(fname) | ||
|
||
# Load from MXNet backend keras | ||
try: | ||
mx_model = load_model(fname, compile=False) | ||
except TypeError: | ||
warnings.warn("Could not reload from tensorflow backend saved model.") | ||
assert False | ||
|
||
# Retest with mxnet backend keras save + load | ||
mx_model_2 = Sequential([ | ||
Dense(10, kernel_initializer="zeros"), | ||
BatchNormalization(), | ||
]) | ||
mx_model_2.build(input_shape=(1, 10)) | ||
_, fname = tempfile.mkstemp(".h5") | ||
mx_model_2.save(fname) | ||
|
||
try: | ||
mx_model_3 = load_model(fname, compile=False) | ||
except TypeError: | ||
warnings.warn("Could not reload from MXNet backend saved model.") | ||
assert False | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |