forked from huggingface/nanotron
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_train.py
140 lines (119 loc) · 5.49 KB
/
run_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Nanotron training script.
Usage:
```
export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
torchrun --nproc_per_node=8 run_train.py --config-file examples/config_tiny_llama.yaml
```
"""
import argparse
from nanotron import logging
from nanotron.config import (
PretrainDatasetsArgs,
)
from nanotron.dataloader import (
clm_process,
dummy_infinite_data_generator,
get_datasets,
get_train_dataloader,
)
from nanotron.logging import log_rank
from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks
from nanotron.trainer import DistributedTrainer
from nanotron.utils import (
main_rank_first,
)
try:
from huggingface_hub import __version__ as hf_hub_version
from transformers import AutoTokenizer
from transformers import __version__ as tf_version
except ImportError:
hf_hub_version = None
tf_version = None
logger = logging.get_logger(__name__)
def get_dataloader(trainer: DistributedTrainer):
"""Returns a dataloader for training."""
# First, we need to know which ranks to feed the dataloader to
input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)
# Case 1: Dummy data generator
if trainer.config.data.dataset is None:
log_rank("Using dummy data generator", logger=logger, level=logging.INFO, rank=0)
dataloader = dummy_infinite_data_generator(
micro_batch_size=trainer.micro_batch_size,
sequence_length=trainer.sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
vocab_size=trainer.model_config.vocab_size,
seed=trainer.config.data.seed,
parallel_context=trainer.parallel_context,
)()
# Case 2: HuggingFace datasets
elif isinstance(trainer.config.data.dataset, PretrainDatasetsArgs):
log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0)
tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path
log_rank(
f"Loading tokenizer from {tokenizer_path} and transformers/hf_hub versions {tf_version, hf_hub_version}",
logger=logger,
level=logging.INFO,
rank=0,
)
# We need to the 1st device to process dataset and cache it, then other devices load from cache
with main_rank_first(trainer.parallel_context.world_pg):
# TODO @nouamanetazi: this may timeout before 1st device finishes processing dataset. Can we have a ctxmanager to modify timeout?
# TODO: generalise to include for validation/test splits
# We load the raw dataset
raw_dataset = get_datasets(
hf_dataset_or_datasets=trainer.config.data.dataset.hf_dataset_or_datasets,
splits=trainer.config.data.dataset.hf_dataset_splits,
)["train"]
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
# We apply the Causal Language Modeling preprocessing
train_dataset = clm_process(
raw_dataset=raw_dataset,
tokenizer=tokenizer,
text_column_name=trainer.config.data.dataset.text_column_name,
dataset_processing_num_proc_per_process=trainer.config.data.dataset.dataset_processing_num_proc_per_process,
dataset_overwrite_cache=trainer.config.data.dataset.dataset_overwrite_cache,
sequence_length=trainer.sequence_length,
)
# We load the processed dataset on the ranks requiring it
dataloader = get_train_dataloader(
train_dataset=train_dataset,
sequence_length=trainer.sequence_length,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
consumed_train_samples=trainer.consumed_train_samples,
dataloader_num_workers=trainer.config.data.num_loading_workers,
seed_worker=trainer.config.data.seed,
dataloader_drop_last=True,
)
# Check if we have enough samples for train_steps
total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length
num_tokens_needed_for_training = (
(trainer.config.tokens.train_steps - trainer.start_iteration_step)
* trainer.global_batch_size
* trainer.sequence_length
)
assert num_tokens_needed_for_training <= total_tokens_dataset, (
f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), "
f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.start_iteration_step}"
)
else:
raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {trainer.config.data.dataset}")
return dataloader
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
config_file = args.config_file
# Load trainer and data
trainer = DistributedTrainer(config_file)
dataloader = get_dataloader(trainer)
# Train
trainer.train(dataloader)