Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

test: Add unit tests to test multiple files in single dataset #412

Merged
32 changes: 21 additions & 11 deletions tests/artifacts/testdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,47 @@

### Constants used for data
DATA_DIR = os.path.join(os.path.dirname(__file__))
JSON_DATA_DIR = os.path.join(os.path.dirname(__file__), "json")
Copy link
Contributor

@dushyantbehl dushyantbehl Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this segregation.

JSONL_DATA_DIR = os.path.join(os.path.dirname(__file__), "jsonl")
ARROW_DATA_DIR = os.path.join(os.path.dirname(__file__), "arrow")
PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet")
TWITTER_COMPLAINTS_DATA_JSON = os.path.join(DATA_DIR, "twitter_complaints_small.json")
TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl")
TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(DATA_DIR, "twitter_complaints_small.arrow")

TWITTER_COMPLAINTS_DATA_JSON = os.path.join(
JSON_DATA_DIR, "twitter_complaints_small.json"
)
TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(
JSONL_DATA_DIR, "twitter_complaints_small.jsonl"
)
TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(
ARROW_DATA_DIR, "twitter_complaints_small.arrow"
)
TWITTER_COMPLAINTS_DATA_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_small.parquet"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON = os.path.join(
DATA_DIR, "twitter_complaints_input_output.json"
JSON_DATA_DIR, "twitter_complaints_input_output.json"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join(
DATA_DIR, "twitter_complaints_input_output.jsonl"
JSONL_DATA_DIR, "twitter_complaints_input_output.jsonl"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW = os.path.join(
DATA_DIR, "twitter_complaints_input_output.arrow"
ARROW_DATA_DIR, "twitter_complaints_input_output.arrow"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_input_output.parquet"
)
TWITTER_COMPLAINTS_TOKENIZED_JSON = os.path.join(
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json"
JSON_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json"
)
TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join(
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl"
JSONL_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl"
)
TWITTER_COMPLAINTS_TOKENIZED_ARROW = os.path.join(
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow"
ARROW_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow"
)
TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet"
)
EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json")
MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json")
EMPTY_DATA = os.path.join(JSON_DATA_DIR, "empty_data.json")
MALFORMATTED_DATA = os.path.join(JSON_DATA_DIR, "malformatted_data.json")
MODEL_NAME = "Maykeye/TinyLLama-v0"
310 changes: 262 additions & 48 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,11 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_ARROW),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_ARROW),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
Expand All @@ -447,6 +449,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
),
],
)
def test_process_dataconfig_file(data_config_path, data_path):
Expand Down Expand Up @@ -491,6 +497,262 @@ def test_process_dataconfig_file(data_config_path, data_path):
assert formatted_dataset_field in set(train_set.column_names)


@pytest.mark.parametrize(
"data_config_path, data_path_list",
[
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSON],
),
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[
TWITTER_COMPLAINTS_DATA_JSONL,
TWITTER_COMPLAINTS_DATA_JSONL,
TWITTER_COMPLAINTS_DATA_JSONL,
],
),
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_PARQUET, TWITTER_COMPLAINTS_DATA_PARQUET],
),
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_ARROW, TWITTER_COMPLAINTS_DATA_ARROW],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_JSONL],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_ARROW],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
],
),
],
)
def test_process_dataconfig_multiple_files(data_config_path, data_path_list):
"""Ensure that datasets with multiple files are formatted and validated correctly based on the arguments passed in config file."""
with open(data_config_path, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"] = data_path_list
datasets_name = yaml_content["datasets"][0]["name"]

# Modify input_field_name and output_field_name according to dataset
if datasets_name == "text_dataset_input_output_masking":
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"input_field_name": "input",
"output_field_name": "output",
}

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name == "apply_custom_data_template":
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
"template": template,
}

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)
assert isinstance(train_set, Dataset)
if datasets_name == "text_dataset_input_output_masking":
column_names = set(["input_ids", "attention_mask", "labels"])
assert set(train_set.column_names) == column_names
elif datasets_name == "pretokenized_dataset":
assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
elif datasets_name == "apply_custom_data_template":
assert formatted_dataset_field in set(train_set.column_names)


@pytest.mark.parametrize(
"datafiles, datasetconfigname",
[
(
[
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
],
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
],
],
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
),
],
)
def test_process_dataconfig_multiple_datasets_datafiles_sampling(
datafiles, datasetconfigname
):
"""Ensure that multiple datasets with multiple files are formatted and validated correctly."""
with open(datasetconfigname, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"] = datafiles[0]
yaml_content["datasets"][1]["data_paths"] = datafiles[1]
yaml_content["datasets"][2]["data_paths"] = datafiles[2]

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
TRAIN_ARGS = configs.TrainingArguments(
packing=False,
max_seq_length=1024,
output_dir="tmp",
)
(train_set, eval_set, _, _, _, _) = process_dataargs(
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
)

assert isinstance(train_set, Dataset)
if eval_set:
assert isinstance(eval_set, Dataset)

assert set(["input_ids", "attention_mask", "labels"]).issubset(
set(train_set.column_names)
)
if eval_set:
assert set(["input_ids", "attention_mask", "labels"]).issubset(
set(eval_set.column_names)
)


@pytest.mark.parametrize(
"data_config_path, data_path_list",
[
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_PARQUET],
),
(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
[TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSON],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
),
(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
[
TWITTER_COMPLAINTS_TOKENIZED_JSON,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
],
),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
[TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_JSON],
),
],
)
def test_process_dataconfig_multiple_files_varied_data_formats(
data_config_path, data_path_list
):
"""Ensure that datasets with multiple files with different formats raise assertion error when passed in config file."""
with open(data_config_path, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"] = data_path_list
datasets_name = yaml_content["datasets"][0]["name"]

# Modify input_field_name and output_field_name according to dataset
if datasets_name == "text_dataset_input_output_masking":
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"input_field_name": "input",
"output_field_name": "output",
}

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name == "apply_custom_data_template":
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
"template": template,
}

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
with pytest.raises(
(AssertionError, datasets.exceptions.DatasetGenerationCastError)
):
(_, _, _) = _process_dataconfig_file(data_args, tokenizer)


@pytest.mark.parametrize(
"data_args",
[
Expand Down Expand Up @@ -764,51 +1026,3 @@ def test_process_dataset_configs_with_sampling_error(
(_, _, _, _, _, _) = process_dataargs(
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
)


@pytest.mark.parametrize(
"datafiles, datasetconfigname",
[
(
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
),
],
)
def test_process_dataset_configs_with_sampling(datafiles, datasetconfigname):

data_args = configs.DataArguments()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
TRAIN_ARGS = configs.TrainingArguments(
packing=False,
max_seq_length=1024,
output_dir="tmp", # Not needed but positional
)

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
with open(datasetconfigname, "r") as f:
data = yaml.safe_load(f)
datasets = data["datasets"]
for i in range(len(datasets)):
d = datasets[i]
d["data_paths"][0] = datafiles[i]
yaml.dump(data, temp_yaml_file)
data_args.data_config_path = temp_yaml_file.name

(train_set, eval_set, _, _, _, _) = process_dataargs(
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
)

assert isinstance(train_set, Dataset)
if eval_set:
assert isinstance(eval_set, Dataset)

assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
if eval_set:
assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names))
Loading
Loading