From 4441948a61721a3dffce454f9f0d40158cc4eb67 Mon Sep 17 00:00:00 2001 From: Abhishek Maurya <124327945+Abhishek-TAMU@users.noreply.github.com> Date: Fri, 13 Dec 2024 11:38:51 -0500 Subject: [PATCH] test: Add unit tests to test multiple files in single/multiple datasets (#412) * test: Add unit tests to test multiple files in single/multiple datasets Signed-off-by: Abhishek * e2e testing unit test for multiple datasets with multiple files Signed-off-by: Abhishek * test: multiple datasets with multiple datafiles column names Signed-off-by: Will Johnson * PR changes Signed-off-by: Abhishek * PR Changes Signed-off-by: Abhishek * fix: fmt Signed-off-by: Abhishek * Merge test_process_dataconfig_multiple_files_varied_data_formats Signed-off-by: Abhishek --------- Signed-off-by: Abhishek Signed-off-by: Will Johnson Co-authored-by: Will Johnson --- tests/artifacts/testdata/__init__.py | 32 +- .../twitter_complaints_input_output.arrow | Bin .../twitter_complaints_small.arrow | Bin ..._tokenized_with_maykeye_tinyllama_v0.arrow | Bin .../testdata/{ => json}/empty_data.json | 0 .../{ => json}/malformatted_data.json | 0 .../twitter_complaints_input_output.json | 0 .../{ => json}/twitter_complaints_small.json | 0 ...s_tokenized_with_maykeye_tinyllama_v0.json | 0 .../twitter_complaints_input_output.jsonl | 0 .../twitter_complaints_small.jsonl | 0 ..._tokenized_with_maykeye_tinyllama_v0.jsonl | 0 tests/data/test_data_preprocessing_utils.py | 310 +++++++++++++++--- tests/test_sft_trainer.py | 96 +++++- tuning/data/data_handlers.py | 2 + 15 files changed, 380 insertions(+), 60 deletions(-) rename tests/artifacts/testdata/{ => arrow}/twitter_complaints_input_output.arrow (100%) rename tests/artifacts/testdata/{ => arrow}/twitter_complaints_small.arrow (100%) rename tests/artifacts/testdata/{ => arrow}/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow (100%) rename tests/artifacts/testdata/{ => json}/empty_data.json (100%) rename tests/artifacts/testdata/{ => json}/malformatted_data.json (100%) rename tests/artifacts/testdata/{ => json}/twitter_complaints_input_output.json (100%) rename tests/artifacts/testdata/{ => json}/twitter_complaints_small.json (100%) rename tests/artifacts/testdata/{ => json}/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json (100%) rename tests/artifacts/testdata/{ => jsonl}/twitter_complaints_input_output.jsonl (100%) rename tests/artifacts/testdata/{ => jsonl}/twitter_complaints_small.jsonl (100%) rename tests/artifacts/testdata/{ => jsonl}/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl (100%) diff --git a/tests/artifacts/testdata/__init__.py b/tests/artifacts/testdata/__init__.py index 39895f6f1..762f88ab9 100644 --- a/tests/artifacts/testdata/__init__.py +++ b/tests/artifacts/testdata/__init__.py @@ -19,37 +19,47 @@ ### Constants used for data DATA_DIR = os.path.join(os.path.dirname(__file__)) +JSON_DATA_DIR = os.path.join(os.path.dirname(__file__), "json") +JSONL_DATA_DIR = os.path.join(os.path.dirname(__file__), "jsonl") +ARROW_DATA_DIR = os.path.join(os.path.dirname(__file__), "arrow") PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet") -TWITTER_COMPLAINTS_DATA_JSON = os.path.join(DATA_DIR, "twitter_complaints_small.json") -TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl") -TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(DATA_DIR, "twitter_complaints_small.arrow") + +TWITTER_COMPLAINTS_DATA_JSON = os.path.join( + JSON_DATA_DIR, "twitter_complaints_small.json" +) +TWITTER_COMPLAINTS_DATA_JSONL = os.path.join( + JSONL_DATA_DIR, "twitter_complaints_small.jsonl" +) +TWITTER_COMPLAINTS_DATA_ARROW = os.path.join( + ARROW_DATA_DIR, "twitter_complaints_small.arrow" +) TWITTER_COMPLAINTS_DATA_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_small.parquet" ) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON = os.path.join( - DATA_DIR, "twitter_complaints_input_output.json" + JSON_DATA_DIR, "twitter_complaints_input_output.json" ) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join( - DATA_DIR, "twitter_complaints_input_output.jsonl" + JSONL_DATA_DIR, "twitter_complaints_input_output.jsonl" ) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW = os.path.join( - DATA_DIR, "twitter_complaints_input_output.arrow" + ARROW_DATA_DIR, "twitter_complaints_input_output.arrow" ) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_input_output.parquet" ) TWITTER_COMPLAINTS_TOKENIZED_JSON = os.path.join( - DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json" + JSON_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json" ) TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join( - DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl" + JSONL_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl" ) TWITTER_COMPLAINTS_TOKENIZED_ARROW = os.path.join( - DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow" + ARROW_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow" ) TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet" ) -EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json") -MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json") +EMPTY_DATA = os.path.join(JSON_DATA_DIR, "empty_data.json") +MALFORMATTED_DATA = os.path.join(JSON_DATA_DIR, "malformatted_data.json") MODEL_NAME = "Maykeye/TinyLLama-v0" diff --git a/tests/artifacts/testdata/twitter_complaints_input_output.arrow b/tests/artifacts/testdata/arrow/twitter_complaints_input_output.arrow similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_input_output.arrow rename to tests/artifacts/testdata/arrow/twitter_complaints_input_output.arrow diff --git a/tests/artifacts/testdata/twitter_complaints_small.arrow b/tests/artifacts/testdata/arrow/twitter_complaints_small.arrow similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_small.arrow rename to tests/artifacts/testdata/arrow/twitter_complaints_small.arrow diff --git a/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow b/tests/artifacts/testdata/arrow/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow rename to tests/artifacts/testdata/arrow/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow diff --git a/tests/artifacts/testdata/empty_data.json b/tests/artifacts/testdata/json/empty_data.json similarity index 100% rename from tests/artifacts/testdata/empty_data.json rename to tests/artifacts/testdata/json/empty_data.json diff --git a/tests/artifacts/testdata/malformatted_data.json b/tests/artifacts/testdata/json/malformatted_data.json similarity index 100% rename from tests/artifacts/testdata/malformatted_data.json rename to tests/artifacts/testdata/json/malformatted_data.json diff --git a/tests/artifacts/testdata/twitter_complaints_input_output.json b/tests/artifacts/testdata/json/twitter_complaints_input_output.json similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_input_output.json rename to tests/artifacts/testdata/json/twitter_complaints_input_output.json diff --git a/tests/artifacts/testdata/twitter_complaints_small.json b/tests/artifacts/testdata/json/twitter_complaints_small.json similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_small.json rename to tests/artifacts/testdata/json/twitter_complaints_small.json diff --git a/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json b/tests/artifacts/testdata/json/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json rename to tests/artifacts/testdata/json/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json diff --git a/tests/artifacts/testdata/twitter_complaints_input_output.jsonl b/tests/artifacts/testdata/jsonl/twitter_complaints_input_output.jsonl similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_input_output.jsonl rename to tests/artifacts/testdata/jsonl/twitter_complaints_input_output.jsonl diff --git a/tests/artifacts/testdata/twitter_complaints_small.jsonl b/tests/artifacts/testdata/jsonl/twitter_complaints_small.jsonl similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_small.jsonl rename to tests/artifacts/testdata/jsonl/twitter_complaints_small.jsonl diff --git a/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl b/tests/artifacts/testdata/jsonl/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl rename to tests/artifacts/testdata/jsonl/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index c34204f4f..5559ac8ec 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -432,9 +432,11 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), + (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET), + (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_ARROW), ( DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, @@ -447,6 +449,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + ), ], ) def test_process_dataconfig_file(data_config_path, data_path): @@ -491,6 +497,262 @@ def test_process_dataconfig_file(data_config_path, data_path): assert formatted_dataset_field in set(train_set.column_names) +@pytest.mark.parametrize( + "data_config_path, data_path_list", + [ + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSON], + ), + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [ + TWITTER_COMPLAINTS_DATA_JSONL, + TWITTER_COMPLAINTS_DATA_JSONL, + TWITTER_COMPLAINTS_DATA_JSONL, + ], + ), + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_PARQUET, TWITTER_COMPLAINTS_DATA_PARQUET], + ), + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_ARROW, TWITTER_COMPLAINTS_DATA_ARROW], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_JSONL], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [ + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, + ], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_ARROW], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + ], + ), + ], +) +def test_process_dataconfig_multiple_files(data_config_path, data_path_list): + """Ensure that datasets with multiple files are formatted and validated correctly based on the arguments passed in config file.""" + with open(data_config_path, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"] = data_path_list + datasets_name = yaml_content["datasets"][0]["name"] + + # Modify input_field_name and output_field_name according to dataset + if datasets_name == "text_dataset_input_output_masking": + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "input_field_name": "input", + "output_field_name": "output", + } + + # Modify dataset_text_field and template according to dataset + formatted_dataset_field = "formatted_data_field" + if datasets_name == "apply_custom_data_template": + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "dataset_text_field": formatted_dataset_field, + "template": template, + } + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) + assert isinstance(train_set, Dataset) + if datasets_name == "text_dataset_input_output_masking": + column_names = set(["input_ids", "attention_mask", "labels"]) + assert set(train_set.column_names) == column_names + elif datasets_name == "pretokenized_dataset": + assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) + elif datasets_name == "apply_custom_data_template": + assert formatted_dataset_field in set(train_set.column_names) + + +@pytest.mark.parametrize( + "datafiles, datasetconfigname", + [ + ( + [ + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), + ], +) +def test_process_dataconfig_multiple_datasets_datafiles_sampling( + datafiles, datasetconfigname +): + """Ensure that multiple datasets with multiple files are formatted and validated correctly.""" + with open(datasetconfigname, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"] = datafiles[0] + yaml_content["datasets"][1]["data_paths"] = datafiles[1] + yaml_content["datasets"][2]["data_paths"] = datafiles[2] + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + TRAIN_ARGS = configs.TrainingArguments( + packing=False, + max_seq_length=1024, + output_dir="tmp", + ) + (train_set, eval_set, _, _, _, _) = process_dataargs( + data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS + ) + + assert isinstance(train_set, Dataset) + if eval_set: + assert isinstance(eval_set, Dataset) + + assert set(["input_ids", "attention_mask", "labels"]).issubset( + set(train_set.column_names) + ) + if eval_set: + assert set(["input_ids", "attention_mask", "labels"]).issubset( + set(eval_set.column_names) + ) + + +@pytest.mark.parametrize( + "data_config_path, data_path_list", + [ + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_PARQUET], + ), + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [ + TWITTER_COMPLAINTS_TOKENIZED_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [ + TWITTER_COMPLAINTS_TOKENIZED_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_JSON], + ), + ], +) +def test_process_dataconfig_multiple_files_varied_data_formats( + data_config_path, data_path_list +): + """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" + with open(data_config_path, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"] = data_path_list + datasets_name = yaml_content["datasets"][0]["name"] + + # Modify input_field_name and output_field_name according to dataset + if datasets_name == "text_dataset_input_output_masking": + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "input_field_name": "input", + "output_field_name": "output", + } + + # Modify dataset_text_field and template according to dataset + formatted_dataset_field = "formatted_data_field" + if datasets_name == "apply_custom_data_template": + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "dataset_text_field": formatted_dataset_field, + "template": template, + } + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + with pytest.raises( + (AssertionError, datasets.exceptions.DatasetGenerationCastError) + ): + (_, _, _) = _process_dataconfig_file(data_args, tokenizer) + + @pytest.mark.parametrize( "data_args", [ @@ -764,51 +1026,3 @@ def test_process_dataset_configs_with_sampling_error( (_, _, _, _, _, _) = process_dataargs( data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS ) - - -@pytest.mark.parametrize( - "datafiles, datasetconfigname", - [ - ( - [ - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, - ], - DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, - ), - ], -) -def test_process_dataset_configs_with_sampling(datafiles, datasetconfigname): - - data_args = configs.DataArguments() - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - TRAIN_ARGS = configs.TrainingArguments( - packing=False, - max_seq_length=1024, - output_dir="tmp", # Not needed but positional - ) - - with tempfile.NamedTemporaryFile( - "w", delete=False, suffix=".yaml" - ) as temp_yaml_file: - with open(datasetconfigname, "r") as f: - data = yaml.safe_load(f) - datasets = data["datasets"] - for i in range(len(datasets)): - d = datasets[i] - d["data_paths"][0] = datafiles[i] - yaml.dump(data, temp_yaml_file) - data_args.data_config_path = temp_yaml_file.name - - (train_set, eval_set, _, _, _, _) = process_dataargs( - data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS - ) - - assert isinstance(train_set, Dataset) - if eval_set: - assert isinstance(eval_set, Dataset) - - assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) - if eval_set: - assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names)) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 5dbc1144c..8dcdf3087 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -30,19 +30,30 @@ import pytest import torch import transformers +import yaml # First Party from build.utils import serialize_args from scripts.run_inference import TunedCausalLM +from tests.artifacts.predefined_data_configs import ( + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, +) from tests.artifacts.testdata import ( EMPTY_DATA, MALFORMATTED_DATA, MODEL_NAME, + TWITTER_COMPLAINTS_DATA_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSONL, + TWITTER_COMPLAINTS_DATA_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, ) # Local @@ -693,6 +704,8 @@ def test_successful_lora_target_modules_default_from_main(): [ TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_JSON, + TWITTER_COMPLAINTS_DATA_PARQUET, + TWITTER_COMPLAINTS_DATA_ARROW, ], ) def test_run_causallm_ft_and_inference(dataset_path): @@ -729,7 +742,12 @@ def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no(): @pytest.mark.parametrize( "dataset_path", - [TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_JSON], + [ + TWITTER_COMPLAINTS_TOKENIZED_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_JSON, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_ARROW, + ], ) def test_run_causallm_ft_pretokenized(dataset_path): """Check if we can bootstrap and finetune causallm models using pretokenized data""" @@ -764,6 +782,82 @@ def test_run_causallm_ft_pretokenized(dataset_path): assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference +@pytest.mark.parametrize( + "datafiles, datasetconfigname", + [ + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), + ], +) +def test_run_causallm_ft_and_inference_with_multiple_dataset( + datasetconfigname, datafiles +): + """Check if we can finetune causallm models using multiple datasets with multiple files""" + with tempfile.TemporaryDirectory() as tempdir: + data_formatting_args = copy.deepcopy(DATA_ARGS) + + # set training_data_path and response_template to none + data_formatting_args.response_template = None + data_formatting_args.training_data_path = None + + # add data_paths in data_config file + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + with open(datasetconfigname, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + datasets = data["datasets"] + for _, d in enumerate(datasets): + d["data_paths"] = datafiles + yaml.dump(data, temp_yaml_file) + data_formatting_args.data_config_path = temp_yaml_file.name + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args) + + # validate full ft configs + _validate_training(tempdir) + _, checkpoint_path = _get_latest_checkpoint_trainer_state(tempdir) + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + + ############################# Helper functions ############################# def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): train_args = copy.deepcopy(training_args) diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index f0100072b..6a821ec5c 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -90,6 +90,8 @@ def apply_dataset_formatting( dataset_text_field: str, **kwargs, ): + if dataset_text_field not in element: + raise KeyError(f"Dataset should contain {dataset_text_field} field.") return { f"{dataset_text_field}": element[f"{dataset_text_field}"] + tokenizer.eos_token }