@@ -115,6 +115,7 @@ def _wrap_for_dist_loader(self, train_dataloader):
115
115
return dist_loader
116
116
117
117
def _wrap_for_auto (self , model , train_dataloader ):
118
+ logger .info ("Wrapping model for auto paralle" )
118
119
dist_loader = self ._wrap_for_dist_loader (train_dataloader )
119
120
120
121
if ShardingOption .SHARD_OP in self .args .sharding :
@@ -135,6 +136,15 @@ def _wrap_for_auto(self, model, train_dataloader):
135
136
if self .args .to_static :
136
137
unified_strategy = dist .Strategy ()
137
138
unified_strategy ._from_legacy_strategy (self .args .strategy )
139
+
140
+ # same logic as autocast_smart_context_manager() in trainer.py
141
+ if self .enable_autocast_context_manager :
142
+ unified_strategy .amp .custom_black_list .extend (["reduce_sum" , "c_softmax_with_cross_entropy" ])
143
+ if self .args .fp16_opt_level == "O2" :
144
+ print ("custom_white_list" , unified_strategy .amp .custom_white_list , flush = 1 )
145
+ unified_strategy .amp .custom_white_list .extend (["lookup_table" , "lookup_table_v2" ])
146
+ print ("custom_white_list" , unified_strategy .amp .custom_white_list , flush = 1 )
147
+
138
148
# dist.to_static() obtains the input spec information through next(dataloader), but this has side effects
139
149
# on the passed-in dataloader, altering the state of the sampler of the dataloader. In some cases, once
140
150
# the state of the sampler is changed, it cannot be reverted. Therefore, a temporary dataloader is
@@ -156,9 +166,10 @@ def _wrap_amp_model(self, args, model):
156
166
master_grad = self .args .amp_master_grad ,
157
167
excluded_layers = QuantizationLinear ,
158
168
)
169
+ self .enable_autocast_context_manager = True
170
+
159
171
if args .to_static :
160
172
return
161
- self .enable_autocast_context_manager = True
162
173
self .do_grad_scaling = True if self .args .fp16 else False
163
174
self .scaler = dist .shard_scaler (paddle .amp .GradScaler (init_loss_scaling = self .args .scale_loss ))
164
175
0 commit comments