Skip to content

Commit 26b73c2

Browse files
authored
align default custom black/white list for dygraph and static graph (#9340)
1 parent ec25cb8 commit 26b73c2

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

llm/auto_parallel/llama/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
- 动静统一自动并行组网[modeling_auto.py](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling_auto.py),当前主要支持预训练,包括动态图和动转静训练,未来会扩展支持 SFT 等流程。
66

77
## 2. 预训练准备
8+
9+
安装最新的 Paddle,建议使用 nightly 版本,请前往 [Paddle 官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/develop/install/pip/linux-pip.html) 进行安装。
10+
811
下载预先处理好的数据,并解压到 `./data` 目录下:
912
```shell
1013
# llama 模型数据下载

paddlenlp/trainer/auto_trainer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def _wrap_for_dist_loader(self, train_dataloader):
115115
return dist_loader
116116

117117
def _wrap_for_auto(self, model, train_dataloader):
118+
logger.info("Wrapping model for auto paralle")
118119
dist_loader = self._wrap_for_dist_loader(train_dataloader)
119120

120121
if ShardingOption.SHARD_OP in self.args.sharding:
@@ -135,6 +136,15 @@ def _wrap_for_auto(self, model, train_dataloader):
135136
if self.args.to_static:
136137
unified_strategy = dist.Strategy()
137138
unified_strategy._from_legacy_strategy(self.args.strategy)
139+
140+
# same logic as autocast_smart_context_manager() in trainer.py
141+
if self.enable_autocast_context_manager:
142+
unified_strategy.amp.custom_black_list.extend(["reduce_sum", "c_softmax_with_cross_entropy"])
143+
if self.args.fp16_opt_level == "O2":
144+
print("custom_white_list", unified_strategy.amp.custom_white_list, flush=1)
145+
unified_strategy.amp.custom_white_list.extend(["lookup_table", "lookup_table_v2"])
146+
print("custom_white_list", unified_strategy.amp.custom_white_list, flush=1)
147+
138148
# dist.to_static() obtains the input spec information through next(dataloader), but this has side effects
139149
# on the passed-in dataloader, altering the state of the sampler of the dataloader. In some cases, once
140150
# the state of the sampler is changed, it cannot be reverted. Therefore, a temporary dataloader is
@@ -156,9 +166,10 @@ def _wrap_amp_model(self, args, model):
156166
master_grad=self.args.amp_master_grad,
157167
excluded_layers=QuantizationLinear,
158168
)
169+
self.enable_autocast_context_manager = True
170+
159171
if args.to_static:
160172
return
161-
self.enable_autocast_context_manager = True
162173
self.do_grad_scaling = True if self.args.fp16 else False
163174
self.scaler = dist.shard_scaler(paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss))
164175

0 commit comments

Comments
 (0)