-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
104 lines (85 loc) · 4.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import pandas as pd
import pathlib
import torch
from pathlib import Path
from torch.utils.data import random_split
from transformers import TrainingArguments
from transformers import Trainer
from typing import List, Dict, Union
from typing import Any, TypeVar, Tuple
from self_instruct.cfg import BasicConfig, TrainConfig
from self_instruct.utils import parse_model_name
from self_instruct.utils import load_model
from self_instruct.utils import load_tokenizer
from self_instruct.utils import create_outpath
from self_instruct.utils import get_lora_config
from self_instruct.utils import SavePeftModelCallback
from self_instruct.dataset import InstructDataset
from peft import PeftModel
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model
Pathable = Union[str, pathlib.Path]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', nargs="?", default="")
parser.add_argument('-t', '--tokenizer', nargs="?", default="")
parser.add_argument('-l', '--lora', nargs="?", default=True)
args = parser.parse_args()
return args
def prepare_dataset(input_path: Pathable, **kwargs) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]:
data = pd.read_csv(input_path)
dict_data = pd.DataFrame.to_dict(data, orient="records")
dataset = InstructDataset(data = dict_data, tokenizer = kwargs['tokenizer'], max_length = kwargs['config'].max_length)
return split_data(dataset = dataset, ratio = kwargs['config'].split_ratio)
def split_data(dataset: torch.utils.data.Dataset, ratio: float = 0.9):
train_size: int = int(ratio * len(dataset))
val_size: int = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
return train_dataset, val_dataset
#def load_in_8bit(name: str):
def main():
config = BasicConfig()
trconfig = TrainConfig()
args = get_args()
model_name = args.model if args.model is not "" else config.model_name
tokenizer_name = args.tokenizer if args.tokenizer is not "" else config.tokenizer_name
parsed_model_name = parse_model_name(model_name)
parsed_tokenizer_name = parse_model_name(tokenizer_name)
model = load_model(model_name, load_in_8bit=False)
tokenizer = load_tokenizer(tokenizer_name)
train_dataset, val_dataset = prepare_dataset(input_path = config.dataset, tokenizer = tokenizer, config = config)
if args.lora:
lora_config = get_lora_config()
lora_config.inference_mode = False
model = load_model(model_name, load_in_8bit=True)
model = prepare_model_for_int8_training(model)
model = get_peft_model(model, lora_config)
model.config.use_cache = False
#Create output path
output_path = create_outpath(model_name = model_name)
training_args = TrainingArguments(output_dir = output_path,
num_train_epochs = trconfig.n_epochs,
logging_steps = trconfig.logging_steps,
per_device_train_batch_size = trconfig.train_batch_size,
per_device_eval_batch_size = trconfig.eval_batch_size,
fp16 = trconfig.fp16,
warmup_steps = trconfig.warmup_steps,
weight_decay = trconfig.weight_decay,
learning_rate = trconfig.learning_rate,
logging_dir = trconfig.logging_dir,
save_total_limit = trconfig.save_total_limit,
report_to = "wandb",
)
# Define trainer
trainer = Trainer(model = model,
tokenizer = tokenizer,
args = training_args,
train_dataset = train_dataset,
eval_dataset = val_dataset,
data_collator= lambda data: {'input_ids': torch.stack([f[0] for f in data]),
'attention_mask': torch.stack([f[1] for f in data]),
'labels': torch.stack([f[0] for f in data])},
callbacks = [SavePeftModelCallback(output_path = output_path)])
trainer.train(resume_from_checkpoint=True if 'checkpoint' in model_name else False)
if __name__ == "__main__":
main()