Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Fix axis error in normalization layer when loading model from tf back…
Browse files Browse the repository at this point in the history
…end saved h5 (#258)

* Fix axis error in normalization layer when loading model from tf backend saved h5

* Update normalization.py

* Update normalization.py

* Fix axis error in normalization layer when loading model from tf backend saved h5
  • Loading branch information
leondgarse authored Apr 7, 2020
1 parent 7d076c0 commit 5e5e74c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
2 changes: 1 addition & 1 deletion keras/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self,
if axis == -1 and K.image_data_format() == 'channels_first' and K.backend() == 'mxnet':
self.axis = 1
else:
self.axis = axis
self.axis = axis[0] if isinstance(axis, list) and len(axis) == 1 else axis

def build(self, input_shape):
dim = input_shape[self.axis]
Expand Down
49 changes: 49 additions & 0 deletions tests/keras/backend/mxnet_tf_model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest

import tempfile
from keras import backend as K
from keras.layers import Dense, BatchNormalization
from keras.models import load_model, Sequential
from keras.backend import tensorflow_backend as KTF
import warnings

pytestmark = pytest.mark.skipif(K.backend() != "mxnet",
reason="Testing MXNet context supports only for MXNet backend")


class TestMXNetTfModel(object):
def test_batchnorm_layer_reload(self):
# Save a tf backend keras h5 model
tf_model = KTF.tf.keras.models.Sequential([
KTF.tf.keras.layers.Dense(10, kernel_initializer="zeros"),
KTF.tf.keras.layers.BatchNormalization(),
])
tf_model.build(input_shape=(1, 10))
_, fname = tempfile.mkstemp(".h5")
tf_model.save(fname)

# Load from MXNet backend keras
try:
mx_model = load_model(fname, compile=False)
except TypeError:
warnings.warn("Could not reload from tensorflow backend saved model.")
assert False

# Retest with mxnet backend keras save + load
mx_model_2 = Sequential([
Dense(10, kernel_initializer="zeros"),
BatchNormalization(),
])
mx_model_2.build(input_shape=(1, 10))
_, fname = tempfile.mkstemp(".h5")
mx_model_2.save(fname)

try:
mx_model_3 = load_model(fname, compile=False)
except TypeError:
warnings.warn("Could not reload from MXNet backend saved model.")
assert False


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 5e5e74c

Please # to comment.