Skip to content

Commit

Permalink
enable verbose by default, missing final check on LR scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
bm-synth committed Feb 10, 2025
1 parent c4b26a8 commit fef0495
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def is_microbatch_valid(metrics):
microbatches += mbs

# make sure we give the same number of (micro-)batches to each dataloader by trimming the dataset
assert len(microbatches) >= effective_batch_size, "not enough datapoints to create a single sample per dataloader"
microbatches = microbatches[:len(microbatches) - len(microbatches) % effective_batch_size]

#compute the effective batch size for each microbatch.
Expand All @@ -129,11 +130,11 @@ def is_microbatch_valid(metrics):
batch_sizes.append(batch_size)
batch_max_seqlens.append(batch_max_seqlen)
microbatch_ids += batch_and_mb_ids
n_tokens_in_batch = sum([m[0] for m in mbs[0]])
assert n_tokens_in_batch <= max_tokens
if verbose:
n_tokens_per_mb = [sum([m[0] for m in mb]) for mb in mbs]
assert all([n <= max_tokens for n in n_tokens_per_mb]), "size of microbatch exceeds max tokens"
logger.info(
f"Batch id {batch_id}, samples {batch_size}, tokens {n_tokens_in_batch} tokens, samples: {dataset_filter_ids}"
f"Batch id {batch_id}, batch_size: {batch_size} sentences, n_tokens per microbatch {n_tokens_per_mb} tokens, sentence ids per microbatch: {dataset_filter_ids}"
)

# return the sample ids of each microbatch, and the batch sizes
Expand Down Expand Up @@ -291,18 +292,21 @@ def step(self, epoch=None):
group['lr'] = scale_lr(self.base_batch_size, batch_size, group['lr'], self.lr_scaling_method)

if self.verbose:
logger.info(f"Batch id {self.last_epoch}, unscaled LRs {unscaled_lrs}, scaled LRs {self.get_lr()}")
logger.info(
f"Batch id {self.last_epoch}. Reference: batch_size {self.base_batch_size}, lr {unscaled_lrs}. Scaled: batch_size {batch_size}, lr {self.get_lr()}"
)


def lr_scheduler_for_variable_batch_size(base_batch_size,
batch_sizes,
dataloader,
lr_scheduler_or_optimizer,
lr_scaling_method='linear'):
lr_scaling_method='linear',
verbose=False):
"""
returns a class that provides an LR scheduler that scales learning rate at every
epoch taking into account the batch size of each epoch.
If learning rate is constant, ie no LR scheduler, then the LR will be taken from the
returns a class that provides an LR scheduler that scales the learning rate at every
iteration taking into account the batch size of that iteration.
If learning rate is constant, ie no LR scheduler, then the base LR will be taken from the
constant LR values in the optimizer param groups. Otherwise from the scheduler's LR.
Arguments:
Expand Down Expand Up @@ -332,7 +336,8 @@ def get_lr(self) -> float:
base_batch_size=base_batch_size,
batch_sizes=batch_sizes,
dataloader=dataloader,
lr_scaling_method=lr_scaling_method)
lr_scaling_method=lr_scaling_method,
verbose=verbose)


def get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed(dataset,
Expand Down Expand Up @@ -472,6 +477,7 @@ def get_dataloader_and_lr_scheduler_for_variable_batch_size(
batch_sizes=batch_sizes,
lr_scaling_method=lr_scaling_method,
lr_scheduler_or_optimizer=lr_scheduler_or_optimizer,
dataloader=dataloader)
dataloader=dataloader,
verbose=verbose)

return dataloader, lr_scheduler, deepspeed_io_kwargs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __init__(self, seq_count, min_seqlen=1, max_seqlen=20, embed_dim=5, seed=0):
self.embed_dim = embed_dim
self.seqs = []
for _ in range(seq_count):
self.seqs.append(torch.ones(data_random.randrange(min_seqlen, max_seqlen), embed_dim))
seqlen = data_random.randrange(min_seqlen, max_seqlen)
self.seqs.append(torch.ones(seqlen, embed_dim))

__len__ = lambda self: len(self.seqs) # noqa
__getitem__ = lambda self, idx: (self.seqs[idx], len(self.seqs[idx])) # noqa
Expand Down Expand Up @@ -101,12 +102,15 @@ def to_layers(self):
device = f"cuda:{dist.get_local_rank()}"
assert dist.get_local_rank() <= torch.cuda.device_count(), "needs at least 1 GPU per process"

pipeline_num_stages = 0
max_seqlen = 50
dataset = TestData(seq_count=1000, min_seqlen=3, max_seqlen=max_seqlen)
# dummy dataset with 2000 sequences of random lengths between 3 and 10 tokens per sentence.
max_seqlen = 10
dataset = TestData(seq_count=2000, min_seqlen=3, max_seqlen=max_seqlen, seed=42)

model = AttentionHeadAndFeedForward(max_seqlen, dataset.embed_dim).to(device)
loss_fn = lambda x, y: F.mse_loss(x.float(), y.float()) # noqa

# number of pipeline stages. 0 or 1 to disable pipelining, >1 to enable
pipeline_num_stages = 0
if pipeline_num_stages > 1:
model = PipelineModule(layers=model.to_layers(), num_stages=pipeline_num_stages, loss_fn=loss_fn)

Expand All @@ -133,7 +137,7 @@ def to_layers(self):
# enables or disables dynamic batching
"enabled": True,
# how many tokens we need to fill a pack of sequences (that will be collated together as a sample)
"max_tokens": 100,
"max_tokens": 50,
# Input and output write to read from or write the length of every sequence.
# Sequence lengths will be loaded from: {metrics_path}/seqlen/seqlen_sample_to_metric.bin and *.idx
# If files dont exist, they'll be computed and saved on the first run, and loaded on subsequent runs.
Expand All @@ -145,7 +149,7 @@ def to_layers(self):
# - dataloader: by same order as they come in with the dataloader
# - seqlen: by sequence length (shortest to longest)
# - random: random order using the seed in config['data_efficiency']['seed'
"sentence_picking_order": "dataloader", # "random" / "seqlen" / "dataloader"
"sentence_picking_order": "seqlen", # "random" / "seqlen" / "dataloader"
# minimum number of sequences required to reach `max_tokens`. If sentence pack is smaller, it's discarded.
"min_batch_size": 1,
# maximum number of sequences required to reach `max_tokens`. If sentence pack is larger, it's discarded.
Expand Down Expand Up @@ -174,26 +178,31 @@ def to_layers(self):
sample_padding_fn=dataset.sample_padding_fn,
batch_seqlens_fn=dataset.batch_seqlens_fn)

# train on 2 epochs with 20 iterations per epoch
for epoch in range(2):
lr_scheduler.step(0) # point LR scheduler to first batch
for it in range(10):
data_iter = iter(dataloader) # point data iterator to first batch
# train loop with 3 epochs
n_iterations = len(dataloader) // engine.gradient_accumulation_steps()
for epoch in range(3):
data_iter = iter(dataloader) # point data iterator to beginning of dataset
# lr_scheduler.step(0) # optional: reset dynamic LR scheduler to point to the first batch in dataset
for it in range(n_iterations):
lr_kwargs = {} # optional: epoch argument to pass to engine.lr_scheduler.step() e.g. {"epoch": it}
if pipeline_num_stages > 0:
engine.reset_activation_shape() # reset, as each batch has a diff BxT dimension
loss = engine.train_batch(data_iter=data_iter) # lr_kwargs={"epoch": batch_id}
loss = engine.train_batch(data_iter=data_iter, lr_kwargs=lr_kwargs)
else:
for gas in range(engine.gradient_accumulation_steps()):
seqs, labels = next(data_iter)
n_tokens = (seqs[:, :, 0] != 0).sum().item()
seqs, labels = seqs.to(device), labels.to(device)
outputs = engine(seqs)
loss = loss_fn(outputs, labels)
engine.backward(loss)
engine.step() # lr_kwargs={"epoch": it})
print(
f"- acc step {gas}, dp_rank {dp_rank}: shape {list(seqs.shape)}, n_tokens {n_tokens}, lr {lr_scheduler.get_lr()}"
)
engine.step(lr_kwargs=lr_kwargs)

# optional: output some information about dynamic batching
n_tokens = (seqs[:, :, 0] != 0).sum().item()
lr = lr_scheduler.get_lr()[0]
shape = list(seqs.shape)
print(f"- {epoch}.{it}.{gas}, rank {dp_rank}: shape {shape}, n_tokens {n_tokens}, lr {lr}")
dist.barrier() # optional

if dp_rank == 0:
print(f"epoch {epoch}, iteration {it}, loss {loss.item()}, lrs {lr_scheduler.get_lr()}")
Expand Down

0 comments on commit fef0495

Please # to comment.