Skip to content

Commit

Permalink
e2e testing unit test for multiple datasets with multiple files
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
  • Loading branch information
Abhishek-TAMU committed Dec 11, 2024
1 parent ce82af1 commit 3fe7425
Showing 1 changed file with 71 additions and 1 deletion.
72 changes: 71 additions & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,27 @@
import pytest
import torch
import transformers
import yaml

# First Party
from build.utils import serialize_args
from scripts.run_inference import TunedCausalLM
from tests.artifacts.predefined_data_configs import (
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
)
from tests.artifacts.testdata import (
EMPTY_DATA,
MALFORMATTED_DATA,
MODEL_NAME,
TWITTER_COMPLAINTS_DATA_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_JSON,
TWITTER_COMPLAINTS_DATA_JSONL,
TWITTER_COMPLAINTS_DATA_PARQUET,
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
TWITTER_COMPLAINTS_TOKENIZED_JSON,
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
)

# Local
Expand Down Expand Up @@ -683,6 +691,8 @@ def test_successful_lora_target_modules_default_from_main():
[
TWITTER_COMPLAINTS_DATA_JSONL,
TWITTER_COMPLAINTS_DATA_JSON,
TWITTER_COMPLAINTS_DATA_PARQUET,
TWITTER_COMPLAINTS_DATA_ARROW,
],
)
def test_run_causallm_ft_and_inference(dataset_path):
Expand Down Expand Up @@ -719,7 +729,12 @@ def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no():

@pytest.mark.parametrize(
"dataset_path",
[TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_JSON],
[
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
TWITTER_COMPLAINTS_TOKENIZED_JSON,
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
],
)
def test_run_causallm_ft_pretokenized(dataset_path):
"""Check if we can bootstrap and finetune causallm models using pretokenized data"""
Expand Down Expand Up @@ -754,6 +769,61 @@ def test_run_causallm_ft_pretokenized(dataset_path):
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference


@pytest.mark.parametrize(
"datafiles, datasetconfigname",
[
(
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
],
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
)
],
)
def test_run_causallm_ft_and_inference_with_multiple_dataset(
datasetconfigname, datafiles
):
"""Check if we can finetune causallm models using multiple datasets with multiple files"""
with tempfile.TemporaryDirectory() as tempdir:
data_formatting_args = copy.deepcopy(DATA_ARGS)

# set training_data_path and response_template to none
data_formatting_args.response_template = None
data_formatting_args.training_data_path = None

# add data_paths in data_config file
with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
with open(datasetconfigname, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
datasets = data["datasets"]
for _, d in enumerate(datasets):
d["data_paths"] = datafiles
yaml.dump(data, temp_yaml_file)
data_formatting_args.data_config_path = temp_yaml_file.name

train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args)

# validate full ft configs
_validate_training(tempdir)
_, checkpoint_path = _get_latest_checkpoint_trainer_state(tempdir)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)

# Run inference on the text
output_inference = loaded_model.run(
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
)
assert len(output_inference) > 0
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference


############################# Helper functions #############################
def _test_run_causallm_ft(training_args, model_args, data_args, tempdir):
train_args = copy.deepcopy(training_args)
Expand Down

0 comments on commit 3fe7425

Please # to comment.