Skip to content

Commit

Permalink
test: multiple datasets with multiple datafiles column names
Browse files Browse the repository at this point in the history
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
  • Loading branch information
willmj committed Dec 11, 2024
1 parent 3fe7425 commit e89002d
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,50 @@ def test_process_dataconfig_multiple_files(data_config_path, list_data_path):
assert formatted_dataset_field in set(train_set.column_names)


@pytest.mark.parametrize(
"datafiles, datasetconfigname",
[
(
[
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
],
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
],
],
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
),
],
)
def test_process_dataconfig_multiple_datasets_datafiles(datafiles, datasetconfigname):
"""Ensure that multiple datasets with multiple files are formatted and validated correctly."""
with open(datasetconfigname, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"] = datafiles[0]
yaml_content["datasets"][1]["data_paths"] = datafiles[1]
yaml_content["datasets"][2]["data_paths"] = datafiles[2]

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)
assert isinstance(train_set, Dataset)
column_names = set(["input_ids", "attention_mask", "labels"])
assert set(train_set.column_names) == column_names


@pytest.mark.parametrize(
"data_config_path, list_data_path",
[
Expand Down

0 comments on commit e89002d

Please # to comment.