-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_model2.py
110 lines (89 loc) · 2.67 KB
/
train_model2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
@author: akash
"""
from transformers import GPT2Config, GPT2LMHeadModel, TrainingArguments, Trainer, EarlyStoppingCallback
from datasets import Dataset, load_dataset
import polars as pl
import numpy as np
import pyarrow
import argparse
import torch
### ADDING ARGUMENTS TO PASS
parser = argparse.ArgumentParser()
parser.add_argument(
'--percent_sample',
type=float,
help="Percentage of the training and validation sets to randomly sample",
required=True
)
parser.add_argument(
'--dataset_size',
type=int,
help="Length of the training dataset",
required=True
)
args = parser.parse_args()
### CONFIGURING GPT-2 VARIANT
config = GPT2Config(
vocab_size=64,
n_positions=64,
n_ctx=64,
n_embd=768,
n_layer=8,
n_head=16,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1
)
model = GPT2LMHeadModel(config)
### LOADING TRAINING AND VALIDATION SET
train_small_dataset = load_dataset(
"parquet",
data_files=f"./data/sampled/training_final_8x8_{int(args.percent_sample * 100)}_percent.parquet",
split="train",
streaming=True
)
val_small_dataset = pl.read_parquet(f"./data/sampled/val_final_8x8_{int(args.percent_sample * 100)}_percent.parquet")
val_small_dataset = Dataset(val_small_dataset.to_arrow())
print(f"Streaming set up for the {int(args.percent_sample * 100)}% sampled datasets done")
### TRAINING BLOCK
num_gpus = 2
batch_size = 128
training_epochs = 1
num_gradient_accumulation = 1
# computing the estimated number of total steps
max_steps = (args.dataset_size // (num_gpus * batch_size)) * training_epochs
max_steps //= num_gradient_accumulation
training_args = TrainingArguments(
output_dir=f"./trained{int(args.percent_sample * 100)}/results/",
overwrite_output_dir=True,
eval_strategy="steps",
num_train_epochs=training_epochs,
max_steps=max_steps,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=num_gradient_accumulation,
# metric_for_best_model="eval_loss",
# greater_is_better=False,
# learning_rate=5e-4,
learning_rate=1e-4,
warmup_steps=500,
lr_scheduler_type="cosine",
weight_decay=0.1,
eval_steps=1000,
save_steps=1000,
logging_dir=f"./trained{int(args.percent_sample * 100)}/logs",
logging_steps=100,
fp16=True,
ddp_find_unused_parameters=False,
# dataloader_num_workers=1,
# dataloader_pin_memory=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_small_dataset,
eval_dataset=val_small_dataset,
# callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
)
trainer.train()