From ce82af11dce4a30d80317d7a12cb26f7d14ee076 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 10 Dec 2024 16:24:11 -0500 Subject: [PATCH 1/7] test: Add unit tests to test multiple files in single dataset Signed-off-by: Abhishek --- tests/artifacts/testdata/__init__.py | 32 ++- .../twitter_complaints_input_output.arrow | Bin .../twitter_complaints_small.arrow | Bin ..._tokenized_with_maykeye_tinyllama_v0.arrow | Bin .../testdata/{ => json}/empty_data.json | 0 .../{ => json}/malformatted_data.json | 0 .../twitter_complaints_input_output.json | 0 .../{ => json}/twitter_complaints_small.json | 0 ...s_tokenized_with_maykeye_tinyllama_v0.json | 0 .../twitter_complaints_input_output.jsonl | 0 .../twitter_complaints_small.jsonl | 0 ..._tokenized_with_maykeye_tinyllama_v0.jsonl | 0 tests/data/test_data_preprocessing_utils.py | 234 ++++++++++++++++++ tuning/data/data_handlers.py | 2 + 14 files changed, 257 insertions(+), 11 deletions(-) rename tests/artifacts/testdata/{ => arrow}/twitter_complaints_input_output.arrow (100%) rename tests/artifacts/testdata/{ => arrow}/twitter_complaints_small.arrow (100%) rename tests/artifacts/testdata/{ => arrow}/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow (100%) rename tests/artifacts/testdata/{ => json}/empty_data.json (100%) rename tests/artifacts/testdata/{ => json}/malformatted_data.json (100%) rename tests/artifacts/testdata/{ => json}/twitter_complaints_input_output.json (100%) rename tests/artifacts/testdata/{ => json}/twitter_complaints_small.json (100%) rename tests/artifacts/testdata/{ => json}/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json (100%) rename tests/artifacts/testdata/{ => jsonl}/twitter_complaints_input_output.jsonl (100%) rename tests/artifacts/testdata/{ => jsonl}/twitter_complaints_small.jsonl (100%) rename tests/artifacts/testdata/{ => jsonl}/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl (100%) diff --git a/tests/artifacts/testdata/__init__.py b/tests/artifacts/testdata/__init__.py index 39895f6f1..762f88ab9 100644 --- a/tests/artifacts/testdata/__init__.py +++ b/tests/artifacts/testdata/__init__.py @@ -19,37 +19,47 @@ ### Constants used for data DATA_DIR = os.path.join(os.path.dirname(__file__)) +JSON_DATA_DIR = os.path.join(os.path.dirname(__file__), "json") +JSONL_DATA_DIR = os.path.join(os.path.dirname(__file__), "jsonl") +ARROW_DATA_DIR = os.path.join(os.path.dirname(__file__), "arrow") PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet") -TWITTER_COMPLAINTS_DATA_JSON = os.path.join(DATA_DIR, "twitter_complaints_small.json") -TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl") -TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(DATA_DIR, "twitter_complaints_small.arrow") + +TWITTER_COMPLAINTS_DATA_JSON = os.path.join( + JSON_DATA_DIR, "twitter_complaints_small.json" +) +TWITTER_COMPLAINTS_DATA_JSONL = os.path.join( + JSONL_DATA_DIR, "twitter_complaints_small.jsonl" +) +TWITTER_COMPLAINTS_DATA_ARROW = os.path.join( + ARROW_DATA_DIR, "twitter_complaints_small.arrow" +) TWITTER_COMPLAINTS_DATA_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_small.parquet" ) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON = os.path.join( - DATA_DIR, "twitter_complaints_input_output.json" + JSON_DATA_DIR, "twitter_complaints_input_output.json" ) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join( - DATA_DIR, "twitter_complaints_input_output.jsonl" + JSONL_DATA_DIR, "twitter_complaints_input_output.jsonl" ) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW = os.path.join( - DATA_DIR, "twitter_complaints_input_output.arrow" + ARROW_DATA_DIR, "twitter_complaints_input_output.arrow" ) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_input_output.parquet" ) TWITTER_COMPLAINTS_TOKENIZED_JSON = os.path.join( - DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json" + JSON_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json" ) TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join( - DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl" + JSONL_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl" ) TWITTER_COMPLAINTS_TOKENIZED_ARROW = os.path.join( - DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow" + ARROW_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow" ) TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet" ) -EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json") -MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json") +EMPTY_DATA = os.path.join(JSON_DATA_DIR, "empty_data.json") +MALFORMATTED_DATA = os.path.join(JSON_DATA_DIR, "malformatted_data.json") MODEL_NAME = "Maykeye/TinyLLama-v0" diff --git a/tests/artifacts/testdata/twitter_complaints_input_output.arrow b/tests/artifacts/testdata/arrow/twitter_complaints_input_output.arrow similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_input_output.arrow rename to tests/artifacts/testdata/arrow/twitter_complaints_input_output.arrow diff --git a/tests/artifacts/testdata/twitter_complaints_small.arrow b/tests/artifacts/testdata/arrow/twitter_complaints_small.arrow similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_small.arrow rename to tests/artifacts/testdata/arrow/twitter_complaints_small.arrow diff --git a/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow b/tests/artifacts/testdata/arrow/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow rename to tests/artifacts/testdata/arrow/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow diff --git a/tests/artifacts/testdata/empty_data.json b/tests/artifacts/testdata/json/empty_data.json similarity index 100% rename from tests/artifacts/testdata/empty_data.json rename to tests/artifacts/testdata/json/empty_data.json diff --git a/tests/artifacts/testdata/malformatted_data.json b/tests/artifacts/testdata/json/malformatted_data.json similarity index 100% rename from tests/artifacts/testdata/malformatted_data.json rename to tests/artifacts/testdata/json/malformatted_data.json diff --git a/tests/artifacts/testdata/twitter_complaints_input_output.json b/tests/artifacts/testdata/json/twitter_complaints_input_output.json similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_input_output.json rename to tests/artifacts/testdata/json/twitter_complaints_input_output.json diff --git a/tests/artifacts/testdata/twitter_complaints_small.json b/tests/artifacts/testdata/json/twitter_complaints_small.json similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_small.json rename to tests/artifacts/testdata/json/twitter_complaints_small.json diff --git a/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json b/tests/artifacts/testdata/json/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json rename to tests/artifacts/testdata/json/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json diff --git a/tests/artifacts/testdata/twitter_complaints_input_output.jsonl b/tests/artifacts/testdata/jsonl/twitter_complaints_input_output.jsonl similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_input_output.jsonl rename to tests/artifacts/testdata/jsonl/twitter_complaints_input_output.jsonl diff --git a/tests/artifacts/testdata/twitter_complaints_small.jsonl b/tests/artifacts/testdata/jsonl/twitter_complaints_small.jsonl similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_small.jsonl rename to tests/artifacts/testdata/jsonl/twitter_complaints_small.jsonl diff --git a/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl b/tests/artifacts/testdata/jsonl/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl similarity index 100% rename from tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl rename to tests/artifacts/testdata/jsonl/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index c34204f4f..fbb73f649 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -432,9 +432,11 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), + (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET), + (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_ARROW), ( DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, @@ -447,6 +449,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + ), ], ) def test_process_dataconfig_file(data_config_path, data_path): @@ -491,6 +497,234 @@ def test_process_dataconfig_file(data_config_path, data_path): assert formatted_dataset_field in set(train_set.column_names) +@pytest.mark.parametrize( + "data_config_path, list_data_path", + [ + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSON], + ), + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_JSONL], + ), + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_PARQUET, TWITTER_COMPLAINTS_DATA_PARQUET], + ), + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_ARROW, TWITTER_COMPLAINTS_DATA_ARROW], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_JSONL], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [ + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, + ], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_ARROW], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + ], + ), + ], +) +def test_process_dataconfig_multiple_files(data_config_path, list_data_path): + """Ensure that datasets with multiple files are formatted and validated correctly based on the arguments passed in config file.""" + with open(data_config_path, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"] = list_data_path + datasets_name = yaml_content["datasets"][0]["name"] + + # Modify input_field_name and output_field_name according to dataset + if datasets_name == "text_dataset_input_output_masking": + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "input_field_name": "input", + "output_field_name": "output", + } + + # Modify dataset_text_field and template according to dataset + formatted_dataset_field = "formatted_data_field" + if datasets_name == "apply_custom_data_template": + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "dataset_text_field": formatted_dataset_field, + "template": template, + } + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) + assert isinstance(train_set, Dataset) + if datasets_name == "text_dataset_input_output_masking": + column_names = set(["input_ids", "attention_mask", "labels"]) + assert set(train_set.column_names) == column_names + elif datasets_name == "pretokenized_dataset": + assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) + elif datasets_name == "apply_custom_data_template": + assert formatted_dataset_field in set(train_set.column_names) + + +@pytest.mark.parametrize( + "data_config_path, list_data_path", + [ + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_PARQUET], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [ + TWITTER_COMPLAINTS_TOKENIZED_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + ), + ], +) +def test_process_dataconfig_multiple_files_varied_data_formats( + data_config_path, list_data_path +): + """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" + with open(data_config_path, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"] = list_data_path + datasets_name = yaml_content["datasets"][0]["name"] + + # Modify input_field_name and output_field_name according to dataset + if datasets_name == "text_dataset_input_output_masking": + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "input_field_name": "input", + "output_field_name": "output", + } + + # Modify dataset_text_field and template according to dataset + formatted_dataset_field = "formatted_data_field" + if datasets_name == "apply_custom_data_template": + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "dataset_text_field": formatted_dataset_field, + "template": template, + } + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + with pytest.raises(AssertionError): + (_, _, _) = _process_dataconfig_file(data_args, tokenizer) + + +@pytest.mark.parametrize( + "data_config_path, list_data_path", + [ + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], + ), + ( + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + [ + TWITTER_COMPLAINTS_TOKENIZED_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + ), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + [TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_JSON], + ), + ], +) +def test_process_dataconfig_multiple_files_varied_types( + data_config_path, list_data_path +): + """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" + with open(data_config_path, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"] = list_data_path + datasets_name = yaml_content["datasets"][0]["name"] + + # Modify input_field_name and output_field_name according to dataset + if datasets_name == "text_dataset_input_output_masking": + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "input_field_name": "input", + "output_field_name": "output", + } + + # Modify dataset_text_field and template according to dataset + formatted_dataset_field = "formatted_data_field" + if datasets_name == "apply_custom_data_template": + template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" + yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { + "dataset_text_field": formatted_dataset_field, + "template": template, + } + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + with pytest.raises(datasets.exceptions.DatasetGenerationCastError): + (_, _, _) = _process_dataconfig_file(data_args, tokenizer) + + @pytest.mark.parametrize( "data_args", [ diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index f0100072b..6a821ec5c 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -90,6 +90,8 @@ def apply_dataset_formatting( dataset_text_field: str, **kwargs, ): + if dataset_text_field not in element: + raise KeyError(f"Dataset should contain {dataset_text_field} field.") return { f"{dataset_text_field}": element[f"{dataset_text_field}"] + tokenizer.eos_token } From 3fe7425146f5edad4f28aa6ffccad420b5dca46f Mon Sep 17 00:00:00 2001 From: Abhishek Date: Wed, 11 Dec 2024 13:20:22 -0500 Subject: [PATCH 2/7] e2e testing unit test for multiple datasets with multiple files Signed-off-by: Abhishek --- tests/test_sft_trainer.py | 72 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 69ccbf4fa..6f8047c1a 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -27,19 +27,27 @@ import pytest import torch import transformers +import yaml # First Party from build.utils import serialize_args from scripts.run_inference import TunedCausalLM +from tests.artifacts.predefined_data_configs import ( + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, +) from tests.artifacts.testdata import ( EMPTY_DATA, MALFORMATTED_DATA, MODEL_NAME, + TWITTER_COMPLAINTS_DATA_ARROW, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSONL, + TWITTER_COMPLAINTS_DATA_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, ) # Local @@ -683,6 +691,8 @@ def test_successful_lora_target_modules_default_from_main(): [ TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_JSON, + TWITTER_COMPLAINTS_DATA_PARQUET, + TWITTER_COMPLAINTS_DATA_ARROW, ], ) def test_run_causallm_ft_and_inference(dataset_path): @@ -719,7 +729,12 @@ def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no(): @pytest.mark.parametrize( "dataset_path", - [TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_JSON], + [ + TWITTER_COMPLAINTS_TOKENIZED_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_JSON, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_ARROW, + ], ) def test_run_causallm_ft_pretokenized(dataset_path): """Check if we can bootstrap and finetune causallm models using pretokenized data""" @@ -754,6 +769,61 @@ def test_run_causallm_ft_pretokenized(dataset_path): assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference +@pytest.mark.parametrize( + "datafiles, datasetconfigname", + [ + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ) + ], +) +def test_run_causallm_ft_and_inference_with_multiple_dataset( + datasetconfigname, datafiles +): + """Check if we can finetune causallm models using multiple datasets with multiple files""" + with tempfile.TemporaryDirectory() as tempdir: + data_formatting_args = copy.deepcopy(DATA_ARGS) + + # set training_data_path and response_template to none + data_formatting_args.response_template = None + data_formatting_args.training_data_path = None + + # add data_paths in data_config file + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + with open(datasetconfigname, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + datasets = data["datasets"] + for _, d in enumerate(datasets): + d["data_paths"] = datafiles + yaml.dump(data, temp_yaml_file) + data_formatting_args.data_config_path = temp_yaml_file.name + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args) + + # validate full ft configs + _validate_training(tempdir) + _, checkpoint_path = _get_latest_checkpoint_trainer_state(tempdir) + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + + ############################# Helper functions ############################# def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): train_args = copy.deepcopy(training_args) From e89002de928f419601038b3b172c99a3d0b62429 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 11 Dec 2024 15:48:46 -0500 Subject: [PATCH 3/7] test: multiple datasets with multiple datafiles column names Signed-off-by: Will Johnson --- tests/data/test_data_preprocessing_utils.py | 44 +++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index fbb73f649..937fc5a70 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -607,6 +607,50 @@ def test_process_dataconfig_multiple_files(data_config_path, list_data_path): assert formatted_dataset_field in set(train_set.column_names) +@pytest.mark.parametrize( + "datafiles, datasetconfigname", + [ + ( + [ + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), + ], +) +def test_process_dataconfig_multiple_datasets_datafiles(datafiles, datasetconfigname): + """Ensure that multiple datasets with multiple files are formatted and validated correctly.""" + with open(datasetconfigname, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"] = datafiles[0] + yaml_content["datasets"][1]["data_paths"] = datafiles[1] + yaml_content["datasets"][2]["data_paths"] = datafiles[2] + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) + assert isinstance(train_set, Dataset) + column_names = set(["input_ids", "attention_mask", "labels"]) + assert set(train_set.column_names) == column_names + + @pytest.mark.parametrize( "data_config_path, list_data_path", [ From 4ba1c048b42436ef2b9a81333c66f5a898928ac6 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Wed, 11 Dec 2024 15:54:01 -0500 Subject: [PATCH 4/7] PR changes Signed-off-by: Abhishek --- .pylintrc | 2 +- tests/data/test_data_preprocessing_utils.py | 8 ++++++- tests/test_sft_trainer.py | 26 ++++++++++++++++++++- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/.pylintrc b/.pylintrc index 222bdf6cb..f54599d18 100644 --- a/.pylintrc +++ b/.pylintrc @@ -333,7 +333,7 @@ indent-string=' ' max-line-length=100 # Maximum number of lines in a module. -max-module-lines=1200 +max-module-lines=1400 # Allow the body of a class to be on the same line as the declaration if body # contains single statement. diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 937fc5a70..45d24f698 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -506,7 +506,11 @@ def test_process_dataconfig_file(data_config_path, data_path): ), ( DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, - [TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_JSONL], + [ + TWITTER_COMPLAINTS_DATA_JSONL, + TWITTER_COMPLAINTS_DATA_JSONL, + TWITTER_COMPLAINTS_DATA_JSONL, + ], ), ( DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, @@ -529,6 +533,7 @@ def test_process_dataconfig_file(data_config_path, data_path): [ TWITTER_COMPLAINTS_TOKENIZED_PARQUET, TWITTER_COMPLAINTS_TOKENIZED_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_PARQUET, ], ), ( @@ -561,6 +566,7 @@ def test_process_dataconfig_file(data_config_path, data_path): [ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, ], ), ], diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 6f8047c1a..1caa0392c 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -40,7 +40,10 @@ MALFORMATTED_DATA, MODEL_NAME, TWITTER_COMPLAINTS_DATA_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_PARQUET, @@ -772,13 +775,34 @@ def test_run_causallm_ft_pretokenized(dataset_path): @pytest.mark.parametrize( "datafiles, datasetconfigname", [ + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), ( [ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, ], DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, - ) + ), + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), ], ) def test_run_causallm_ft_and_inference_with_multiple_dataset( From 3fce172390951c7245bcbf81d00649f5d71c7871 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 12 Dec 2024 18:37:40 -0500 Subject: [PATCH 5/7] PR Changes Signed-off-by: Abhishek --- .pylintrc | 2 +- tests/data/test_data_preprocessing_utils.py | 90 +++++++-------------- 2 files changed, 29 insertions(+), 63 deletions(-) diff --git a/.pylintrc b/.pylintrc index f54599d18..222bdf6cb 100644 --- a/.pylintrc +++ b/.pylintrc @@ -333,7 +333,7 @@ indent-string=' ' max-line-length=100 # Maximum number of lines in a module. -max-module-lines=1400 +max-module-lines=1200 # Allow the body of a class to be on the same line as the declaration if body # contains single statement. diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 45d24f698..863f6ac7a 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -23,7 +23,9 @@ import datasets import pytest import yaml - +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) # First Party from tests.artifacts.predefined_data_configs import ( DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, @@ -498,7 +500,7 @@ def test_process_dataconfig_file(data_config_path, data_path): @pytest.mark.parametrize( - "data_config_path, list_data_path", + "data_config_path, data_path_list", [ ( DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, @@ -571,11 +573,11 @@ def test_process_dataconfig_file(data_config_path, data_path): ), ], ) -def test_process_dataconfig_multiple_files(data_config_path, list_data_path): +def test_process_dataconfig_multiple_files(data_config_path, data_path_list): """Ensure that datasets with multiple files are formatted and validated correctly based on the arguments passed in config file.""" with open(data_config_path, "r") as f: yaml_content = yaml.safe_load(f) - yaml_content["datasets"][0]["data_paths"] = list_data_path + yaml_content["datasets"][0]["data_paths"] = data_path_list datasets_name = yaml_content["datasets"][0]["name"] # Modify input_field_name and output_field_name according to dataset @@ -635,7 +637,7 @@ def test_process_dataconfig_multiple_files(data_config_path, list_data_path): ), ], ) -def test_process_dataconfig_multiple_datasets_datafiles(datafiles, datasetconfigname): +def test_process_dataconfig_multiple_datasets_datafiles_sampling(datafiles, datasetconfigname): """Ensure that multiple datasets with multiple files are formatted and validated correctly.""" with open(datasetconfigname, "r") as f: yaml_content = yaml.safe_load(f) @@ -651,14 +653,26 @@ def test_process_dataconfig_multiple_datasets_datafiles(datafiles, datasetconfig data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) + TRAIN_ARGS = configs.TrainingArguments( + packing=False, + max_seq_length=1024, + output_dir="tmp", + ) + (train_set, eval_set, _, _, _, _) = process_dataargs( + data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS + ) + assert isinstance(train_set, Dataset) - column_names = set(["input_ids", "attention_mask", "labels"]) - assert set(train_set.column_names) == column_names + if eval_set: + assert isinstance(eval_set, Dataset) + + assert set(["input_ids", "attention_mask", "labels"]).issubset(set(train_set.column_names)) + if eval_set: + assert set(["input_ids", "attention_mask", "labels"]).issubset(set(eval_set.column_names)) @pytest.mark.parametrize( - "data_config_path, list_data_path", + "data_config_path, data_path_list", [ ( DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, @@ -682,12 +696,12 @@ def test_process_dataconfig_multiple_datasets_datafiles(datafiles, datasetconfig ], ) def test_process_dataconfig_multiple_files_varied_data_formats( - data_config_path, list_data_path + data_config_path, data_path_list ): """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" with open(data_config_path, "r") as f: yaml_content = yaml.safe_load(f) - yaml_content["datasets"][0]["data_paths"] = list_data_path + yaml_content["datasets"][0]["data_paths"] = data_path_list datasets_name = yaml_content["datasets"][0]["name"] # Modify input_field_name and output_field_name according to dataset @@ -719,7 +733,7 @@ def test_process_dataconfig_multiple_files_varied_data_formats( @pytest.mark.parametrize( - "data_config_path, list_data_path", + "data_config_path, data_path_list", [ ( DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, @@ -739,12 +753,12 @@ def test_process_dataconfig_multiple_files_varied_data_formats( ], ) def test_process_dataconfig_multiple_files_varied_types( - data_config_path, list_data_path + data_config_path, data_path_list ): """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" with open(data_config_path, "r") as f: yaml_content = yaml.safe_load(f) - yaml_content["datasets"][0]["data_paths"] = list_data_path + yaml_content["datasets"][0]["data_paths"] = data_path_list datasets_name = yaml_content["datasets"][0]["name"] # Modify input_field_name and output_field_name according to dataset @@ -1048,51 +1062,3 @@ def test_process_dataset_configs_with_sampling_error( (_, _, _, _, _, _) = process_dataargs( data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS ) - - -@pytest.mark.parametrize( - "datafiles, datasetconfigname", - [ - ( - [ - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, - ], - DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, - ), - ], -) -def test_process_dataset_configs_with_sampling(datafiles, datasetconfigname): - - data_args = configs.DataArguments() - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - TRAIN_ARGS = configs.TrainingArguments( - packing=False, - max_seq_length=1024, - output_dir="tmp", # Not needed but positional - ) - - with tempfile.NamedTemporaryFile( - "w", delete=False, suffix=".yaml" - ) as temp_yaml_file: - with open(datasetconfigname, "r") as f: - data = yaml.safe_load(f) - datasets = data["datasets"] - for i in range(len(datasets)): - d = datasets[i] - d["data_paths"][0] = datafiles[i] - yaml.dump(data, temp_yaml_file) - data_args.data_config_path = temp_yaml_file.name - - (train_set, eval_set, _, _, _, _) = process_dataargs( - data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS - ) - - assert isinstance(train_set, Dataset) - if eval_set: - assert isinstance(eval_set, Dataset) - - assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) - if eval_set: - assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names)) From 68a0f503ff88fc495ec87676e5e98219b83c5c84 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 12 Dec 2024 18:41:51 -0500 Subject: [PATCH 6/7] fix: fmt Signed-off-by: Abhishek --- tests/data/test_data_preprocessing_utils.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 863f6ac7a..8edb085c0 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -23,9 +23,7 @@ import datasets import pytest import yaml -import sys -import os -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) + # First Party from tests.artifacts.predefined_data_configs import ( DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, @@ -637,7 +635,9 @@ def test_process_dataconfig_multiple_files(data_config_path, data_path_list): ), ], ) -def test_process_dataconfig_multiple_datasets_datafiles_sampling(datafiles, datasetconfigname): +def test_process_dataconfig_multiple_datasets_datafiles_sampling( + datafiles, datasetconfigname +): """Ensure that multiple datasets with multiple files are formatted and validated correctly.""" with open(datasetconfigname, "r") as f: yaml_content = yaml.safe_load(f) @@ -666,9 +666,13 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(datafiles, data if eval_set: assert isinstance(eval_set, Dataset) - assert set(["input_ids", "attention_mask", "labels"]).issubset(set(train_set.column_names)) + assert set(["input_ids", "attention_mask", "labels"]).issubset( + set(train_set.column_names) + ) if eval_set: - assert set(["input_ids", "attention_mask", "labels"]).issubset(set(eval_set.column_names)) + assert set(["input_ids", "attention_mask", "labels"]).issubset( + set(eval_set.column_names) + ) @pytest.mark.parametrize( From 5905e235ff7175ef19b00bc83496a87ef702d974 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 12 Dec 2024 18:58:24 -0500 Subject: [PATCH 7/7] Merge test_process_dataconfig_multiple_files_varied_data_formats Signed-off-by: Abhishek --- tests/data/test_data_preprocessing_utils.py | 68 +++++---------------- 1 file changed, 14 insertions(+), 54 deletions(-) diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 8edb085c0..5559ac8ec 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -682,6 +682,10 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_PARQUET], ), + ( + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], + ), ( DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, [ @@ -691,63 +695,17 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling( ], ), ( - DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, [ - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_TOKENIZED_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, ], ), - ], -) -def test_process_dataconfig_multiple_files_varied_data_formats( - data_config_path, data_path_list -): - """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" - with open(data_config_path, "r") as f: - yaml_content = yaml.safe_load(f) - yaml_content["datasets"][0]["data_paths"] = data_path_list - datasets_name = yaml_content["datasets"][0]["name"] - - # Modify input_field_name and output_field_name according to dataset - if datasets_name == "text_dataset_input_output_masking": - yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { - "input_field_name": "input", - "output_field_name": "output", - } - - # Modify dataset_text_field and template according to dataset - formatted_dataset_field = "formatted_data_field" - if datasets_name == "apply_custom_data_template": - template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" - yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { - "dataset_text_field": formatted_dataset_field, - "template": template, - } - - with tempfile.NamedTemporaryFile( - "w", delete=False, suffix=".yaml" - ) as temp_yaml_file: - yaml.dump(yaml_content, temp_yaml_file) - temp_yaml_file_path = temp_yaml_file.name - data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) - - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - with pytest.raises(AssertionError): - (_, _, _) = _process_dataconfig_file(data_args, tokenizer) - - -@pytest.mark.parametrize( - "data_config_path, data_path_list", - [ - ( - DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, - [TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], - ), ( - DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, [ - TWITTER_COMPLAINTS_TOKENIZED_JSON, - TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, ], ), ( @@ -756,7 +714,7 @@ def test_process_dataconfig_multiple_files_varied_data_formats( ), ], ) -def test_process_dataconfig_multiple_files_varied_types( +def test_process_dataconfig_multiple_files_varied_data_formats( data_config_path, data_path_list ): """Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" @@ -789,7 +747,9 @@ def test_process_dataconfig_multiple_files_varied_types( data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - with pytest.raises(datasets.exceptions.DatasetGenerationCastError): + with pytest.raises( + (AssertionError, datasets.exceptions.DatasetGenerationCastError) + ): (_, _, _) = _process_dataconfig_file(data_args, tokenizer)