Skip to content

Commit

Permalink
test: Add unit tests to test multiple files in single/multiple datase…
Browse files Browse the repository at this point in the history
…ts (#412)

* test: Add unit tests to test multiple files in single/multiple datasets

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* e2e testing unit test for multiple datasets with multiple files

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* test: multiple datasets with multiple datafiles column names

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* PR changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: fmt

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* Merge test_process_dataconfig_multiple_files_varied_data_formats

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

---------

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Co-authored-by: Will Johnson <mwjohnson728@gmail.com>
  • Loading branch information
Abhishek-TAMU and willmj authored Dec 13, 2024
1 parent a89f76b commit 4441948
Show file tree
Hide file tree
Showing 15 changed files with 380 additions and 60 deletions.
32 changes: 21 additions & 11 deletions tests/artifacts/testdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,47 @@

### Constants used for data
DATA_DIR = os.path.join(os.path.dirname(__file__))
JSON_DATA_DIR = os.path.join(os.path.dirname(__file__), "json")
JSONL_DATA_DIR = os.path.join(os.path.dirname(__file__), "jsonl")
ARROW_DATA_DIR = os.path.join(os.path.dirname(__file__), "arrow")
PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet")
TWITTER_COMPLAINTS_DATA_JSON = os.path.join(DATA_DIR, "twitter_complaints_small.json")
TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl")
TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(DATA_DIR, "twitter_complaints_small.arrow")

TWITTER_COMPLAINTS_DATA_JSON = os.path.join(
JSON_DATA_DIR, "twitter_complaints_small.json"
)
TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(
JSONL_DATA_DIR, "twitter_complaints_small.jsonl"
)
TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(
ARROW_DATA_DIR, "twitter_complaints_small.arrow"
)
TWITTER_COMPLAINTS_DATA_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_small.parquet"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON = os.path.join(
DATA_DIR, "twitter_complaints_input_output.json"
JSON_DATA_DIR, "twitter_complaints_input_output.json"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join(
DATA_DIR, "twitter_complaints_input_output.jsonl"
JSONL_DATA_DIR, "twitter_complaints_input_output.jsonl"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW = os.path.join(
DATA_DIR, "twitter_complaints_input_output.arrow"
ARROW_DATA_DIR, "twitter_complaints_input_output.arrow"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_input_output.parquet"
)
TWITTER_COMPLAINTS_TOKENIZED_JSON = os.path.join(
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json"
JSON_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json"
)
TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join(
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl"
JSONL_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl"
)
TWITTER_COMPLAINTS_TOKENIZED_ARROW = os.path.join(
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow"
ARROW_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow"
)
TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet"
)
EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json")
MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json")
EMPTY_DATA = os.path.join(JSON_DATA_DIR, "empty_data.json")
MALFORMATTED_DATA = os.path.join(JSON_DATA_DIR, "malformatted_data.json")
MODEL_NAME = "Maykeye/TinyLLama-v0"
File renamed without changes.
310 changes: 262 additions & 48 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,11 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_ARROW),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
Expand All @@ -447,6 +449,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
),
],
)
def test_process_dataconfig_file(data_config_path, data_path):
Expand Down Expand Up @@ -491,6 +497,262 @@ def test_process_dataconfig_file(data_config_path, data_path):
assert formatted_dataset_field in set(train_set.column_names)


@pytest.mark.parametrize(
"data_config_path, data_path_list",
[
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSON],
),
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[
TWITTER_COMPLAINTS_DATA_JSONL,
TWITTER_COMPLAINTS_DATA_JSONL,
TWITTER_COMPLAINTS_DATA_JSONL,
],
),
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_PARQUET, TWITTER_COMPLAINTS_DATA_PARQUET],
),
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_ARROW, TWITTER_COMPLAINTS_DATA_ARROW],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_JSONL],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_ARROW],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
],
),
],
)
def test_process_dataconfig_multiple_files(data_config_path, data_path_list):
"""Ensure that datasets with multiple files are formatted and validated correctly based on the arguments passed in config file."""
with open(data_config_path, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"] = data_path_list
datasets_name = yaml_content["datasets"][0]["name"]

# Modify input_field_name and output_field_name according to dataset
if datasets_name == "text_dataset_input_output_masking":
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"input_field_name": "input",
"output_field_name": "output",
}

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name == "apply_custom_data_template":
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
"template": template,
}

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)
assert isinstance(train_set, Dataset)
if datasets_name == "text_dataset_input_output_masking":
column_names = set(["input_ids", "attention_mask", "labels"])
assert set(train_set.column_names) == column_names
elif datasets_name == "pretokenized_dataset":
assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
elif datasets_name == "apply_custom_data_template":
assert formatted_dataset_field in set(train_set.column_names)


@pytest.mark.parametrize(
"datafiles, datasetconfigname",
[
(
[
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
],
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
],
],
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
),
],
)
def test_process_dataconfig_multiple_datasets_datafiles_sampling(
datafiles, datasetconfigname
):
"""Ensure that multiple datasets with multiple files are formatted and validated correctly."""
with open(datasetconfigname, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"] = datafiles[0]
yaml_content["datasets"][1]["data_paths"] = datafiles[1]
yaml_content["datasets"][2]["data_paths"] = datafiles[2]

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
TRAIN_ARGS = configs.TrainingArguments(
packing=False,
max_seq_length=1024,
output_dir="tmp",
)
(train_set, eval_set, _, _, _, _) = process_dataargs(
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
)

assert isinstance(train_set, Dataset)
if eval_set:
assert isinstance(eval_set, Dataset)

assert set(["input_ids", "attention_mask", "labels"]).issubset(
set(train_set.column_names)
)
if eval_set:
assert set(["input_ids", "attention_mask", "labels"]).issubset(
set(eval_set.column_names)
)


@pytest.mark.parametrize(
"data_config_path, data_path_list",
[
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_PARQUET],
),
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[
TWITTER_COMPLAINTS_TOKENIZED_JSON,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_JSON],
),
],
)
def test_process_dataconfig_multiple_files_varied_data_formats(
data_config_path, data_path_list
):
"""Ensure that datasets with multiple files with different formats raise assertion error when passed in config file."""
with open(data_config_path, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"] = data_path_list
datasets_name = yaml_content["datasets"][0]["name"]

# Modify input_field_name and output_field_name according to dataset
if datasets_name == "text_dataset_input_output_masking":
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"input_field_name": "input",
"output_field_name": "output",
}

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name == "apply_custom_data_template":
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
"template": template,
}

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
with pytest.raises(
(AssertionError, datasets.exceptions.DatasetGenerationCastError)
):
(_, _, _) = _process_dataconfig_file(data_args, tokenizer)


@pytest.mark.parametrize(
"data_args",
[
Expand Down Expand Up @@ -764,51 +1026,3 @@ def test_process_dataset_configs_with_sampling_error(
(_, _, _, _, _, _) = process_dataargs(
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
)


@pytest.mark.parametrize(
"datafiles, datasetconfigname",
[
(
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
),
],
)
def test_process_dataset_configs_with_sampling(datafiles, datasetconfigname):

data_args = configs.DataArguments()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
TRAIN_ARGS = configs.TrainingArguments(
packing=False,
max_seq_length=1024,
output_dir="tmp", # Not needed but positional
)

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
with open(datasetconfigname, "r") as f:
data = yaml.safe_load(f)
datasets = data["datasets"]
for i in range(len(datasets)):
d = datasets[i]
d["data_paths"][0] = datafiles[i]
yaml.dump(data, temp_yaml_file)
data_args.data_config_path = temp_yaml_file.name

(train_set, eval_set, _, _, _, _) = process_dataargs(
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
)

assert isinstance(train_set, Dataset)
if eval_set:
assert isinstance(eval_set, Dataset)

assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
if eval_set:
assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names))
Loading

0 comments on commit 4441948

Please # to comment.