Skip to content

Commit 7d2d553

Browse files
committed
fix bugs
1 parent 64b70b4 commit 7d2d553

File tree

2 files changed

+54
-81
lines changed

2 files changed

+54
-81
lines changed

tensorflow_addons/optimizers/gradient_accumulator.py

+42-45
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,13 @@ def __init__(
4949
super().__init__(name, **kwargs)
5050
self._optimizer = tf.keras.optimizers.get(inner_optimizer)
5151
self._step = None
52-
self._gradients = {}
5352
self._accum_steps = accum_steps
5453
self._reduction = reduction
5554

5655
def _accum_grad(grads_and_vars):
57-
with tf.init_scope():
58-
if not self._gradients:
59-
for grad, var in grads_and_vars:
60-
self._gradients[var.ref()] = tf.Variable(
61-
tf.zeros_like(var), trainable=False
62-
)
6356
new_grads_and_vars = []
6457
for grad, var in grads_and_vars:
65-
handle = self._gradients[var.ref()]
58+
handle = self.get_slot(var, "ga")
6659

6760
if isinstance(grad, tf.IndexedSlices):
6861
handle.scatter_add(grad)
@@ -84,9 +77,11 @@ def _get_grad():
8477
values = tf.gather(new_grad, indices)
8578
dense_shape = tf.constant(new_grad.shape.as_list())
8679
handle.assign(
87-
tf.zeros_like(handle), use_locking=self._use_locking
80+
tf.zeros_like(handle),
81+
use_locking=self._use_locking,
82+
read_value=False,
8883
)
89-
return values, tf.cast(indices, tf.int32), dense_shape
84+
return values, tf.cast(indices, grad.indices.dtype), dense_shape
9085

9186
values, indices, dense_shape = tf.cond(
9287
self.step % self._accum_steps == 0,
@@ -100,14 +95,18 @@ def _get_grad():
10095
new_grad = tf.IndexedSlices(values, indices, dense_shape)
10196
new_grads_and_vars.append((new_grad, var))
10297
else:
103-
handle.assign_add(grad)
98+
handle.assign_add(
99+
grad, use_locking=self._use_locking, read_value=False
100+
)
104101

105102
def _get_grad():
106103
new_grad = handle.read_value()
107104
if self._reduction == "MEAN":
108105
new_grad /= tf.cast(self._accum_steps, new_grad.dtype)
109106
handle.assign(
110-
tf.zeros_like(handle), use_locking=self._use_locking
107+
tf.zeros_like(handle),
108+
use_locking=self._use_locking,
109+
read_value=False,
111110
)
112111
return new_grad
113112

@@ -119,11 +118,39 @@ def _get_grad():
119118
new_grads_and_vars.append((new_grad, var))
120119
return new_grads_and_vars
121120

122-
self._optimizer.gradient_transformers.append(_accum_grad)
121+
self.gradient_transformers.append(_accum_grad)
123122
self._iterations = self._optimizer.iterations
124123

125124
def _create_slots(self, var_list):
126125
self._optimizer._create_slots(var_list=var_list)
126+
for var in var_list:
127+
self.add_slot(var, "ga")
128+
129+
def _resource_apply_dense(self, grad, handle, apply_state):
130+
if "apply_state" in self._optimizer._dense_apply_args:
131+
return self.inner_optimizer._resource_apply_dense(grad, handle, apply_state)
132+
else:
133+
return self.inner_optimizer._resource_apply_dense(grad, handle)
134+
135+
def _resource_apply_sparse(self, grad, handle, indices, apply_state):
136+
if "apply_state" in self._optimizer._sparse_apply_args:
137+
return self.inner_optimizer._resource_apply_sparse(
138+
grad, handle, indices, apply_state=apply_state
139+
)
140+
else:
141+
return self.inner_optimizer._resource_apply_sparse(grad, handle, indices)
142+
143+
def _resource_apply_sparse_duplicate_indices(
144+
self, grad, handle, indices, apply_state=None
145+
):
146+
if "apply_state" in self._optimizer._sparse_apply_args:
147+
return self.inner_optimizer._resource_apply_sparse_duplicate_indices(
148+
grad, handle, indices, apply_state=apply_state
149+
)
150+
else:
151+
return self.inner_optimizer._resource_apply_sparse_duplicate_indices(
152+
grad, handle, indices
153+
)
127154

128155
@property
129156
def step(self):
@@ -151,49 +178,19 @@ def step(self, variable):
151178
self._step = variable
152179
self._weights.append(self._step)
153180

154-
@property
155-
def gradients(self):
156-
"""The accumulated gradients on the current replica."""
157-
if not self._gradients:
158-
raise ValueError(
159-
"The accumulator should be called first to initialize the gradients"
160-
)
161-
return list(
162-
gradient.read_value() if gradient is not None else gradient
163-
for _, gradient in self._gradients
164-
)
165-
166181
def apply_gradients(self, grads_and_vars, name=None, **kwargs):
167-
train_op = self._optimizer.apply_gradients(grads_and_vars, name, **kwargs)
182+
train_op = super().apply_gradients(grads_and_vars, name, **kwargs)
168183
with tf.control_dependencies([train_op]):
169184
with tf.control_dependencies(
170185
[
171-
self._optimizer.iterations.assign_add(
186+
self.iterations.assign_add(
172187
tf.cast(self.step % self._accum_steps == 0, tf.int64),
173188
read_value=False,
174189
)
175190
]
176191
):
177192
return self.step.assign_add(1, read_value=False)
178193

179-
def reset(self):
180-
"""Resets the accumulated gradients on the current replica."""
181-
assign_ops = []
182-
if not self._gradients:
183-
return assign_ops
184-
185-
for _, gradient in self._gradients:
186-
if gradient is not None:
187-
assign_ops.append(
188-
gradient.assign(
189-
tf.zeros_like(gradient),
190-
use_locking=self._use_locking,
191-
read_value=False,
192-
)
193-
)
194-
195-
return tf.group(assign_ops)
196-
197194
@property
198195
def inner_optimizer(self):
199196
"""The optimizer that this LossScaleOptimizer is wrapping."""

tensorflow_addons/optimizers/tests/gradient_accumulator_test.py

+12-36
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
import numpy as np
1818
import pytest
1919
import tensorflow as tf
20-
from tensorflow_addons.utils import test_utils
2120

2221
from tensorflow_addons.optimizers import GradientAccumulator
2322

2423

2524
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
25+
@pytest.mark.with_device(["cpu", "gpu", tf.distribute.MirroredStrategy])
2626
def test_run():
2727
var0 = tf.Variable([1.0, 2.0])
2828
var1 = tf.Variable([3.0, 4.0])
@@ -35,14 +35,16 @@ def test_run():
3535

3636
opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0), accum_steps)
3737

38+
strategy = tf.distribute.get_strategy()
3839
for _ in range(accum_steps + 1):
39-
opt.apply_gradients(grads_and_vars)
40+
strategy.run(opt.apply_gradients, [grads_and_vars])
4041

4142
np.testing.assert_allclose(var0.read_value(), [0.6, 1.6])
4243
np.testing.assert_allclose(var1.read_value(), [2.96, 3.96])
4344

4445

4546
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
47+
@pytest.mark.with_device(["cpu", "gpu", tf.distribute.MirroredStrategy])
4648
def test_sparse():
4749
var0 = tf.Variable([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]])
4850
var1 = tf.Variable([[3.0, 4.0, 0.0]])
@@ -60,38 +62,13 @@ def test_sparse():
6062

6163
grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
6264
opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0))
65+
strategy = tf.distribute.get_strategy()
6366
for _ in range(8):
64-
opt.apply_gradients(grads_and_vars)
67+
strategy.run(opt.apply_gradients, [grads_and_vars])
6568
np.testing.assert_allclose(var0.read_value(), [[1.0, 2.0, 0.0], [0.2, 1.2, 0.0]])
6669
np.testing.assert_allclose(var1.read_value(), [[2.92, 3.92, 0.0]])
6770

6871

69-
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
70-
@pytest.mark.needs_gpu
71-
def test_sparse_multi_gpus():
72-
strategy = tf.distribute.MirroredStrategy(test_utils.gpus_for_testing())
73-
with strategy.scope():
74-
var0 = tf.Variable([[1.0, 2.0, 0.0]])
75-
var1 = tf.Variable([[3.0, 4.0, 0.0]])
76-
77-
grads0 = tf.IndexedSlices(
78-
tf.constant([[0.1, 0.1, 0.0]]),
79-
tf.constant([0]),
80-
tf.constant([1, 3]),
81-
)
82-
grads1 = tf.IndexedSlices(
83-
tf.constant([[0.01, 0.01, 0.0]]),
84-
tf.constant([0]),
85-
tf.constant([1, 3]),
86-
)
87-
88-
grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
89-
opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0))
90-
strategy.run(opt.apply_gradients, [grads_and_vars])
91-
np.testing.assert_allclose(var0.read_value(), [[1.0, 2.0, 0.0]])
92-
np.testing.assert_allclose(var1.read_value(), [[3.0, 4.0, 0.0]])
93-
94-
9572
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
9673
def test_dense():
9774
grad = tf.Variable([[0.1]])
@@ -133,7 +110,7 @@ def test_config():
133110

134111

135112
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
136-
@pytest.mark.needs_gpu
113+
@pytest.mark.with_device([tf.distribute.MirroredStrategy])
137114
def test_fit_simple_linear_model():
138115
seed = 0x2019
139116
np.random.seed(seed)
@@ -142,13 +119,12 @@ def test_fit_simple_linear_model():
142119
x = np.random.standard_normal((num_examples, 3))
143120
w = np.random.standard_normal((3, 1))
144121
y = np.dot(x, w) + np.random.standard_normal((num_examples, 1)) * 1e-4
145-
strategy = tf.distribute.MirroredStrategy(test_utils.gpus_for_testing())
146-
with strategy.scope():
147-
model = tf.keras.models.Sequential()
148-
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1))
149122

150-
opt = GradientAccumulator("sgd")
151-
model.compile(opt, loss="mse")
123+
model = tf.keras.models.Sequential()
124+
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1))
125+
126+
opt = GradientAccumulator("sgd")
127+
model.compile(opt, loss="mse")
152128

153129
model.fit(x, y, epochs=5)
154130

0 commit comments

Comments
 (0)