From ce82af11dce4a30d80317d7a12cb26f7d14ee076 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 10 Dec 2024 16:24:11 -0500 Subject: [PATCH 1/6] test: Add unit tests to test multiple files in single dataset Signed-off-by: Abhishek --- tests/artifacts/testdata/__init__.py | 32 ++- .../twitter_complaints_input_output.arrow | Bin .../twitter_complaints_small.arrow | Bin ..._tokenized_with_maykeye_tinyllama_v0.arrow | Bin .../testdata/{ => json}/empty_data.json | 0 .../{ => json}/malformatted_data.json | 0 .../twitter_complaints_input_output.json | 0 .../{ => json}/twitter_complaints_small.json | 0 ...s_tokenized_with_maykeye_tinyllama_v0.json | 0 .../twitter_complaints_input_output.jsonl | 0 .../twitter_complaints_small.jsonl | 0 ..._tokenized_with_maykeye_tinyllama_v0.jsonl | 0 tests/data/test_data_preprocessing_utils.py | 234 ++++++++++++++++++ tuning/data/data_handlers.py | 2 + 14 files changed, 257 insertions(+), 11 deletions(-) rename tests/artifacts/testdata/{ => arrow}/twitter_complaints_input_output.arrow (100%) rename tests/artifacts/testdata/{ => arrow}/twitter_complaints_small.arrow (100%) rename tests/artifacts/testdata/{ => arrow}/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow (100%) rename tests/artifacts/testdata/{ => json}/empty_data.json (100%) rename tests/artifacts/testdata/{ => json}/malformatted_data.json (100%) rename tests/artifacts/testdata/{ => json}/twitter_complaints_input_output.json (100%) rename tests/artifacts/testdata/{ => json}/twitter_complaints_small.json (100%) rename tests/artifacts/testdata/{ => json}/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json (100%) rename tests/artifacts/testdata/{ => jsonl}/twitter_complaints_input_output.jsonl (100%) rename tests/artifacts/testdata/{ => jsonl}/twitter_complaints_small.jsonl (100%) rename tests/artifacts/testdata/{ => jsonl}/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl (100%) diff --git a/tests/artifacts/testdata/__init__.py b/tests/artifacts/testdata/__init__.py index 39895f6f1..762f88ab9 100644 --- a/tests/artifacts/testdata/__init__.py +++ b/tests/artifacts/testdata/__init__.py @@ -19,37 +19,47 @@ ### Constants used for data DATA_DIR = os.path.join(os.path.dirname(__file__)) +JSON_DATA_DIR = os.path.join(os.path.dirname(__file__), "json") +JSONL_DATA_DIR = os.path.join(os.path.dirname(__file__), "jsonl") +ARROW_DATA_DIR = os.path.join(os.path.dirname(__file__), "arrow") PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet") -TWITTER_COMPLAINTS_DATA_JSON = os.path.join(DATA_DIR, "twitter_complaints_small.json") -TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl") -TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(DATA_DIR, "twitter_complaints_small.arrow") + +TWITTER_COMPLAINTS_DATA_JSON = os.path.join( + JSON_DATA_DIR, "twitter_complaints_small.json" +) +TWITTER_COMPLAINTS_DATA_JSONL = os.path.join( + JSONL_DATA_DIR, "twitter_complaints_small.jsonl" +) +TWITTER_COMPLAINTS_DATA_ARROW = os.path.join( + ARROW_DATA_DIR, "twitter_complaints_small.arrow" +) TWITTER_COMPLAINTS_DATA_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_small.parquet" ) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON = os.path.join( - DATA_DIR, "twitter_complaints_input_output.json" + JSON_DATA_DIR, "twitter_complaints_input_output.json" ) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join( - DATA_DIR, "twitter_complaints_input_output.jsonl" + JSONL_DATA_DIR, "twitter_complaints_input_output.jsonl" ) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW = os.path.join( - DATA_DIR, "twitter_complaints_input_output.arrow" + ARROW_DATA_DIR, "twitter_complaints_input_output.arrow" ) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_input_output.parquet" ) TWITTER_COMPLAINTS_TOKENIZED_JSON = os.path.join( - DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json" + JSON_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json" ) TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join( - DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl" + JSONL_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl" ) TWITTER_COMPLAINTS_TOKENIZED_ARROW = os.path.join( - DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow" + ARROW_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow" ) TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet" ) -EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json") -MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json") +EMPTY_DATA = os.path.join(JSON_DATA_DIR, "empty_data.json") +MALFORMATTED_DATA = os.path.join(JSON_DATA_DIR, "malformatted_data.json") MODEL_NAME = "Maykeye/TinyLLama-v0" diff --git a/tests/artifacts/testdata/twitter_complaints_input_output.arrow b/tests/artifacts/testdata/arrow/twitter_complaints_input_output.arrow similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_input_output.arrow rename to tests/artifacts/testdata/arrow/twitter_complaints_input_output.arrow diff --git a/tests/artifacts/testdata/twitter_complaints_small.arrow b/tests/artifacts/testdata/arrow/twitter_complaints_small.arrow similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_small.arrow rename to tests/artifacts/testdata/arrow/twitter_complaints_small.arrow diff --git a/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow b/tests/artifacts/testdata/arrow/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow rename to tests/artifacts/testdata/arrow/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow diff --git a/tests/artifacts/testdata/empty_data.json b/tests/artifacts/testdata/json/empty_data.json similarity index 100% rename from tests/artifacts/testdata/empty_data.json rename to tests/artifacts/testdata/json/empty_data.json diff --git a/tests/artifacts/testdata/malformatted_data.json b/tests/artifacts/testdata/json/malformatted_data.json similarity index 100% rename from tests/artifacts/testdata/malformatted_data.json rename to tests/artifacts/testdata/json/malformatted_data.json diff --git a/tests/artifacts/testdata/twitter_complaints_input_output.json b/tests/artifacts/testdata/json/twitter_complaints_input_output.json similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_input_output.json rename to tests/artifacts/testdata/json/twitter_complaints_input_output.json diff --git a/tests/artifacts/testdata/twitter_complaints_small.json b/tests/artifacts/testdata/json/twitter_complaints_small.json similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_small.json rename to tests/artifacts/testdata/json/twitter_complaints_small.json diff --git a/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json b/tests/artifacts/testdata/json/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json rename to tests/artifacts/testdata/json/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json diff --git a/tests/artifacts/testdata/twitter_complaints_input_output.jsonl b/tests/artifacts/testdata/jsonl/twitter_complaints_input_output.jsonl similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_input_output.jsonl rename to tests/artifacts/testdata/jsonl/twitter_complaints_input_output.jsonl diff --git a/tests/artifacts/testdata/twitter_complaints_small.jsonl b/tests/artifacts/testdata/jsonl/twitter_complaints_small.jsonl similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_small.jsonl rename to tests/artifacts/testdata/jsonl/twitter_complaints_small.jsonl diff --git a/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl b/tests/artifacts/testdata/jsonl/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl rename to tests/artifacts/testdata/jsonl/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index c34204f4f..fbb73f649 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -432,9 +432,11 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), + (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET), + (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_ARROW), ( DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, @@ -447,6 +449,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + ), ], ) def test_process_dataconfig_file(data_config_path, data_path): @@ -491,6 +497,234 @@ def test_process_dataconfig_file(data_config_path, data_path): assert formatted_dataset_field in set(train_set.column_names) +@pytest.mark.parametrize( + "data_config_path, list_data_path", + [ + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSON], + ), + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_JSONL], + ), + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_PARQUET, TWITTER_COMPLAINTS_DATA_PARQUET], + ), + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_ARROW, TWITTER_COMPLAINTS_DATA_ARROW], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_JSONL], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [ + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, + ], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_ARROW], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + ], + ), + ], +) +def test_process_dataconfig_multiple_files(data_config_path, list_data_path): + """Ensure that datasets with multiple files are formatted and validated correctly based on the arguments passed in config file.""" + with open(data_config_path, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"] = list_data_path + datasets_name = yaml_content["datasets"][0]["name"] + + # Modify input_field_name and output_field_name according to dataset + if datasets_name == "text_dataset_input_output_masking": + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "input_field_name": "input", + "output_field_name": "output", + } + + # Modify dataset_text_field and template according to dataset + formatted_dataset_field = "formatted_data_field" + if datasets_name == "apply_custom_data_template": + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "dataset_text_field": formatted_dataset_field, + "template": template, + } + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) + assert isinstance(train_set, Dataset) + if datasets_name == "text_dataset_input_output_masking": + column_names = set(["input_ids", "attention_mask", "labels"]) + assert set(train_set.column_names) == column_names + elif datasets_name == "pretokenized_dataset": + assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) + elif datasets_name == "apply_custom_data_template": + assert formatted_dataset_field in set(train_set.column_names) + + +@pytest.mark.parametrize( + "data_config_path, list_data_path", + [ + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_PARQUET], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [ + TWITTER_COMPLAINTS_TOKENIZED_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + ), + ], +) +def test_process_dataconfig_multiple_files_varied_data_formats( + data_config_path, list_data_path +): + """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" + with open(data_config_path, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"] = list_data_path + datasets_name = yaml_content["datasets"][0]["name"] + + # Modify input_field_name and output_field_name according to dataset + if datasets_name == "text_dataset_input_output_masking": + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "input_field_name": "input", + "output_field_name": "output", + } + + # Modify dataset_text_field and template according to dataset + formatted_dataset_field = "formatted_data_field" + if datasets_name == "apply_custom_data_template": + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "dataset_text_field": formatted_dataset_field, + "template": template, + } + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + with pytest.raises(AssertionError): + (_, _, _) = _process_dataconfig_file(data_args, tokenizer) + + +@pytest.mark.parametrize( + "data_config_path, list_data_path", + [ + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [ + TWITTER_COMPLAINTS_TOKENIZED_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_JSON], + ), + ], +) +def test_process_dataconfig_multiple_files_varied_types( + data_config_path, list_data_path +): + """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" + with open(data_config_path, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"] = list_data_path + datasets_name = yaml_content["datasets"][0]["name"] + + # Modify input_field_name and output_field_name according to dataset + if datasets_name == "text_dataset_input_output_masking": + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "input_field_name": "input", + "output_field_name": "output", + } + + # Modify dataset_text_field and template according to dataset + formatted_dataset_field = "formatted_data_field" + if datasets_name == "apply_custom_data_template": + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "dataset_text_field": formatted_dataset_field, + "template": template, + } + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + with pytest.raises(datasets.exceptions.DatasetGenerationCastError): + (_, _, _) = _process_dataconfig_file(data_args, tokenizer) + + @pytest.mark.parametrize( "data_args", [ diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index f0100072b..6a821ec5c 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -90,6 +90,8 @@ def apply_dataset_formatting( dataset_text_field: str, **kwargs, ): + if dataset_text_field not in element: + raise KeyError(f"Dataset should contain {dataset_text_field} field.") return { f"{dataset_text_field}": element[f"{dataset_text_field}"] + tokenizer.eos_token } From 3fe7425146f5edad4f28aa6ffccad420b5dca46f Mon Sep 17 00:00:00 2001 From: Abhishek Date: Wed, 11 Dec 2024 13:20:22 -0500 Subject: [PATCH 2/6] e2e testing unit test for multiple datasets with multiple files Signed-off-by: Abhishek --- tests/test_sft_trainer.py | 72 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 69ccbf4fa..6f8047c1a 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -27,19 +27,27 @@ import pytest import torch import transformers +import yaml # First Party from build.utils import serialize_args from scripts.run_inference import TunedCausalLM +from tests.artifacts.predefined_data_configs import ( + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, +) from tests.artifacts.testdata import ( EMPTY_DATA, MALFORMATTED_DATA, MODEL_NAME, + TWITTER_COMPLAINTS_DATA_ARROW, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSONL, + TWITTER_COMPLAINTS_DATA_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, ) # Local @@ -683,6 +691,8 @@ def test_successful_lora_target_modules_default_from_main(): [ TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_JSON, + TWITTER_COMPLAINTS_DATA_PARQUET, + TWITTER_COMPLAINTS_DATA_ARROW, ], ) def test_run_causallm_ft_and_inference(dataset_path): @@ -719,7 +729,12 @@ def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no(): @pytest.mark.parametrize( "dataset_path", - [TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_JSON], + [ + TWITTER_COMPLAINTS_TOKENIZED_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_JSON, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_ARROW, + ], ) def test_run_causallm_ft_pretokenized(dataset_path): """Check if we can bootstrap and finetune causallm models using pretokenized data""" @@ -754,6 +769,61 @@ def test_run_causallm_ft_pretokenized(dataset_path): assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference +@pytest.mark.parametrize( + "datafiles, datasetconfigname", + [ + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ) + ], +) +def test_run_causallm_ft_and_inference_with_multiple_dataset( + datasetconfigname, datafiles +): + """Check if we can finetune causallm models using multiple datasets with multiple files""" + with tempfile.TemporaryDirectory() as tempdir: + data_formatting_args = copy.deepcopy(DATA_ARGS) + + # set training_data_path and response_template to none + data_formatting_args.response_template = None + data_formatting_args.training_data_path = None + + # add data_paths in data_config file + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + with open(datasetconfigname, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + datasets = data["datasets"] + for _, d in enumerate(datasets): + d["data_paths"] = datafiles + yaml.dump(data, temp_yaml_file) + data_formatting_args.data_config_path = temp_yaml_file.name + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args) + + # validate full ft configs + _validate_training(tempdir) + _, checkpoint_path = _get_latest_checkpoint_trainer_state(tempdir) + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + + ############################# Helper functions ############################# def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): train_args = copy.deepcopy(training_args) From e89002de928f419601038b3b172c99a3d0b62429 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 11 Dec 2024 15:48:46 -0500 Subject: [PATCH 3/6] test: multiple datasets with multiple datafiles column names Signed-off-by: Will Johnson --- tests/data/test_data_preprocessing_utils.py | 44 +++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index fbb73f649..937fc5a70 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -607,6 +607,50 @@ def test_process_dataconfig_multiple_files(data_config_path, list_data_path): assert formatted_dataset_field in set(train_set.column_names) +@pytest.mark.parametrize( + "datafiles, datasetconfigname", + [ + ( + [ + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), + ], +) +def test_process_dataconfig_multiple_datasets_datafiles(datafiles, datasetconfigname): + """Ensure that multiple datasets with multiple files are formatted and validated correctly.""" + with open(datasetconfigname, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"] = datafiles[0] + yaml_content["datasets"][1]["data_paths"] = datafiles[1] + yaml_content["datasets"][2]["data_paths"] = datafiles[2] + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) + assert isinstance(train_set, Dataset) + column_names = set(["input_ids", "attention_mask", "labels"]) + assert set(train_set.column_names) == column_names + + @pytest.mark.parametrize( "data_config_path, list_data_path", [ From 4ba1c048b42436ef2b9a81333c66f5a898928ac6 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Wed, 11 Dec 2024 15:54:01 -0500 Subject: [PATCH 4/6] PR changes Signed-off-by: Abhishek --- .pylintrc | 2 +- tests/data/test_data_preprocessing_utils.py | 8 ++++++- tests/test_sft_trainer.py | 26 ++++++++++++++++++++- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/.pylintrc b/.pylintrc index 222bdf6cb..f54599d18 100644 --- a/.pylintrc +++ b/.pylintrc @@ -333,7 +333,7 @@ indent-string=' ' max-line-length=100 # Maximum number of lines in a module. -max-module-lines=1200 +max-module-lines=1400 # Allow the body of a class to be on the same line as the declaration if body # contains single statement. diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 937fc5a70..45d24f698 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -506,7 +506,11 @@ def test_process_dataconfig_file(data_config_path, data_path): ), ( DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, - [TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_JSONL], + [ + TWITTER_COMPLAINTS_DATA_JSONL, + TWITTER_COMPLAINTS_DATA_JSONL, + TWITTER_COMPLAINTS_DATA_JSONL, + ], ), ( DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, @@ -529,6 +533,7 @@ def test_process_dataconfig_file(data_config_path, data_path): [ TWITTER_COMPLAINTS_TOKENIZED_PARQUET, TWITTER_COMPLAINTS_TOKENIZED_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, ], ), ( @@ -561,6 +566,7 @@ def test_process_dataconfig_file(data_config_path, data_path): [ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, ], ), ], diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 6f8047c1a..1caa0392c 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -40,7 +40,10 @@ MALFORMATTED_DATA, MODEL_NAME, TWITTER_COMPLAINTS_DATA_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_PARQUET, @@ -772,13 +775,34 @@ def test_run_causallm_ft_pretokenized(dataset_path): @pytest.mark.parametrize( "datafiles, datasetconfigname", [ + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), ( [ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, ], DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, - ) + ), + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), ], ) def test_run_causallm_ft_and_inference_with_multiple_dataset( From 8b1d07e8655cd18c8e333f895973696052a4856c Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 12 Dec 2024 10:20:41 -0500 Subject: [PATCH 5/6] Handling files passed via pattern Signed-off-by: Abhishek --- .../jsonl/twitter_complaints_small_2.jsonl | 10 ++++ tests/data/test_data_preprocessing_utils.py | 47 +++++++++++++++++++ tuning/data/data_config.py | 1 - 3 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 tests/artifacts/testdata/jsonl/twitter_complaints_small_2.jsonl diff --git a/tests/artifacts/testdata/jsonl/twitter_complaints_small_2.jsonl b/tests/artifacts/testdata/jsonl/twitter_complaints_small_2.jsonl new file mode 100644 index 000000000..0837217c3 --- /dev/null +++ b/tests/artifacts/testdata/jsonl/twitter_complaints_small_2.jsonl @@ -0,0 +1,10 @@ +{"Tweet text":"@NortonSupport Thanks much.","ID":10,"Label":2,"text_label":"no complaint","output":"### Text: @NortonSupport Thanks much.\n\n### Label: no complaint"} +{"Tweet text":"@VerizonSupport all of a sudden I can't connect to my primary wireless network but guest one works","ID":11,"Label":2,"text_label":"no complaint","output":"### Text: @VerizonSupport all of a sudden I can't connect to my primary wireless network but guest one works\n\n### Label: no complaint"} +{"Tweet text":"Aaaahhhhh!!!! My @Razer @PlayOverwatch d.va meka headset came in!!! I didn't even know it had shipped!!! So excited\u2026 https:\/\/t.co\/4gXy9xED8d","ID":12,"Label":2,"text_label":"no complaint","output":"### Text: Aaaahhhhh!!!! My @Razer @PlayOverwatch d.va meka headset came in!!! I didn't even know it had shipped!!! So excited\u2026 https:\/\/t.co\/4gXy9xED8d\n\n### Label: no complaint"} +{"Tweet text":"@Lin_Manuel @jmessinaphoto @VAMNit Omg a little squish!!!!! Enjoy and congrats!!!! I miss mine being so young! \ufffd\ufffd\ufffd\ufffd\ufffd\ufffd","ID":13,"Label":2,"text_label":"no complaint","output":"### Text: @Lin_Manuel @jmessinaphoto @VAMNit Omg a little squish!!!!! Enjoy and congrats!!!! I miss mine being so young! \ufffd\ufffd\ufffd\ufffd\ufffd\ufffd\n\n### Label: no complaint"} +{"Tweet text":"@IanJamesPoulter What's your secret to poaching eggs? Mine NEVER look that good.","ID":14,"Label":2,"text_label":"no complaint","output":"### Text: @IanJamesPoulter What's your secret to poaching eggs? Mine NEVER look that good.\n\n### Label: no complaint"} +{"Tweet text":"@AWSSupport When will be able Kinesis Firehose compatible with Elasticsearch 6.0? Thank you!","ID":15,"Label":2,"text_label":"no complaint","output":"### Text: @AWSSupport When will be able Kinesis Firehose compatible with Elasticsearch 6.0? Thank you!\n\n### Label: no complaint"} +{"Tweet text":"@NCIS_CBS https:\/\/t.co\/eeVL9Eu3bE","ID":16,"Label":2,"text_label":"no complaint","output":"### Text: @NCIS_CBS https:\/\/t.co\/eeVL9Eu3bE\n\n### Label: no complaint"} +{"Tweet text":"@msetchell Via the settings? That\u2019s how I do it on master T\u2019s","ID":17,"Label":2,"text_label":"no complaint","output":"### Text: @msetchell Via the settings? That\u2019s how I do it on master T\u2019s\n\n### Label: no complaint"} +{"Tweet text":"Today at work there was a low flying duck heading toward a crowd of people, and I yelled \"watch out! and I'm very disappointed with myself.","ID":18,"Label":2,"text_label":"no complaint","output":"### Text: Today at work there was a low flying duck heading toward a crowd of people, and I yelled \"watch out! and I'm very disappointed with myself.\n\n### Label: no complaint"} +{"Tweet text":"@NortonSupport @NortonOnline What the hell is a dm 5-10 days to get money back bank account now overdrawn thanks guys","ID":19,"Label":1,"text_label":"complaint","output":"### Text: @NortonSupport @NortonOnline What the hell is a dm 5-10 days to get money back bank account now overdrawn thanks guys\n\n### Label: complaint"} diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 45d24f698..2688e6642 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -13,6 +13,7 @@ # limitations under the License. # Standard +import glob import json import tempfile @@ -613,6 +614,52 @@ def test_process_dataconfig_multiple_files(data_config_path, list_data_path): assert formatted_dataset_field in set(train_set.column_names) +@pytest.mark.parametrize( + "data_config_path, data_path", + [ + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + TWITTER_COMPLAINTS_DATA_JSON, + ), + ], +) +def test_process_dataconfig_multiple_files_with_globbing(data_config_path, data_path): + """Ensure that datasets files matching globbing pattern are formatted and validated correctly based on the arguments passed in config file.""" + with open(data_config_path, "r") as f: + yaml_content = yaml.safe_load(f) + + PATTERN_TWITTER_COMPLAINTS_DATA_JSON = data_path.replace( + "twitter_complaints_small.json", "*small*.json" + ) + yaml_content["datasets"][0]["data_paths"][0] = PATTERN_TWITTER_COMPLAINTS_DATA_JSON + + # Modify dataset_text_field and template according to dataset + formatted_dataset_field = "formatted_data_field" + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "dataset_text_field": formatted_dataset_field, + "template": template, + } + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) + assert isinstance(train_set, Dataset) + assert formatted_dataset_field in set(train_set.column_names) + + data_len = sum( + len(json.load(open(file, "r"))) + for file in glob.glob(PATTERN_TWITTER_COMPLAINTS_DATA_JSON) + ) + assert len(train_set) == data_len + + @pytest.mark.parametrize( "datafiles, datasetconfigname", [ diff --git a/tuning/data/data_config.py b/tuning/data/data_config.py index 4da83d720..ed8af3f8c 100644 --- a/tuning/data/data_config.py +++ b/tuning/data/data_config.py @@ -79,7 +79,6 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig: c.data_paths = [] for p in data_paths: assert isinstance(p, str), f"path {p} should be of the type string" - assert os.path.exists(p), f"data_paths {p} does not exist" if not os.path.isabs(p): _p = os.path.abspath(p) logging.warning( From 70eccb14ed2fbb15a9ca6c617c270f6887873bc4 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 12 Dec 2024 14:12:01 -0500 Subject: [PATCH 6/6] Handling files passed via pattern Signed-off-by: Abhishek --- tests/data/test_data_preprocessing_utils.py | 15 +++++++++------ tuning/utils/utils.py | 4 ++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 2688e6642..4619392ca 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -15,6 +15,7 @@ # Standard import glob import json +import os import tempfile # Third Party @@ -619,7 +620,9 @@ def test_process_dataconfig_multiple_files(data_config_path, list_data_path): [ ( DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, - TWITTER_COMPLAINTS_DATA_JSON, + os.path.join( + os.path.dirname(TWITTER_COMPLAINTS_DATA_JSONL), "*small*.jsonl" + ), ), ], ) @@ -628,10 +631,10 @@ def test_process_dataconfig_multiple_files_with_globbing(data_config_path, data_ with open(data_config_path, "r") as f: yaml_content = yaml.safe_load(f) - PATTERN_TWITTER_COMPLAINTS_DATA_JSON = data_path.replace( - "twitter_complaints_small.json", "*small*.json" + PATTERN_TWITTER_COMPLAINTS_DATA_JSONL = data_path.replace( + "twitter_complaints_small.jsonl", "*small*.jsonl" ) - yaml_content["datasets"][0]["data_paths"][0] = PATTERN_TWITTER_COMPLAINTS_DATA_JSON + yaml_content["datasets"][0]["data_paths"][0] = PATTERN_TWITTER_COMPLAINTS_DATA_JSONL # Modify dataset_text_field and template according to dataset formatted_dataset_field = "formatted_data_field" @@ -654,8 +657,8 @@ def test_process_dataconfig_multiple_files_with_globbing(data_config_path, data_ assert formatted_dataset_field in set(train_set.column_names) data_len = sum( - len(json.load(open(file, "r"))) - for file in glob.glob(PATTERN_TWITTER_COMPLAINTS_DATA_JSON) + sum(1 for _ in open(file, "r")) # Count lines in each JSONL file + for file in glob.glob(PATTERN_TWITTER_COMPLAINTS_DATA_JSONL) ) assert len(train_set) == data_len diff --git a/tuning/utils/utils.py b/tuning/utils/utils.py index 6eef6b2cf..b6c6a38b0 100644 --- a/tuning/utils/utils.py +++ b/tuning/utils/utils.py @@ -31,9 +31,9 @@ def get_loader_for_filepath(file_path: str) -> str: return "text" if ext in (".json", ".jsonl"): return "json" - if ext in (".arrow"): + if ext in (".arrow",): return "arrow" - if ext in (".parquet"): + if ext in (".parquet",): return "parquet" return ext