Skip to content
Merged
12 changes: 12 additions & 0 deletions QEfficient/cloud/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ def main(**kwargs):
# print the datatype of the model parameters
# print(get_parameter_dtypes(model))

# Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
# Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
# apply gradient checkpointing related hooks to the input embeddings. Without this we will get
# "No inf checks were recorded for this optimizer." error.
# Enable gradient checkpointing
if train_config.gradient_checkpointing:
# Note: below attribute and method is only available in HuggingFace Transformer models.
if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False})
else:
raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.")

if train_config.use_peft:
# Load the pre-trained peft model checkpoint and setup its configuration
if train_config.from_peft_checkpoint:
Expand Down
1 change: 1 addition & 0 deletions QEfficient/finetune/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class train_config:
batch_size_training: int = 1
context_length: int = None
gradient_accumulation_steps: int = 4
gradient_checkpointing: bool = False
num_epochs: int = 1
max_train_step: int = 0
max_eval_step: int = 0
Expand Down
8 changes: 7 additions & 1 deletion QEfficient/finetune/dataset/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from QEfficient.finetune.dataset.samsum_dataset import (
get_preprocessed_samsum as get_samsum_dataset,
)
from QEfficient.finetune.dataset.samsum_dataset import (
get_samsum_collate_fn,
)

DATASET_PREPROC = {
"alpaca_dataset": partial(get_alpaca_dataset),
Expand All @@ -29,4 +32,7 @@
"gsm8k_dataset": get_gsm8k_dataset,
"custom_dataset": get_custom_dataset,
}
DATALOADER_COLLATE_FUNC = {"custom_dataset": get_data_collator}
DATALOADER_COLLATE_FUNC = {
"custom_dataset": get_data_collator,
"samsum_dataset": get_samsum_collate_fn,
}
21 changes: 21 additions & 0 deletions QEfficient/finetune/dataset/samsum_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# -----------------------------------------------------------------------------

import datasets
import torch
from torch.nn.utils.rnn import pad_sequence


def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
Expand Down Expand Up @@ -46,3 +48,22 @@ def tokenize_add_label(sample):
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))

return dataset


def collate_fn(batch):
eos_token = batch[0]["input_ids"][-1]

input_ids = pad_sequence(
[torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token
)
attn_mask = pad_sequence(
[torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0
)
labels = pad_sequence(
[torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token
)
return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}


def get_samsum_collate_fn(dataset_processer, dataset_config):
return collate_fn
7 changes: 1 addition & 6 deletions QEfficient/finetune/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,7 @@ def train(
# adjust atol & rtol this as required
atol=1e-1,
use_ref_output_on_mismatch=True,
# report all mismatches
max_failures=None,
# generate unittest for each op once
repeat_same_op=True,
filter_config=qaic_debug.DispatchFilterConfig.default(device),
dump_root_dir=train_config.dump_root_dir + str(step),
) as verifier:
loss = model(**batch).loss # Forward call
Expand Down Expand Up @@ -297,8 +294,6 @@ def train(
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(
model, train_config, eval_dataloader, local_rank, tokenizer, device
)
dist.barrier()
dist.all_reduce(eval_epoch_loss, op=dist.ReduceOp.SUM)
if local_rank == 0:
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)

Expand Down