From e89002de928f419601038b3b172c99a3d0b62429 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 11 Dec 2024 15:48:46 -0500 Subject: [PATCH] test: multiple datasets with multiple datafiles column names Signed-off-by: Will Johnson --- tests/data/test_data_preprocessing_utils.py | 44 +++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index fbb73f649..937fc5a70 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -607,6 +607,50 @@ def test_process_dataconfig_multiple_files(data_config_path, list_data_path): assert formatted_dataset_field in set(train_set.column_names) +@pytest.mark.parametrize( + "datafiles, datasetconfigname", + [ + ( + [ + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, + ], + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + ], + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), + ], +) +def test_process_dataconfig_multiple_datasets_datafiles(datafiles, datasetconfigname): + """Ensure that multiple datasets with multiple files are formatted and validated correctly.""" + with open(datasetconfigname, "r") as f: + yaml_content = yaml.safe_load(f) + yaml_content["datasets"][0]["data_paths"] = datafiles[0] + yaml_content["datasets"][1]["data_paths"] = datafiles[1] + yaml_content["datasets"][2]["data_paths"] = datafiles[2] + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + yaml.dump(yaml_content, temp_yaml_file) + temp_yaml_file_path = temp_yaml_file.name + data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) + assert isinstance(train_set, Dataset) + column_names = set(["input_ids", "attention_mask", "labels"]) + assert set(train_set.column_names) == column_names + + @pytest.mark.parametrize( "data_config_path, list_data_path", [