From 637d6e699b8e25515e096423600aedd835af7c19 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Mon, 27 Aug 2018 11:25:12 +0000 Subject: [PATCH] Address comments --- python/mxnet/contrib/optimizer.py | 28 ++++---- src/operator/contrib/optimizer_op-inl.h | 8 +-- src/operator/contrib/optimizer_op.cc | 5 +- .../python/unittest/test_contrib_optimizer.py | 65 +++++++++---------- 4 files changed, 50 insertions(+), 56 deletions(-) diff --git a/python/mxnet/contrib/optimizer.py b/python/mxnet/contrib/optimizer.py index 15dabc4026b0..f06870565057 100644 --- a/python/mxnet/contrib/optimizer.py +++ b/python/mxnet/contrib/optimizer.py @@ -18,8 +18,8 @@ # pylint: disable=too-many-lines """Contrib optimizers.""" -from ..ndarray import (NDArray, clip, full, mean, norm, - proximal_group_adagrad_update, sqrt, square, zeros) +from ..ndarray import (NDArray, clip, contrib, full, mean, norm, sqrt, square, + zeros) from ..optimizer import Optimizer # convenience wrapper for Optimizer.Register @@ -40,7 +40,7 @@ class ProximalGroupAdaGrad(Optimizer): grad = clip(grad * rescale_grad, clip_gradient) history += mean(square(grad), axis=1, keepdims=True) div = grad / sqrt(history + float_stable_eps) - weight += (div + weight * wd) * -lr + weight -= div * lr If `l2_regularization_strength > 0` a proximal operator is used to optimize with group lasso objective. Weights are updated lazily if the gradient is @@ -58,7 +58,7 @@ class ProximalGroupAdaGrad(Optimizer): trainer.step(batch_size=1) For details of the update algorithm see - :class:`~mxnet.ndarray.proximal_group_adagrad_update`. + :class:`~mxnet.ndarray.contrib.proximal_group_adagrad_update`. This optimizer accepts the following parameters in addition to those accepted by :class:`.Optimizer`. Weight decay is not supported. @@ -81,15 +81,15 @@ def create_state(self, index, weight): assert len(weight.shape) == 2 history = zeros( (weight.shape[0], 1), weight.context, stype=weight.stype) - last_update_buffer = None + last_update = None if self.l2_regularization_strength > 0: - last_update_buffer = full( + last_update = full( shape=(weight.shape[0], ), val=self.num_update, ctx=weight.context) else: - last_update_buffer = zeros(1, ctx=weight.context) - return (history, last_update_buffer) + last_update = zeros(1, ctx=weight.context) + return (history, last_update) def update(self, index, weight, grad, state): assert (isinstance(weight, NDArray)) @@ -97,21 +97,21 @@ def update(self, index, weight, grad, state): self._update_count(index) lr = self._get_lr(index) wd = self._get_wd(index) - assert wd == 0 + assert wd == 0, 'Weight decay is not supported for ProximalGroupAdaGrad' is_sparse = grad.stype == 'row_sparse' history = state[0] - last_update_buffer = state[1] + last_update = state[1] if self.l2_regularization_strength > 0 and is_sparse: kwargs = dict() if self.clip_gradient: kwargs['clip_gradient'] = self.clip_gradient - proximal_group_adagrad_update( + contrib.proximal_group_adagrad_update( weight, grad, history, out=weight, - last_update_buffer=last_update_buffer, + last_update=last_update, rescale_grad=self.rescale_grad, epsilon=self.float_stable_eps, lr=lr, @@ -124,13 +124,13 @@ def update(self, index, weight, grad, state): grad = clip(grad, -self.clip_gradient, self.clip_gradient) history[:] += mean(square(grad), axis=1, keepdims=True) div = lr * grad / sqrt(history + self.float_stable_eps) - num_skipped = (self.num_update - last_update_buffer).expand_dims(1) + num_skipped = (self.num_update - last_update).expand_dims(1) scaled_l2 = lr / sqrt(history + self.float_stable_eps) \ * self.l2_regularization_strength * num_skipped nrm = norm(weight - div, ord=2, axis=1, keepdims=True) weight[:] = (weight - div) * (1 - scaled_l2 / nrm) weight[:] *= nrm > scaled_l2 - last_update_buffer[:] = self.num_update + last_update[:] = self.num_update else: grad = grad * self.rescale_grad if self.clip_gradient is not None: diff --git a/src/operator/contrib/optimizer_op-inl.h b/src/operator/contrib/optimizer_op-inl.h index 12d87635be34..4659ea1e534c 100644 --- a/src/operator/contrib/optimizer_op-inl.h +++ b/src/operator/contrib/optimizer_op-inl.h @@ -143,7 +143,7 @@ template struct ProximalGroupAdagradDnsRspKernel { // Compute number of weight updates skipped due to lazy_update DType num_skipped = current_update - last_update_data[grad_idx[i]]; last_update_data[grad_idx[i]] = current_update; - // Warn in case of erroneous last_update_buffer + // Warn in case of erroneous last_update if (num_skipped < 0) { num_skipped = 0; std::printf("Got invalid last_update in proximal_adagrad_update. " @@ -200,13 +200,13 @@ template struct ProximalGroupAdagradDnsRspKernel { }; /* - * \brief Adagrad update implementation for dense weight and row_sparse grad. + * \brief Proximal Group Adagrad update implementation for dense weight and row_sparse grad. */ template inline void ProximalGroupAdagradUpdateDnsRspDnsImpl( const ProximalGroupAdagradParam ¶m, const OpContext &ctx, const TBlob &weight, const NDArray &grad, const TBlob &state, - const TBlob &last_update_buffer, const OpReqType &req, TBlob *out) { + const TBlob &last_update, const OpReqType &req, TBlob *out) { using namespace mshadow; using namespace mshadow::expr; using namespace mshadow_op; @@ -229,7 +229,7 @@ inline void ProximalGroupAdagradUpdateDnsRspDnsImpl( const IType *grad_idx = grad.aux_data(rowsparse::kIdx).dptr(); const DType *grad_val = grad.data().dptr(); DType *state_data = state.dptr(); - DType *last_update_data = last_update_buffer.dptr(); + DType *last_update_data = last_update.dptr(); const nnvm::dim_t num_grad = grad.aux_shape(rowsparse::kIdx)[0]; const auto row_length = weight.shape_.ProdShape(1, weight.ndim()); diff --git a/src/operator/contrib/optimizer_op.cc b/src/operator/contrib/optimizer_op.cc index c3fd6e52fae1..1cbc392e7233 100644 --- a/src/operator/contrib/optimizer_op.cc +++ b/src/operator/contrib/optimizer_op.cc @@ -50,8 +50,7 @@ inline bool ProximalGroupAdagradShape(const nnvm::NodeAttrs &attrs, (in_attrs->at(0)[0] == in_attrs->at(2)[0]); } -NNVM_REGISTER_OP(proximal_group_adagrad_update) -MXNET_ADD_SPARSE_OP_ALIAS(proximal_group_adagrad_update) +NNVM_REGISTER_OP(_contrib_proximal_group_adagrad_update) .describe(R"code(Update function for Proximal Group AdaGrad optimizer. Referenced from *Adaptive Subgradient Methods for Online Learning and Stochastic Optimization*, @@ -88,7 +87,7 @@ Note that non-zero values for the weight decay option are not supported. .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_argument("history", "NDArray-or-Symbol", "History") -.add_argument("last_update_buffer", "NDArray-or-Symbol", "Last update buffer") +.add_argument("last_update", "NDArray-or-Symbol", "Array storing last update counter for each row.") .add_arguments(ProximalGroupAdagradParam::__FIELDS__()); } // namespace op diff --git a/tests/python/unittest/test_contrib_optimizer.py b/tests/python/unittest/test_contrib_optimizer.py index 4836ac149339..dbc88f614938 100644 --- a/tests/python/unittest/test_contrib_optimizer.py +++ b/tests/python/unittest/test_contrib_optimizer.py @@ -83,41 +83,36 @@ def test_proximal_group_adagrad(): 'l2_regularization_strength': 0.05 }] for dtype in [np.float32]: - for eps_option in eps_options: - for cg_option in cg_options: - for rg_option in rg_options: - for l2_option in l2_options: - kwarg = dict(wd=0.0) - kwarg.update(eps_option) - kwarg.update(cg_option) - kwarg.update(rg_option) - kwarg.update(l2_option) - compare_optimizer( - opt1(**kwarg), - opt2(**kwarg), - shape, - dtype, - compare_states=False) - if l2_option.get('l2_regularization_strength', - 0.0) == 0.0: - # By design results for PyOp which always performs - # dense update will differ if - # l2_regularization_strength > 0 - compare_optimizer( - opt1(**kwarg), - opt2(**kwarg), - shape, - dtype, - w_stype='row_sparse', - g_stype='row_sparse', - compare_states=False) - compare_optimizer( - opt1(**kwarg), - opt2(**kwarg), - shape, - dtype, - g_stype='row_sparse', - compare_states=False) + for options in itertools.product(eps_options, cg_options, rg_options, + l2_options): + kwarg = dict(wd=0.0) + for option in options: + kwarg.update(option) + compare_optimizer( + opt1(**kwarg), + opt2(**kwarg), + shape, + dtype, + compare_states=False) + if kwarg.get('l2_regularization_strength', 0.0) == 0.0: + # By design results for PyOp which always performs + # dense update will differ if + # l2_regularization_strength > 0 + compare_optimizer( + opt1(**kwarg), + opt2(**kwarg), + shape, + dtype, + w_stype='row_sparse', + g_stype='row_sparse', + compare_states=False) + compare_optimizer( + opt1(**kwarg), + opt2(**kwarg), + shape, + dtype, + g_stype='row_sparse', + compare_states=False) if __name__ == '__main__':