Skip to content

Commit

Permalink
test: add arrow datasets and arrow unit tests (#403)
Browse files Browse the repository at this point in the history
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
  • Loading branch information
willmj authored Dec 7, 2024
1 parent fbe6064 commit e6f7a22
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tests/artifacts/testdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet")
TWITTER_COMPLAINTS_DATA_JSON = os.path.join(DATA_DIR, "twitter_complaints_small.json")
TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl")
TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(DATA_DIR, "twitter_complaints_small.arrow")
TWITTER_COMPLAINTS_DATA_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_small.parquet"
)
Expand All @@ -31,6 +32,9 @@
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join(
DATA_DIR, "twitter_complaints_input_output.jsonl"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW = os.path.join(
DATA_DIR, "twitter_complaints_input_output.arrow"
)
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_input_output.parquet"
)
Expand All @@ -40,6 +44,9 @@
TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join(
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl"
)
TWITTER_COMPLAINTS_TOKENIZED_ARROW = os.path.join(
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow"
)
TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet"
)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
41 changes: 41 additions & 0 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@
)
from tests.artifacts.testdata import (
MODEL_NAME,
TWITTER_COMPLAINTS_DATA_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
TWITTER_COMPLAINTS_DATA_JSON,
TWITTER_COMPLAINTS_DATA_JSONL,
TWITTER_COMPLAINTS_DATA_PARQUET,
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
TWITTER_COMPLAINTS_TOKENIZED_JSON,
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
Expand All @@ -62,6 +65,10 @@
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
set(["ID", "Label", "input", "output"]),
),
(
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
set(["ID", "Label", "input", "output", "sequence"]),
),
(
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
set(["ID", "Label", "input", "output"]),
Expand All @@ -80,6 +87,20 @@
]
),
),
(
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
set(
[
"Tweet text",
"ID",
"Label",
"text_label",
"output",
"input_ids",
"labels",
]
),
),
(
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
set(
Expand All @@ -98,6 +119,10 @@
TWITTER_COMPLAINTS_DATA_JSONL,
set(["Tweet text", "ID", "Label", "text_label", "output"]),
),
(
TWITTER_COMPLAINTS_DATA_ARROW,
set(["Tweet text", "ID", "Label", "text_label", "output"]),
),
(
TWITTER_COMPLAINTS_DATA_PARQUET,
set(["Tweet text", "ID", "Label", "text_label", "output"]),
Expand All @@ -123,6 +148,11 @@ def test_load_dataset_with_datafile(datafile, column_names):
set(["ID", "Label", "input", "output"]),
"text_dataset_input_output_masking",
),
(
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
set(["ID", "Label", "input", "output", "sequence"]),
"text_dataset_input_output_masking",
),
(
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
set(["ID", "Label", "input", "output"]),
Expand Down Expand Up @@ -163,6 +193,11 @@ def test_load_dataset_with_datafile(datafile, column_names):
set(["Tweet text", "ID", "Label", "text_label", "output"]),
"apply_custom_data_template",
),
(
TWITTER_COMPLAINTS_DATA_ARROW,
set(["Tweet text", "ID", "Label", "text_label", "output"]),
"apply_custom_data_template",
),
(
TWITTER_COMPLAINTS_DATA_PARQUET,
set(["Tweet text", "ID", "Label", "text_label", "output"]),
Expand Down Expand Up @@ -593,6 +628,12 @@ def test_process_dataargs(data_args):
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
)
),
# ARROW pretokenized train datasets
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_ARROW,
)
),
# PARQUET pretokenized train datasets
(
configs.DataArguments(
Expand Down
2 changes: 2 additions & 0 deletions tuning/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def get_loader_for_filepath(file_path: str) -> str:
return "text"
if ext in (".json", ".jsonl"):
return "json"
if ext in (".arrow"):
return "arrow"
if ext in (".parquet"):
return "parquet"
return ext
Expand Down

0 comments on commit e6f7a22

Please # to comment.