Skip to content

Commit 64b70b4

Browse files
committed
decrease memory usage
1 parent e62cc95 commit 64b70b4

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

tensorflow_addons/optimizers/gradient_accumulator.py

+4-16
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,12 @@ def _accum_grad(grads_and_vars):
5757
with tf.init_scope():
5858
if not self._gradients:
5959
for grad, var in grads_and_vars:
60-
if tf.distribute.has_strategy():
61-
for v in var.values:
62-
self._gradients[v.ref()] = tf.Variable(
63-
tf.zeros_like(v), trainable=False
64-
)
65-
else:
66-
self._gradients[var.ref()] = tf.Variable(
67-
tf.zeros_like(var), trainable=False
68-
)
60+
self._gradients[var.ref()] = tf.Variable(
61+
tf.zeros_like(var), trainable=False
62+
)
6963
new_grads_and_vars = []
7064
for grad, var in grads_and_vars:
71-
if tf.distribute.has_strategy():
72-
replica_id = tf.get_static_value(
73-
tf.distribute.get_replica_context().replica_id_in_sync_group
74-
)
75-
handle = self._gradients[var.values[replica_id].ref()]
76-
else:
77-
handle = self._gradients[var.ref()]
65+
handle = self._gradients[var.ref()]
7866

7967
if isinstance(grad, tf.IndexedSlices):
8068
handle.scatter_add(grad)

0 commit comments

Comments
 (0)