-
Notifications
You must be signed in to change notification settings - Fork 58
/
Copy pathfinetune_moss_for_training.py
124 lines (107 loc) · 3.06 KB
/
finetune_moss_for_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
一个使用CoLLie训练Moss的例子(使用LOMO优化器,开启ZeRO3)
"""
import sys
sys.path.append('..')
from transformers import AutoTokenizer
from collie.config import CollieConfig
from collie.data import CollieDatasetForTraining
from collie.optim.lomo import Lomo
from collie.controller.trainer import Trainer
from collie.controller.evaluator import EvaluatorForPerplexity, EvaluatorForGeneration
from collie.models.moss_moon import Moss003MoonForCausalLM
from collie.utils.monitor import StepTimeMonitor, TGSMonitor, MemoryMonitor, LossMonitor, EvalMonitor
from collie.metrics import DecodeMetric, PPLMetric
from collie.module import GPTLMLoss
# 1. 设置路径
# 1.1 预训练模型路径
pretrained_model = "fnlp/moss-moon-003-sft"
# 2. 设置配置
# 2.1 加载配置
config = CollieConfig.from_pretrained(pretrained_model, trust_remote_code=True)
config.tp_size = 1
config.dp_size = 1
config.pp_size = 1
config.train_epochs = 1
config.eval_per_n_steps = 0
config.eval_per_n_epochs = 1
config.train_micro_batch_size = 16
config.eval_batch_size = 1
config.ds_config = {
"fp16": {
"enabled": True
},
"zero_allow_untested_optimizer": True,
"zero_force_ds_cpu_optimizer": False,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": False
}
}
}
# 3. 设置tokenizer
tokenizer = AutoTokenizer.from_pretrained("fnlp/moss-moon-003-sft", trust_remote_code=True)
# 4. 加载数据集
train_dataset = [
{
'input': 'Collie is a python package for ',
'output': 'finetuning large language models.'
} for _ in range(10000)
]
train_dataset = CollieDatasetForTraining(train_dataset, tokenizer)
eval_dataset = train_dataset[:32]
# 5. 加载预训练模型
model = Moss003MoonForCausalLM.from_pretrained(pretrained_model, config=config)
# 6. 设置优化器
optimizer = Lomo(
model,
lr = 0.001,
clip_grad_norm = 5.0
)
# 7. 添加监视器
monitors = [
StepTimeMonitor(config),
TGSMonitor(config),
MemoryMonitor(config),
LossMonitor(config),
EvalMonitor(config)
]
# 8. 添加Evaluator
evaluator_ppl = EvaluatorForPerplexity(
model = model,
config = config,
dataset = eval_dataset,
monitors = [
EvalMonitor(config)
],
metrics = {
'ppl': PPLMetric()
}
)
evaluator_decode = EvaluatorForGeneration(
model = model,
config = config,
tokenizer = tokenizer,
dataset = eval_dataset,
monitors = [
EvalMonitor(config)
],
metrics = {
'decode': DecodeMetric()
}
)
# 9. 实例化trainer
trainer = Trainer(
model = model,
config = config,
loss_fn = GPTLMLoss(-100),
optimizer = optimizer,
train_dataset = train_dataset,
monitors = monitors,
evaluators = [evaluator_ppl, evaluator_decode]
)
# 10. 训练/验证
trainer.train()
# Command CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --rdzv_backend=c10d --rdzv_endpoint=localhost:29402 --nnodes=1 --nproc_per_node=4 finetune_moss_for_training.py