diff --git a/tests/artifacts/testdata/__init__.py b/tests/artifacts/testdata/__init__.py index 8b6a7ea43..39895f6f1 100644 --- a/tests/artifacts/testdata/__init__.py +++ b/tests/artifacts/testdata/__init__.py @@ -22,6 +22,7 @@ PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet") TWITTER_COMPLAINTS_DATA_JSON = os.path.join(DATA_DIR, "twitter_complaints_small.json") TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl") +TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(DATA_DIR, "twitter_complaints_small.arrow") TWITTER_COMPLAINTS_DATA_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_small.parquet" ) @@ -31,6 +32,9 @@ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join( DATA_DIR, "twitter_complaints_input_output.jsonl" ) +TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW = os.path.join( + DATA_DIR, "twitter_complaints_input_output.arrow" +) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_input_output.parquet" ) @@ -40,6 +44,9 @@ TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join( DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl" ) +TWITTER_COMPLAINTS_TOKENIZED_ARROW = os.path.join( + DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow" +) TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet" ) diff --git a/tests/artifacts/testdata/twitter_complaints_input_output.arrow b/tests/artifacts/testdata/twitter_complaints_input_output.arrow new file mode 100644 index 000000000..602798d34 Binary files /dev/null and b/tests/artifacts/testdata/twitter_complaints_input_output.arrow differ diff --git a/tests/artifacts/testdata/twitter_complaints_small.arrow b/tests/artifacts/testdata/twitter_complaints_small.arrow new file mode 100644 index 000000000..b5bba53e2 Binary files /dev/null and b/tests/artifacts/testdata/twitter_complaints_small.arrow differ diff --git a/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow b/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow new file mode 100644 index 000000000..6afd36ddf Binary files /dev/null and b/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow differ diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index a4ec5dbf7..6e7dacde8 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -32,12 +32,15 @@ ) from tests.artifacts.testdata import ( MODEL_NAME, + TWITTER_COMPLAINTS_DATA_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_PARQUET, @@ -62,6 +65,10 @@ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, set(["ID", "Label", "input", "output"]), ), + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + set(["ID", "Label", "input", "output", "sequence"]), + ), ( TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, set(["ID", "Label", "input", "output"]), @@ -80,6 +87,20 @@ ] ), ), + ( + TWITTER_COMPLAINTS_TOKENIZED_ARROW, + set( + [ + "Tweet text", + "ID", + "Label", + "text_label", + "output", + "input_ids", + "labels", + ] + ), + ), ( TWITTER_COMPLAINTS_TOKENIZED_PARQUET, set( @@ -98,6 +119,10 @@ TWITTER_COMPLAINTS_DATA_JSONL, set(["Tweet text", "ID", "Label", "text_label", "output"]), ), + ( + TWITTER_COMPLAINTS_DATA_ARROW, + set(["Tweet text", "ID", "Label", "text_label", "output"]), + ), ( TWITTER_COMPLAINTS_DATA_PARQUET, set(["Tweet text", "ID", "Label", "text_label", "output"]), @@ -123,6 +148,11 @@ def test_load_dataset_with_datafile(datafile, column_names): set(["ID", "Label", "input", "output"]), "text_dataset_input_output_masking", ), + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + set(["ID", "Label", "input", "output", "sequence"]), + "text_dataset_input_output_masking", + ), ( TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, set(["ID", "Label", "input", "output"]), @@ -163,6 +193,11 @@ def test_load_dataset_with_datafile(datafile, column_names): set(["Tweet text", "ID", "Label", "text_label", "output"]), "apply_custom_data_template", ), + ( + TWITTER_COMPLAINTS_DATA_ARROW, + set(["Tweet text", "ID", "Label", "text_label", "output"]), + "apply_custom_data_template", + ), ( TWITTER_COMPLAINTS_DATA_PARQUET, set(["Tweet text", "ID", "Label", "text_label", "output"]), @@ -593,6 +628,12 @@ def test_process_dataargs(data_args): training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, ) ), + # ARROW pretokenized train datasets + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_TOKENIZED_ARROW, + ) + ), # PARQUET pretokenized train datasets ( configs.DataArguments( diff --git a/tuning/utils/utils.py b/tuning/utils/utils.py index 585011ae9..6eef6b2cf 100644 --- a/tuning/utils/utils.py +++ b/tuning/utils/utils.py @@ -31,6 +31,8 @@ def get_loader_for_filepath(file_path: str) -> str: return "text" if ext in (".json", ".jsonl"): return "json" + if ext in (".arrow"): + return "arrow" if ext in (".parquet"): return "parquet" return ext