@@ -49,20 +49,13 @@ def __init__(
49
49
super ().__init__ (name , ** kwargs )
50
50
self ._optimizer = tf .keras .optimizers .get (inner_optimizer )
51
51
self ._step = None
52
- self ._gradients = {}
53
52
self ._accum_steps = accum_steps
54
53
self ._reduction = reduction
55
54
56
55
def _accum_grad (grads_and_vars ):
57
- with tf .init_scope ():
58
- if not self ._gradients :
59
- for grad , var in grads_and_vars :
60
- self ._gradients [var .ref ()] = tf .Variable (
61
- tf .zeros_like (var ), trainable = False
62
- )
63
56
new_grads_and_vars = []
64
57
for grad , var in grads_and_vars :
65
- handle = self ._gradients [ var . ref ()]
58
+ handle = self .get_slot ( var , "ga" )
66
59
67
60
if isinstance (grad , tf .IndexedSlices ):
68
61
handle .scatter_add (grad )
@@ -84,9 +77,11 @@ def _get_grad():
84
77
values = tf .gather (new_grad , indices )
85
78
dense_shape = tf .constant (new_grad .shape .as_list ())
86
79
handle .assign (
87
- tf .zeros_like (handle ), use_locking = self ._use_locking
80
+ tf .zeros_like (handle ),
81
+ use_locking = self ._use_locking ,
82
+ read_value = False ,
88
83
)
89
- return values , tf .cast (indices , tf . int32 ), dense_shape
84
+ return values , tf .cast (indices , grad . indices . dtype ), dense_shape
90
85
91
86
values , indices , dense_shape = tf .cond (
92
87
self .step % self ._accum_steps == 0 ,
@@ -100,14 +95,18 @@ def _get_grad():
100
95
new_grad = tf .IndexedSlices (values , indices , dense_shape )
101
96
new_grads_and_vars .append ((new_grad , var ))
102
97
else :
103
- handle .assign_add (grad )
98
+ handle .assign_add (
99
+ grad , use_locking = self ._use_locking , read_value = False
100
+ )
104
101
105
102
def _get_grad ():
106
103
new_grad = handle .read_value ()
107
104
if self ._reduction == "MEAN" :
108
105
new_grad /= tf .cast (self ._accum_steps , new_grad .dtype )
109
106
handle .assign (
110
- tf .zeros_like (handle ), use_locking = self ._use_locking
107
+ tf .zeros_like (handle ),
108
+ use_locking = self ._use_locking ,
109
+ read_value = False ,
111
110
)
112
111
return new_grad
113
112
@@ -119,11 +118,39 @@ def _get_grad():
119
118
new_grads_and_vars .append ((new_grad , var ))
120
119
return new_grads_and_vars
121
120
122
- self ._optimizer . gradient_transformers .append (_accum_grad )
121
+ self .gradient_transformers .append (_accum_grad )
123
122
self ._iterations = self ._optimizer .iterations
124
123
125
124
def _create_slots (self , var_list ):
126
125
self ._optimizer ._create_slots (var_list = var_list )
126
+ for var in var_list :
127
+ self .add_slot (var , "ga" )
128
+
129
+ def _resource_apply_dense (self , grad , handle , apply_state ):
130
+ if "apply_state" in self ._optimizer ._dense_apply_args :
131
+ return self .inner_optimizer ._resource_apply_dense (grad , handle , apply_state )
132
+ else :
133
+ return self .inner_optimizer ._resource_apply_dense (grad , handle )
134
+
135
+ def _resource_apply_sparse (self , grad , handle , indices , apply_state ):
136
+ if "apply_state" in self ._optimizer ._sparse_apply_args :
137
+ return self .inner_optimizer ._resource_apply_sparse (
138
+ grad , handle , indices , apply_state = apply_state
139
+ )
140
+ else :
141
+ return self .inner_optimizer ._resource_apply_sparse (grad , handle , indices )
142
+
143
+ def _resource_apply_sparse_duplicate_indices (
144
+ self , grad , handle , indices , apply_state = None
145
+ ):
146
+ if "apply_state" in self ._optimizer ._sparse_apply_args :
147
+ return self .inner_optimizer ._resource_apply_sparse_duplicate_indices (
148
+ grad , handle , indices , apply_state = apply_state
149
+ )
150
+ else :
151
+ return self .inner_optimizer ._resource_apply_sparse_duplicate_indices (
152
+ grad , handle , indices
153
+ )
127
154
128
155
@property
129
156
def step (self ):
@@ -151,49 +178,19 @@ def step(self, variable):
151
178
self ._step = variable
152
179
self ._weights .append (self ._step )
153
180
154
- @property
155
- def gradients (self ):
156
- """The accumulated gradients on the current replica."""
157
- if not self ._gradients :
158
- raise ValueError (
159
- "The accumulator should be called first to initialize the gradients"
160
- )
161
- return list (
162
- gradient .read_value () if gradient is not None else gradient
163
- for _ , gradient in self ._gradients
164
- )
165
-
166
181
def apply_gradients (self , grads_and_vars , name = None , ** kwargs ):
167
- train_op = self . _optimizer .apply_gradients (grads_and_vars , name , ** kwargs )
182
+ train_op = super () .apply_gradients (grads_and_vars , name , ** kwargs )
168
183
with tf .control_dependencies ([train_op ]):
169
184
with tf .control_dependencies (
170
185
[
171
- self ._optimizer . iterations .assign_add (
186
+ self .iterations .assign_add (
172
187
tf .cast (self .step % self ._accum_steps == 0 , tf .int64 ),
173
188
read_value = False ,
174
189
)
175
190
]
176
191
):
177
192
return self .step .assign_add (1 , read_value = False )
178
193
179
- def reset (self ):
180
- """Resets the accumulated gradients on the current replica."""
181
- assign_ops = []
182
- if not self ._gradients :
183
- return assign_ops
184
-
185
- for _ , gradient in self ._gradients :
186
- if gradient is not None :
187
- assign_ops .append (
188
- gradient .assign (
189
- tf .zeros_like (gradient ),
190
- use_locking = self ._use_locking ,
191
- read_value = False ,
192
- )
193
- )
194
-
195
- return tf .group (assign_ops )
196
-
197
194
@property
198
195
def inner_optimizer (self ):
199
196
"""The optimizer that this LossScaleOptimizer is wrapping."""
0 commit comments