Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

test: Add support and unit tests for handling of multiple files passed as a pattern in data_config #416

Closed
Closed
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ indent-string=' '
max-line-length=100

# Maximum number of lines in a module.
max-module-lines=1200
max-module-lines=1400

# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
Expand Down
10 changes: 10 additions & 0 deletions tests/artifacts/testdata/jsonl/twitter_complaints_small_2.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{"Tweet text":"@NortonSupport Thanks much.","ID":10,"Label":2,"text_label":"no complaint","output":"### Text: @NortonSupport Thanks much.\n\n### Label: no complaint"}
{"Tweet text":"@VerizonSupport all of a sudden I can't connect to my primary wireless network but guest one works","ID":11,"Label":2,"text_label":"no complaint","output":"### Text: @VerizonSupport all of a sudden I can't connect to my primary wireless network but guest one works\n\n### Label: no complaint"}
{"Tweet text":"Aaaahhhhh!!!! My @Razer @PlayOverwatch d.va meka headset came in!!! I didn't even know it had shipped!!! So excited\u2026 https:\/\/t.co\/4gXy9xED8d","ID":12,"Label":2,"text_label":"no complaint","output":"### Text: Aaaahhhhh!!!! My @Razer @PlayOverwatch d.va meka headset came in!!! I didn't even know it had shipped!!! So excited\u2026 https:\/\/t.co\/4gXy9xED8d\n\n### Label: no complaint"}
{"Tweet text":"@Lin_Manuel @jmessinaphoto @VAMNit Omg a little squish!!!!! Enjoy and congrats!!!! I miss mine being so young! \ufffd\ufffd\ufffd\ufffd\ufffd\ufffd","ID":13,"Label":2,"text_label":"no complaint","output":"### Text: @Lin_Manuel @jmessinaphoto @VAMNit Omg a little squish!!!!! Enjoy and congrats!!!! I miss mine being so young! \ufffd\ufffd\ufffd\ufffd\ufffd\ufffd\n\n### Label: no complaint"}
{"Tweet text":"@IanJamesPoulter What's your secret to poaching eggs? Mine NEVER look that good.","ID":14,"Label":2,"text_label":"no complaint","output":"### Text: @IanJamesPoulter What's your secret to poaching eggs? Mine NEVER look that good.\n\n### Label: no complaint"}
{"Tweet text":"@AWSSupport When will be able Kinesis Firehose compatible with Elasticsearch 6.0? Thank you!","ID":15,"Label":2,"text_label":"no complaint","output":"### Text: @AWSSupport When will be able Kinesis Firehose compatible with Elasticsearch 6.0? Thank you!\n\n### Label: no complaint"}
{"Tweet text":"@NCIS_CBS https:\/\/t.co\/eeVL9Eu3bE","ID":16,"Label":2,"text_label":"no complaint","output":"### Text: @NCIS_CBS https:\/\/t.co\/eeVL9Eu3bE\n\n### Label: no complaint"}
{"Tweet text":"@msetchell Via the settings? That\u2019s how I do it on master T\u2019s","ID":17,"Label":2,"text_label":"no complaint","output":"### Text: @msetchell Via the settings? That\u2019s how I do it on master T\u2019s\n\n### Label: no complaint"}
{"Tweet text":"Today at work there was a low flying duck heading toward a crowd of people, and I yelled \"watch out! and I'm very disappointed with myself.","ID":18,"Label":2,"text_label":"no complaint","output":"### Text: Today at work there was a low flying duck heading toward a crowd of people, and I yelled \"watch out! and I'm very disappointed with myself.\n\n### Label: no complaint"}
{"Tweet text":"@NortonSupport @NortonOnline What the hell is a dm 5-10 days to get money back bank account now overdrawn thanks guys","ID":19,"Label":1,"text_label":"complaint","output":"### Text: @NortonSupport @NortonOnline What the hell is a dm 5-10 days to get money back bank account now overdrawn thanks guys\n\n### Label: complaint"}
50 changes: 50 additions & 0 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

# Standard
import glob
import json
import os
import tempfile

# Third Party
Expand Down Expand Up @@ -613,6 +615,54 @@ def test_process_dataconfig_multiple_files(data_config_path, data_path_list):
assert formatted_dataset_field in set(train_set.column_names)


@pytest.mark.parametrize(
"data_config_path, data_path",
[
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
os.path.join(
os.path.dirname(TWITTER_COMPLAINTS_DATA_JSONL), "*small*.jsonl"
),
),
],
)
def test_process_dataconfig_multiple_files_with_globbing(data_config_path, data_path):
"""Ensure that datasets files matching globbing pattern are formatted and validated correctly based on the arguments passed in config file."""
with open(data_config_path, "r") as f:
yaml_content = yaml.safe_load(f)

PATTERN_TWITTER_COMPLAINTS_DATA_JSONL = data_path.replace(
"twitter_complaints_small.jsonl", "*small*.jsonl"
)
yaml_content["datasets"][0]["data_paths"][0] = PATTERN_TWITTER_COMPLAINTS_DATA_JSONL

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
"template": template,
}

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)
assert isinstance(train_set, Dataset)
assert formatted_dataset_field in set(train_set.column_names)

data_len = sum(
sum(1 for _ in open(file, "r")) # Count lines in each JSONL file
for file in glob.glob(PATTERN_TWITTER_COMPLAINTS_DATA_JSONL)
)
assert len(train_set) == data_len


@pytest.mark.parametrize(
"datafiles, datasetconfigname",
[
Expand Down
1 change: 0 additions & 1 deletion tuning/data/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig:
c.data_paths = []
for p in data_paths:
assert isinstance(p, str), f"path {p} should be of the type string"
assert os.path.exists(p), f"data_paths {p} does not exist"
if not os.path.isabs(p):
_p = os.path.abspath(p)
logging.warning(
Expand Down
4 changes: 2 additions & 2 deletions tuning/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def get_loader_for_filepath(file_path: str) -> str:
return "text"
if ext in (".json", ".jsonl"):
return "json"
if ext in (".arrow"):
if ext in (".arrow",):
return "arrow"
if ext in (".parquet"):
if ext in (".parquet",):
return "parquet"
return ext

Expand Down
Loading