Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

test: Add unit tests to test multiple files in single dataset #412

Merged
32 changes: 21 additions & 11 deletions tests/artifacts/testdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,47 @@

### Constants used for data
DATA_DIR = os.path.join(os.path.dirname(__file__))
JSON_DATA_DIR = os.path.join(os.path.dirname(__file__), "json")
Copy link
Contributor

@dushyantbehl dushyantbehl Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this segregation.

JSONL_DATA_DIR = os.path.join(os.path.dirname(__file__), "jsonl")
ARROW_DATA_DIR = os.path.join(os.path.dirname(__file__), "arrow")
PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet")
TWITTER_COMPLAINTS_DATA_JSON = os.path.join(DATA_DIR, "twitter_complaints_small.json")
TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl")
TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(DATA_DIR, "twitter_complaints_small.arrow")

TWITTER_COMPLAINTS_DATA_JSON = os.path.join(
JSON_DATA_DIR, "twitter_complaints_small.json"
)
TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(
JSONL_DATA_DIR, "twitter_complaints_small.jsonl"
)
TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(
ARROW_DATA_DIR, "twitter_complaints_small.arrow"
)
TWITTER_COMPLAINTS_DATA_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_small.parquet"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON = os.path.join(
DATA_DIR, "twitter_complaints_input_output.json"
JSON_DATA_DIR, "twitter_complaints_input_output.json"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join(
DATA_DIR, "twitter_complaints_input_output.jsonl"
JSONL_DATA_DIR, "twitter_complaints_input_output.jsonl"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW = os.path.join(
DATA_DIR, "twitter_complaints_input_output.arrow"
ARROW_DATA_DIR, "twitter_complaints_input_output.arrow"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_input_output.parquet"
)
TWITTER_COMPLAINTS_TOKENIZED_JSON = os.path.join(
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json"
JSON_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json"
)
TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join(
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl"
JSONL_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl"
)
TWITTER_COMPLAINTS_TOKENIZED_ARROW = os.path.join(
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow"
ARROW_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow"
)
TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet"
)
EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json")
MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json")
EMPTY_DATA = os.path.join(JSON_DATA_DIR, "empty_data.json")
MALFORMATTED_DATA = os.path.join(JSON_DATA_DIR, "malformatted_data.json")
MODEL_NAME = "Maykeye/TinyLLama-v0"
234 changes: 234 additions & 0 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,11 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_ARROW),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
Expand All @@ -447,6 +449,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
),
],
)
def test_process_dataconfig_file(data_config_path, data_path):
Expand Down Expand Up @@ -491,6 +497,234 @@ def test_process_dataconfig_file(data_config_path, data_path):
assert formatted_dataset_field in set(train_set.column_names)


@pytest.mark.parametrize(
"data_config_path, list_data_path",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we change the variable name list_data_path to data_path_list

[
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSON],
),
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_JSONL],
),
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_PARQUET, TWITTER_COMPLAINTS_DATA_PARQUET],
),
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_ARROW, TWITTER_COMPLAINTS_DATA_ARROW],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_JSONL],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_ARROW],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
],
),
],
)
def test_process_dataconfig_multiple_files(data_config_path, list_data_path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding a test with three files just in case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this test case already have multiple cases, added case with 3 files in same unit test for all 3 handlers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can reduce the test cases number and possibly just have

  1. Mix of all three -> 1 test
  2. each dataset multiple files, either 2 or three, maybe in a random mix

I think 3-4 scenrios should be fine, rest are anyway similar

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see there is one with varied data formats below...so maybe just a reduction of number of tests here could also work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have optimized the unit tests based on below comments. If we still need to optimize more do let me know here. @dushyantbehl

"""Ensure that datasets with multiple files are formatted and validated correctly based on the arguments passed in config file."""
with open(data_config_path, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"] = list_data_path
datasets_name = yaml_content["datasets"][0]["name"]

# Modify input_field_name and output_field_name according to dataset
if datasets_name == "text_dataset_input_output_masking":
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"input_field_name": "input",
"output_field_name": "output",
}

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name == "apply_custom_data_template":
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
"template": template,
}

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)
assert isinstance(train_set, Dataset)
if datasets_name == "text_dataset_input_output_masking":
column_names = set(["input_ids", "attention_mask", "labels"])
assert set(train_set.column_names) == column_names
elif datasets_name == "pretokenized_dataset":
assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
elif datasets_name == "apply_custom_data_template":
assert formatted_dataset_field in set(train_set.column_names)


@pytest.mark.parametrize(
"data_config_path, list_data_path",
[
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_PARQUET],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
],
),
],
)
def test_process_dataconfig_multiple_files_varied_data_formats(
data_config_path, list_data_path
):
"""Ensure that datasets with multiple files with different formats raise assertion error when passed in config file."""
with open(data_config_path, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"] = list_data_path
datasets_name = yaml_content["datasets"][0]["name"]

# Modify input_field_name and output_field_name according to dataset
if datasets_name == "text_dataset_input_output_masking":
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"input_field_name": "input",
"output_field_name": "output",
}

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name == "apply_custom_data_template":
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
"template": template,
}

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
with pytest.raises(AssertionError):
(_, _, _) = _process_dataconfig_file(data_args, tokenizer)


@pytest.mark.parametrize(
"data_config_path, list_data_path",
[
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[
TWITTER_COMPLAINTS_TOKENIZED_JSON,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_JSON],
),
],
)
def test_process_dataconfig_multiple_files_varied_types(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be combined with the above test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, this is added! Thank you!

data_config_path, list_data_path
):
"""Ensure that datasets with multiple files with different formats raise assertion error when passed in config file."""
with open(data_config_path, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"] = list_data_path
datasets_name = yaml_content["datasets"][0]["name"]

# Modify input_field_name and output_field_name according to dataset
if datasets_name == "text_dataset_input_output_masking":
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"input_field_name": "input",
"output_field_name": "output",
}

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name == "apply_custom_data_template":
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
"template": template,
}

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
with pytest.raises(datasets.exceptions.DatasetGenerationCastError):
(_, _, _) = _process_dataconfig_file(data_args, tokenizer)


@pytest.mark.parametrize(
"data_args",
[
Expand Down
Loading
Loading