diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 69ccbf4fa..6f8047c1a 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -27,19 +27,27 @@ import pytest import torch import transformers +import yaml # First Party from build.utils import serialize_args from scripts.run_inference import TunedCausalLM +from tests.artifacts.predefined_data_configs import ( + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, +) from tests.artifacts.testdata import ( EMPTY_DATA, MALFORMATTED_DATA, MODEL_NAME, + TWITTER_COMPLAINTS_DATA_ARROW, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSONL, + TWITTER_COMPLAINTS_DATA_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, ) # Local @@ -683,6 +691,8 @@ def test_successful_lora_target_modules_default_from_main(): [ TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_JSON, + TWITTER_COMPLAINTS_DATA_PARQUET, + TWITTER_COMPLAINTS_DATA_ARROW, ], ) def test_run_causallm_ft_and_inference(dataset_path): @@ -719,7 +729,12 @@ def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no(): @pytest.mark.parametrize( "dataset_path", - [TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_JSON], + [ + TWITTER_COMPLAINTS_TOKENIZED_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_JSON, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_ARROW, + ], ) def test_run_causallm_ft_pretokenized(dataset_path): """Check if we can bootstrap and finetune causallm models using pretokenized data""" @@ -754,6 +769,61 @@ def test_run_causallm_ft_pretokenized(dataset_path): assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference +@pytest.mark.parametrize( + "datafiles, datasetconfigname", + [ + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ) + ], +) +def test_run_causallm_ft_and_inference_with_multiple_dataset( + datasetconfigname, datafiles +): + """Check if we can finetune causallm models using multiple datasets with multiple files""" + with tempfile.TemporaryDirectory() as tempdir: + data_formatting_args = copy.deepcopy(DATA_ARGS) + + # set training_data_path and response_template to none + data_formatting_args.response_template = None + data_formatting_args.training_data_path = None + + # add data_paths in data_config file + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + with open(datasetconfigname, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + datasets = data["datasets"] + for _, d in enumerate(datasets): + d["data_paths"] = datafiles + yaml.dump(data, temp_yaml_file) + data_formatting_args.data_config_path = temp_yaml_file.name + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args) + + # validate full ft configs + _validate_training(tempdir) + _, checkpoint_path = _get_latest_checkpoint_trainer_state(tempdir) + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + + ############################# Helper functions ############################# def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): train_args = copy.deepcopy(training_args)