Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Fix custom loss for MXNet backend. Fix bug in Concat layer (#110)
Browse files Browse the repository at this point in the history
* Fix custom loss usage in MXNet backend. Issue - 25

* Fix CR comments

* Fix CR comments
  • Loading branch information
sandeep-krishnamurthy committed Jun 15, 2018
1 parent d8c30a4 commit 977baa2
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 13 deletions.
12 changes: 9 additions & 3 deletions keras/backend/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
shape = tuple([0 if dim is None else dim for dim in shape])
else:
shape = tuple([0 for _ in range(ndim)])

sym = _keras_variable(name, shape=shape, dtype=dtype)
sym._keras_shape = tuple([d if d != 0 else None for d in shape])
sym._mxnet_placeholder = True
Expand Down Expand Up @@ -1949,9 +1950,14 @@ def concatenate(tensors, axis=-1):
A tensor.
"""
if axis < 0:
axis += ndim(tensors[0])
rank = ndim(tensors[0])
if rank:
axis %= rank
else:
axis = 0

tensors = [t.symbol for t in tensors]
return KerasSymbol(mx.sym.Concat(*tensors, dim=axis))
return KerasSymbol(mx.sym.concat(*tensors, dim=axis))


@keras_mxnet_symbol
Expand Down Expand Up @@ -2809,7 +2815,7 @@ def softmax(x):
# Returns
A tensor.
"""
return KerasSymbol(mx.sym.SoftmaxActivation(data=x.symbol))
return KerasSymbol(mx.sym.softmax(data=x.symbol))


@keras_mxnet_symbol
Expand Down
5 changes: 5 additions & 0 deletions keras/engine/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,11 @@ def weighted(y_true, y_pred, weights, mask=None):
weight_ndim = K.ndim(weights)
score_array = K.mean(score_array,
axis=list(range(weight_ndim, ndim)))
# If sample_weights shape is like (100, ), we convert it to (100, 1).

# Because, MXNet treats the shape (100, ) as (100) leading to broadcast operator
# 
failures in below operations.

if K.backend() == 'mxnet' and weight_ndim == 1:
weights = K.reshape(weights, shape=(weights.shape[0], 1))
score_array *= weights
score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx()))
return K.mean(score_array)
Expand Down
4 changes: 2 additions & 2 deletions tests/keras/layers/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_batchnorm_correctness_2d():


@pytest.mark.skipif((K.backend() == 'mxnet'),
reason='MXNet backend does not allow predict() before compile()')
reason='MXNet backend uses native BatchNorm operator. Do not do updates in the model.')
@keras_test
def test_batchnorm_training_argument():
bn1 = normalization.BatchNormalization(input_shape=(10,))
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_batchnorm_convnet_no_center_no_scale():


@pytest.mark.skipif((K.backend() == 'mxnet'),
reason='MXNet backend uses native BatchNorm operator. Do do updates in the model.')
reason='MXNet backend uses native BatchNorm operator. Do not do updates in the model.')
@keras_test
def test_shared_batchnorm():
'''Test that a BN layer can be shared
Expand Down
3 changes: 0 additions & 3 deletions tests/keras/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ def test_serializing_loss_class():
assert deserialized.mse_fraction == 0.3


# https://github.com/deep-learning-tools/keras/issues/25
@pytest.mark.skipif(K.backend() == 'mxnet',
reason='MXNet backend does not fully support custom loss yet.')
def test_serializing_model_with_loss_class(tmpdir):
model_filename = str(tmpdir / 'custom_loss.hdf')

Expand Down
3 changes: 0 additions & 3 deletions tests/keras/utils/vis_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@
from keras.layers import TimeDistributed
from keras.models import Sequential
from keras.utils import vis_utils
from keras import backend as K


@pytest.mark.skipif(K.backend() == 'mxnet',
reason='MXNet backend does not support LSTM yet')
def test_plot_model():
model = Sequential()
model.add(Conv2D(filters=2, kernel_size=(2, 3), input_shape=(3, 5, 5), name='conv'))
Expand Down
2 changes: 0 additions & 2 deletions tests/test_loss_weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ def test_sequential_sample_weights():
assert(score < standard_score_sequential)


@pytest.mark.skipif(K.backend() == 'mxnet',
reason='MXNet backend does not support GRU yet.')
@keras_test
def test_sequential_temporal_sample_weights():
(x_train, y_train), (x_test, y_test), (sample_weight, class_weight, test_ids) = _get_test_data()
Expand Down

0 comments on commit 977baa2

Please # to comment.