Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Proximal Group Adagrad optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Oct 2, 2018
1 parent 6fb81af commit 6f57fd2
Show file tree
Hide file tree
Showing 11 changed files with 850 additions and 44 deletions.
1 change: 1 addition & 0 deletions docs/api/python/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ Code examples are placed throughout the API documentation and these can be run a
:maxdepth: 1
optimization/optimization.md
optimization/contrib.md
```

## Profiler API
Expand Down
52 changes: 52 additions & 0 deletions docs/api/python/optimization/contrib.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Contrib Optimization API

```eval_rst
.. currentmodule:: mxnet.optimizer.contrib
```

## Overview

This document summaries the contrib APIs used to initialize and update the model
weights during training

```eval_rst
.. autosummary::
:nosignatures:
mxnet.optimizer.contrib
```

The `Contrib Optimization` API, defined in the `optimizer.contrib` package, provides
many useful experimental APIs for new features.
This is a place for the community to try out the new features,
so that feature contributors can receive feedback.

```eval_rst
.. warning:: This package contains experimental APIs and may change in the near future.
```

In the rest of this document, we list routines provided by the `optimizer.contrib` package.

## Contrib

```eval_rst
.. currentmodule:: mxnet.optimizer.contrib
.. autosummary::
:nosignatures:
ProximalGroupAdaGrad
```

## API Reference

<script type="text/javascript" src='../../../_static/js/auto_module_index.js'></script>

```eval_rst
.. automodule:: mxnet.optimizer.contrib
:members:
```

<script>auto_index("api-reference");</script>
23 changes: 23 additions & 0 deletions python/mxnet/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Optimizer API of MXNet."""

from . import optimizer, contrib
from .optimizer import *
# pylint: enable=wildcard-import

__all__ = optimizer.__all__ + ['contrib']
145 changes: 145 additions & 0 deletions python/mxnet/optimizer/contrib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=too-many-lines
"""Contrib optimizers."""
from ..ndarray import (NDArray, clip, contrib, full, mean, norm, sparse, sqrt,
square, zeros)
from .optimizer import Optimizer

# convenience wrapper for Optimizer.Register
register = Optimizer.register # pylint: disable=invalid-name

__all__ = ['ProximalGroupAdaGrad']


@register
class ProximalGroupAdaGrad(Optimizer):
"""Proximal Adagrad optimizer with row-wise learning rates.
This class implements the AdaGrad optimizer described in *Adaptive
Subgradient Methods for Online Learning and Stochastic Optimization*, and
available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf but
uses only a single learning rate for every row of the parameter array.
This optimizer updates each weight by::
grad = clip(grad * rescale_grad, clip_gradient)
history += mean(square(grad), axis=1, keepdims=True)
div = grad / sqrt(history + float_stable_eps)
weight -= div * lr
If `l2_regularization_strength > 0` a proximal operator is used to optimize
with group lasso objective. Weights are updated lazily if the gradient is
sparse. In particular, before using a set of weights for a forward pass,
you may want to ensure that the lazily accumulated group lasso
regularization is applied. This can be achieved by creating a sparse
gradient array that contains explicit 0 data for the indices to be updated:
fake_grad = mx.nd.sparse.row_sparse_array(
(mx.nd.zeros((len(indices), dim)), indices))
weight.grad()[:] = fake_grad
weight.data()._fresh_grad = True
trainer._optimizer._index_update_count[0] -= 1
trainer._optimizer.num_update -= 1
trainer.step(batch_size=1)
For details of the update algorithm see
:class:`~mxnet.ndarray.contrib.proximal_group_adagrad_update`.
This optimizer accepts the following parameters in addition to those
accepted by :class:`.Optimizer`. Weight decay is not supported.
Parameters
----------
l2_regularization_strength : float
Strength of group lasso L2 regularization.
eps: float, optional
Initial value of the history accumulator. Avoids division by 0.
"""

def __init__(self, l2_regularization_strength=0.0, eps=1e-5, **kwargs):
super(ProximalGroupAdaGrad, self).__init__(**kwargs)
self.l2_regularization_strength = l2_regularization_strength
self.float_stable_eps = eps

def create_state(self, index, weight):
assert len(weight.shape) == 2
history = zeros(
(weight.shape[0], 1), weight.context, stype=weight.stype)
last_update = None
if self.l2_regularization_strength > 0:
last_update = full(
shape=(weight.shape[0], ),
val=self.num_update,
ctx=weight.context)
else:
last_update = zeros(1, ctx=weight.context)
return (history, last_update)

def update(self, index, weight, grad, state):
assert (isinstance(weight, NDArray))
assert (isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
assert wd == 0, 'Weight decay is not supported for ProximalGroupAdaGrad'

is_sparse = grad.stype == 'row_sparse'
history = state[0]
last_update = state[1]
if is_sparse:
kwargs = {
'epsilon': self.float_stable_eps,
'rescale_grad': self.rescale_grad
}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
if self.l2_regularization_strength:
kwargs['l2_regularization_strength'] = \
self.l2_regularization_strength
contrib.proximal_group_adagrad_update(
weight,
grad,
history,
out=weight,
last_update=last_update,
lr=lr,
current_update=self.num_update,
**kwargs)
elif self.l2_regularization_strength > 0:
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
history[:] += mean(square(grad), axis=1, keepdims=True)
div = lr * grad / sqrt(history + self.float_stable_eps)
num_skipped = (self.num_update - last_update).expand_dims(1)
scaled_l2 = lr / sqrt(history + self.float_stable_eps) \
* self.l2_regularization_strength * num_skipped
nrm = norm(weight - div, ord=2, axis=1, keepdims=True)
weight[:] = (weight - div) * (1 - scaled_l2 / nrm)
weight[:] *= nrm > scaled_l2
last_update[:] = self.num_update
else:
grad = grad * self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
history[:] += mean(square(grad), axis=1, keepdims=True)
div = lr * grad / sqrt(history + self.float_stable_eps)
weight[:] -= div
20 changes: 13 additions & 7 deletions python/mxnet/optimizer.py → python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@
import pickle
import warnings
import numpy
from .base import py_str
from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update)
from .ndarray import sparse
from .random import normal
from ..base import py_str
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update)
from ..ndarray import sparse
from ..random import normal

__all__ = [
'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LBSGD',
'NAG', 'NDArray', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD',
'Signum', 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
]


class Optimizer(object):
Expand Down
41 changes: 41 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,3 +1957,44 @@ def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, suc
% (str(cs_ret_l), str(obs_freq_l), str(expected_freq_l),
str(buckets), str(probs)))
return cs_ret_l

def compare_ndarray_tuple(t1, t2, rtol=None, atol=None):
"""Compare ndarray tuple."""
if t1 is not None and t2 is not None:
if isinstance(t1, tuple):
for s1, s2 in zip(t1, t2):
compare_ndarray_tuple(s1, s2, rtol, atol)
else:
assert_almost_equal(t1.asnumpy(), t2.asnumpy(), rtol=rtol, atol=atol)


def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default',
rtol=1e-4, atol=1e-5, compare_states=True):
"""Compare opt1 and opt2."""
if w_stype == 'default':
w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
w1 = w2.copyto(default_context())
elif w_stype == 'row_sparse' or w_stype == 'csr':
w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
w1 = w2.copyto(default_context()).tostype('default')
else:
raise Exception("type not supported yet")
if g_stype == 'default':
g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
g1 = g2.copyto(default_context())
elif g_stype == 'row_sparse' or g_stype == 'csr':
g2 = rand_ndarray(shape, g_stype, dtype=dtype)
g1 = g2.copyto(default_context()).tostype('default')
else:
raise Exception("type not supported yet")

state1 = opt1.create_state_multi_precision(0, w1)
state2 = opt2.create_state_multi_precision(0, w2)
if compare_states:
compare_ndarray_tuple(state1, state2)

opt1.update_multi_precision(0, w1, g1, state1)
opt2.update_multi_precision(0, w2, g2, state2)
if compare_states:
compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
Loading

0 comments on commit 6f57fd2

Please # to comment.