-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
test: Add unit tests to test multiple files in single dataset #412
Changes from 2 commits
ce82af1
3fe7425
e89002d
4ba1c04
3fce172
68a0f50
5905e23
6f13d9a
83d0127
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -432,9 +432,11 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): | |
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), | ||
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), | ||
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), | ||
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW), | ||
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), | ||
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), | ||
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET), | ||
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_ARROW), | ||
( | ||
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, | ||
|
@@ -447,6 +449,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): | |
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, | ||
), | ||
( | ||
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, | ||
), | ||
], | ||
) | ||
def test_process_dataconfig_file(data_config_path, data_path): | ||
|
@@ -491,6 +497,234 @@ def test_process_dataconfig_file(data_config_path, data_path): | |
assert formatted_dataset_field in set(train_set.column_names) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data_config_path, list_data_path", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: could we change the variable name |
||
[ | ||
( | ||
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, | ||
[TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSON], | ||
), | ||
( | ||
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, | ||
[TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_JSONL], | ||
), | ||
( | ||
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, | ||
[TWITTER_COMPLAINTS_DATA_PARQUET, TWITTER_COMPLAINTS_DATA_PARQUET], | ||
), | ||
( | ||
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, | ||
[TWITTER_COMPLAINTS_DATA_ARROW, TWITTER_COMPLAINTS_DATA_ARROW], | ||
), | ||
( | ||
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
[TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], | ||
), | ||
( | ||
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
[TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_JSONL], | ||
), | ||
( | ||
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
[ | ||
TWITTER_COMPLAINTS_TOKENIZED_PARQUET, | ||
TWITTER_COMPLAINTS_TOKENIZED_PARQUET, | ||
], | ||
), | ||
( | ||
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
[TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_ARROW], | ||
), | ||
( | ||
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
[ | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, | ||
], | ||
), | ||
( | ||
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
[ | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, | ||
], | ||
), | ||
( | ||
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
[ | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, | ||
], | ||
), | ||
( | ||
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
[ | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, | ||
], | ||
), | ||
], | ||
) | ||
def test_process_dataconfig_multiple_files(data_config_path, list_data_path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be worth adding a test with three files just in case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As this test case already have multiple cases, added case with 3 files in same unit test for all 3 handlers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can reduce the test cases number and possibly just have
I think 3-4 scenrios should be fine, rest are anyway similar There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see there is one with varied data formats below...so maybe just a reduction of number of tests here could also work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have optimized the unit tests based on below comments. If we still need to optimize more do let me know here. @dushyantbehl |
||
"""Ensure that datasets with multiple files are formatted and validated correctly based on the arguments passed in config file.""" | ||
with open(data_config_path, "r") as f: | ||
yaml_content = yaml.safe_load(f) | ||
yaml_content["datasets"][0]["data_paths"] = list_data_path | ||
datasets_name = yaml_content["datasets"][0]["name"] | ||
|
||
# Modify input_field_name and output_field_name according to dataset | ||
if datasets_name == "text_dataset_input_output_masking": | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
"input_field_name": "input", | ||
"output_field_name": "output", | ||
} | ||
|
||
# Modify dataset_text_field and template according to dataset | ||
formatted_dataset_field = "formatted_data_field" | ||
if datasets_name == "apply_custom_data_template": | ||
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
"dataset_text_field": formatted_dataset_field, | ||
"template": template, | ||
} | ||
|
||
with tempfile.NamedTemporaryFile( | ||
"w", delete=False, suffix=".yaml" | ||
) as temp_yaml_file: | ||
yaml.dump(yaml_content, temp_yaml_file) | ||
temp_yaml_file_path = temp_yaml_file.name | ||
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) | ||
assert isinstance(train_set, Dataset) | ||
if datasets_name == "text_dataset_input_output_masking": | ||
column_names = set(["input_ids", "attention_mask", "labels"]) | ||
assert set(train_set.column_names) == column_names | ||
elif datasets_name == "pretokenized_dataset": | ||
assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) | ||
elif datasets_name == "apply_custom_data_template": | ||
assert formatted_dataset_field in set(train_set.column_names) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data_config_path, list_data_path", | ||
[ | ||
( | ||
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, | ||
[TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_PARQUET], | ||
), | ||
( | ||
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
[ | ||
TWITTER_COMPLAINTS_TOKENIZED_JSONL, | ||
TWITTER_COMPLAINTS_TOKENIZED_ARROW, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, | ||
], | ||
), | ||
( | ||
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
[ | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, | ||
], | ||
), | ||
], | ||
) | ||
def test_process_dataconfig_multiple_files_varied_data_formats( | ||
data_config_path, list_data_path | ||
): | ||
"""Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" | ||
with open(data_config_path, "r") as f: | ||
yaml_content = yaml.safe_load(f) | ||
yaml_content["datasets"][0]["data_paths"] = list_data_path | ||
datasets_name = yaml_content["datasets"][0]["name"] | ||
|
||
# Modify input_field_name and output_field_name according to dataset | ||
if datasets_name == "text_dataset_input_output_masking": | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
"input_field_name": "input", | ||
"output_field_name": "output", | ||
} | ||
|
||
# Modify dataset_text_field and template according to dataset | ||
formatted_dataset_field = "formatted_data_field" | ||
if datasets_name == "apply_custom_data_template": | ||
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
"dataset_text_field": formatted_dataset_field, | ||
"template": template, | ||
} | ||
|
||
with tempfile.NamedTemporaryFile( | ||
"w", delete=False, suffix=".yaml" | ||
) as temp_yaml_file: | ||
yaml.dump(yaml_content, temp_yaml_file) | ||
temp_yaml_file_path = temp_yaml_file.name | ||
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
with pytest.raises(AssertionError): | ||
(_, _, _) = _process_dataconfig_file(data_args, tokenizer) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data_config_path, list_data_path", | ||
[ | ||
( | ||
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, | ||
[TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON], | ||
), | ||
( | ||
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, | ||
[ | ||
TWITTER_COMPLAINTS_TOKENIZED_JSON, | ||
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, | ||
], | ||
), | ||
( | ||
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, | ||
[TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_JSON], | ||
), | ||
], | ||
) | ||
def test_process_dataconfig_multiple_files_varied_types( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be combined with the above test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, this is added! Thank you! |
||
data_config_path, list_data_path | ||
): | ||
"""Ensure that datasets with multiple files with different formats raise assertion error when passed in config file.""" | ||
with open(data_config_path, "r") as f: | ||
yaml_content = yaml.safe_load(f) | ||
yaml_content["datasets"][0]["data_paths"] = list_data_path | ||
datasets_name = yaml_content["datasets"][0]["name"] | ||
|
||
# Modify input_field_name and output_field_name according to dataset | ||
if datasets_name == "text_dataset_input_output_masking": | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
"input_field_name": "input", | ||
"output_field_name": "output", | ||
} | ||
|
||
# Modify dataset_text_field and template according to dataset | ||
formatted_dataset_field = "formatted_data_field" | ||
if datasets_name == "apply_custom_data_template": | ||
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}" | ||
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = { | ||
"dataset_text_field": formatted_dataset_field, | ||
"template": template, | ||
} | ||
|
||
with tempfile.NamedTemporaryFile( | ||
"w", delete=False, suffix=".yaml" | ||
) as temp_yaml_file: | ||
yaml.dump(yaml_content, temp_yaml_file) | ||
temp_yaml_file_path = temp_yaml_file.name | ||
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
with pytest.raises(datasets.exceptions.DatasetGenerationCastError): | ||
(_, _, _) = _process_dataconfig_file(data_args, tokenizer) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data_args", | ||
[ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this segregation.