File tree 1 file changed +4
-16
lines changed
tensorflow_addons/optimizers
1 file changed +4
-16
lines changed Original file line number Diff line number Diff line change @@ -57,24 +57,12 @@ def _accum_grad(grads_and_vars):
57
57
with tf .init_scope ():
58
58
if not self ._gradients :
59
59
for grad , var in grads_and_vars :
60
- if tf .distribute .has_strategy ():
61
- for v in var .values :
62
- self ._gradients [v .ref ()] = tf .Variable (
63
- tf .zeros_like (v ), trainable = False
64
- )
65
- else :
66
- self ._gradients [var .ref ()] = tf .Variable (
67
- tf .zeros_like (var ), trainable = False
68
- )
60
+ self ._gradients [var .ref ()] = tf .Variable (
61
+ tf .zeros_like (var ), trainable = False
62
+ )
69
63
new_grads_and_vars = []
70
64
for grad , var in grads_and_vars :
71
- if tf .distribute .has_strategy ():
72
- replica_id = tf .get_static_value (
73
- tf .distribute .get_replica_context ().replica_id_in_sync_group
74
- )
75
- handle = self ._gradients [var .values [replica_id ].ref ()]
76
- else :
77
- handle = self ._gradients [var .ref ()]
65
+ handle = self ._gradients [var .ref ()]
78
66
79
67
if isinstance (grad , tf .IndexedSlices ):
80
68
handle .scatter_add (grad )
You can’t perform that action at this time.
0 commit comments