From e76f79884b398adc8f2a47679e0d0acab5b0595e Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 11 Dec 2024 10:21:05 -0500 Subject: [PATCH 1/7] feat: first pass at data folder functionality Signed-off-by: Will Johnson --- tests/artifacts/predefined_data_configs/__init__.py | 3 +++ tests/data/test_data_preprocessing_utils.py | 1 + tuning/data/data_processors.py | 12 ++++++++++++ 3 files changed, 16 insertions(+) diff --git a/tests/artifacts/predefined_data_configs/__init__.py b/tests/artifacts/predefined_data_configs/__init__.py index c199406c6..ffc8269b4 100644 --- a/tests/artifacts/predefined_data_configs/__init__.py +++ b/tests/artifacts/predefined_data_configs/__init__.py @@ -19,6 +19,9 @@ ### Constants used for data PREDEFINED_DATA_CONFIGS = os.path.join(os.path.dirname(__file__)) +DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_FOLDER_YAML = os.path.join( + PREDEFINED_DATA_CONFIGS, "apply_custom_template_folder.yaml" +) DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML = os.path.join( PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml" ) diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index c34204f4f..0eb53e5fb 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -27,6 +27,7 @@ # First Party from tests.artifacts.predefined_data_configs import ( DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_FOLDER_YAML, DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index e92b7b684..cc08ad398 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -66,6 +66,18 @@ def load_dataset( files = [datafile] loader = get_loader_for_filepath(file_path=datafile) elif datasetconfig: + files = [] + for path in datasetconfig.data_paths: + if os.path.isdir(path): + # If the path is a folder, collect all files within it + folder_files = [ + os.path.join(path, file) + for file in os.listdir(path) + if os.path.isfile(os.path.join(path, file)) + ] + files.extend(folder_files) + else: + files.append(path) files = datasetconfig.data_paths name = datasetconfig.name # simple check to make sure all files are of same type. From 7f5451d3e3d606a88bc9a76f0239d77f6f7b8a9d Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 12 Dec 2024 09:43:08 -0500 Subject: [PATCH 2/7] fix: remove files set after condition Signed-off-by: Will Johnson --- tuning/data/data_processors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index cc08ad398..8b41be00f 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -78,7 +78,6 @@ def load_dataset( files.extend(folder_files) else: files.append(path) - files = datasetconfig.data_paths name = datasetconfig.name # simple check to make sure all files are of same type. extns = [get_extension(f) for f in files] From 7f9fff12b9f5e09374af9dccfe57cd7572b51f64 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 12 Dec 2024 09:44:30 -0500 Subject: [PATCH 3/7] fix: remove files set after condition Signed-off-by: Will Johnson --- tests/data/test_data_preprocessing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 0eb53e5fb..01946411d 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -26,8 +26,8 @@ # First Party from tests.artifacts.predefined_data_configs import ( - DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_FOLDER_YAML, + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, From 922856db58165d4ba68f5fd69a3fa4a038a9116d Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 13 Dec 2024 09:52:37 -0500 Subject: [PATCH 4/7] test: for data folder Signed-off-by: Will Johnson --- .../predefined_data_configs/__init__.py | 3 - .../tokenize_and_apply_input_masking.yaml | 4 +- tests/artifacts/testdata/__init__.py | 1 + .../twitter_complaints_input_output_1.json | 152 ++++++++++++++++++ .../twitter_complaints_input_output_2.json | 152 ++++++++++++++++++ tests/data/test_data_preprocessing_utils.py | 26 ++- tests/test_sft_trainer.py | 55 +++++++ 7 files changed, 387 insertions(+), 6 deletions(-) create mode 100644 tests/artifacts/testdata/datafolder/twitter_complaints_input_output_1.json create mode 100644 tests/artifacts/testdata/datafolder/twitter_complaints_input_output_2.json diff --git a/tests/artifacts/predefined_data_configs/__init__.py b/tests/artifacts/predefined_data_configs/__init__.py index ffc8269b4..c199406c6 100644 --- a/tests/artifacts/predefined_data_configs/__init__.py +++ b/tests/artifacts/predefined_data_configs/__init__.py @@ -19,9 +19,6 @@ ### Constants used for data PREDEFINED_DATA_CONFIGS = os.path.join(os.path.dirname(__file__)) -DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_FOLDER_YAML = os.path.join( - PREDEFINED_DATA_CONFIGS, "apply_custom_template_folder.yaml" -) DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML = os.path.join( PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml" ) diff --git a/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml index b66b01d55..ac7e07030 100644 --- a/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml +++ b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml @@ -10,5 +10,5 @@ datasets: remove_columns: all batched: false fn_kwargs: - input_field_name: "INPUT" - output_field_name: "OUTPUT" \ No newline at end of file + input_field_name: input + output_field_name: output \ No newline at end of file diff --git a/tests/artifacts/testdata/__init__.py b/tests/artifacts/testdata/__init__.py index 39895f6f1..99bad8c98 100644 --- a/tests/artifacts/testdata/__init__.py +++ b/tests/artifacts/testdata/__init__.py @@ -20,6 +20,7 @@ ### Constants used for data DATA_DIR = os.path.join(os.path.dirname(__file__)) PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet") +TWITTER_COMPLAINTS_DATA_DIR_JSON = os.path.join(DATA_DIR, "datafolder") TWITTER_COMPLAINTS_DATA_JSON = os.path.join(DATA_DIR, "twitter_complaints_small.json") TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl") TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(DATA_DIR, "twitter_complaints_small.arrow") diff --git a/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_1.json b/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_1.json new file mode 100644 index 000000000..2668241f8 --- /dev/null +++ b/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_1.json @@ -0,0 +1,152 @@ +[ + { + "ID": 0, + "Label": 2, + "input": "@HMRCcustomers No this is my first job", + "output": "no complaint" + }, + { + "ID": 1, + "Label": 2, + "input": "@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.", + "output": "no complaint" + }, + { + "ID": 2, + "Label": 1, + "input": "If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService", + "output": "complaint" + }, + { + "ID": 3, + "Label": 1, + "input": "@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.", + "output": "complaint" + }, + { + "ID": 4, + "Label": 2, + "input": "Couples wallpaper, so cute. :) #BrothersAtHome", + "output": "no complaint" + }, + { + "ID": 5, + "Label": 2, + "input": "@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https://t.co/WRtNsokblG", + "output": "no complaint" + }, + { + "ID": 6, + "Label": 2, + "input": "@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?", + "output": "no complaint" + }, + { + "ID": 7, + "Label": 1, + "input": "@nationalgridus I have no water and the bill is current and paid. Can you do something about this?", + "output": "complaint" + }, + { + "ID": 8, + "Label": 1, + "input": "Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude/condescending. I'll take my $$ to @Sephora", + "output": "complaint" + }, + { + "ID": 9, + "Label": 2, + "input": "@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd", + "output": "no complaint" + }, + { + "ID": 10, + "Label": 2, + "input": "@NortonSupport Thanks much.", + "output": "no complaint" + }, + { + "ID": 11, + "Label": 2, + "input": "@VerizonSupport all of a sudden I can't connect to my primary wireless network but guest one works", + "output": "no complaint" + }, + { + "ID": 12, + "Label": 2, + "input": "Aaaahhhhh!!!! My @Razer @PlayOverwatch d.va meka headset came in!!! I didn't even know it had shipped!!! So excited\u2026 https://t.co/4gXy9xED8d", + "output": "no complaint" + }, + { + "ID": 13, + "Label": 2, + "input": "@Lin_Manuel @jmessinaphoto @VAMNit Omg a little squish!!!!! Enjoy and congrats!!!! I miss mine being so young! \ufffd\ufffd\ufffd\ufffd\ufffd\ufffd", + "output": "no complaint" + }, + { + "ID": 14, + "Label": 2, + "input": "@IanJamesPoulter What's your secret to poaching eggs? Mine NEVER look that good.", + "output": "no complaint" + }, + { + "ID": 15, + "Label": 2, + "input": "@AWSSupport When will be able Kinesis Firehose compatible with Elasticsearch 6.0? Thank you!", + "output": "no complaint" + }, + { + "ID": 16, + "Label": 2, + "input": "@NCIS_CBS https://t.co/eeVL9Eu3bE", + "output": "no complaint" + }, + { + "ID": 17, + "Label": 2, + "input": "@msetchell Via the settings? That\u2019s how I do it on master T\u2019s", + "output": "no complaint" + }, + { + "ID": 18, + "Label": 2, + "input": "Today at work there was a low flying duck heading toward a crowd of people, and I yelled \"watch out! and I'm very disappointed with myself.", + "output": "no complaint" + }, + { + "ID": 19, + "Label": 1, + "input": "@NortonSupport @NortonOnline What the hell is a dm 5-10 days to get money back bank account now overdrawn thanks guys", + "output": "complaint" + }, + { + "ID": 20, + "Label": 1, + "input": "@united not happy with this delay from Newark to Manchester tonight :( only 30 mins free Wi-fi sucks ...", + "output": "complaint" + }, + { + "ID": 21, + "Label": 1, + "input": "@ZARA_Care I've been waiting on a reply to my tweets and DMs for days now?", + "output": "complaint" + }, + { + "ID": 22, + "Label": 2, + "input": "New Listing! Large 2 Family Home for Sale in #Passaic Park, #NJ #realestate #homesforsale Great Location!\u2026 https://t.co/IV4OrLXkMk", + "output": "no complaint" + }, + { + "ID": 23, + "Label": 1, + "input": "@SouthwestAir I love you but when sending me flight changes please don't use military time #ignoranceisbliss", + "output": "complaint" + }, + { + "ID": 24, + "Label": 2, + "input": "@JetBlue Completely understand but would prefer being on time to filling out forms....", + "output": "no complaint" + } +] \ No newline at end of file diff --git a/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_2.json b/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_2.json new file mode 100644 index 000000000..e93fed8e4 --- /dev/null +++ b/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_2.json @@ -0,0 +1,152 @@ +[ + { + "ID": 25, + "Label": 2, + "input": "@nvidiacc I own two gtx 460 in sli. I want to try windows 8 dev preview. Which driver should I use. Can I use the windows 7 one.", + "output": "no complaint" + }, + { + "ID": 26, + "Label": 2, + "input": "Just posted a photo https://t.co/RShFwCjPHu", + "output": "no complaint" + }, + { + "ID": 27, + "Label": 2, + "input": "Love crescent rolls? Try adding pesto @PerdueChicken to them and you\u2019re going to love it! #Promotion #PerdueCrew -\u2026 https://t.co/KBHOfqCukH", + "output": "no complaint" + }, + { + "ID": 28, + "Label": 1, + "input": "@TopmanAskUs please just give me my money back.", + "output": "complaint" + }, + { + "ID": 29, + "Label": 2, + "input": "I just gave 5 stars to Tracee at @neimanmarcus for the great service I received!", + "output": "no complaint" + }, + { + "ID": 30, + "Label": 2, + "input": "@FitbitSupport when are you launching new clock faces for Indian market", + "output": "no complaint" + }, + { + "ID": 31, + "Label": 1, + "input": "@HPSupport my printer will not allow me to choose color instead it only prints monochrome #hppsdr #ijkhelp", + "output": "complaint" + }, + { + "ID": 32, + "Label": 1, + "input": "@DIRECTV can I get a monthly charge double refund when it sprinkles outside and we lose reception? #IamEmbarrasedForYou", + "output": "complaint" + }, + { + "ID": 33, + "Label": 1, + "input": "@AlfaRomeoCares Hi thanks for replying, could be my internet but link doesn't seem to be working", + "output": "complaint" + }, + { + "ID": 34, + "Label": 2, + "input": "Looks tasty! Going to share with everyone I know #FebrezeONE #sponsored https://t.co/4AQI53npei", + "output": "no complaint" + }, + { + "ID": 35, + "Label": 2, + "input": "@OnePlus_IN can OnePlus 5T do front camera portrait?", + "output": "no complaint" + }, + { + "ID": 36, + "Label": 1, + "input": "@sho_help @showtime your arrive is terrible streaming is stop and start every couple mins. Get it together it's xmas", + "output": "complaint" + }, + { + "ID": 37, + "Label": 2, + "input": "@KandraKPTV I just witnessed a huge building fire in Santa Monica California", + "output": "no complaint" + }, + { + "ID": 38, + "Label": 2, + "input": "@fernrocks most definitely the latter for me", + "output": "no complaint" + }, + { + "ID": 39, + "Label": 1, + "input": "@greateranglia Could I ask why the Area in front of BIC Station was not gritted withh all the snow.", + "output": "complaint" + }, + { + "ID": 40, + "Label": 2, + "input": "I'm earning points with #CricketRewards https://t.co/GfpGhqqnhE", + "output": "no complaint" + }, + { + "ID": 41, + "Label": 2, + "input": "@Schrapnel @comcast RIP me", + "output": "no complaint" + }, + { + "ID": 42, + "Label": 2, + "input": "The wait is finally over, just joined @SquareUK, hope to get started real soon!", + "output": "no complaint" + }, + { + "ID": 43, + "Label": 2, + "input": "@WholeFoods what's the best way to give feedback on a particular store to the regional/national office?", + "output": "no complaint" + }, + { + "ID": 44, + "Label": 2, + "input": "@DanielNewman I honestly would believe anything. People are...too much sometimes.", + "output": "no complaint" + }, + { + "ID": 45, + "Label": 2, + "input": "@asblough Yep! It should send you a notification with your driver\u2019s name and what time they\u2019ll be showing up!", + "output": "no complaint" + }, + { + "ID": 46, + "Label": 2, + "input": "@Wavy2Timez for real", + "output": "no complaint" + }, + { + "ID": 47, + "Label": 1, + "input": "@KenyaPower_Care no power in south b area... is it scheduled.", + "output": "complaint" + }, + { + "ID": 48, + "Label": 1, + "input": "Honda won't do anything about water leaking in brand new car. Frustrated! @HondaCustSvc @AmericanHonda", + "output": "complaint" + }, + { + "ID": 49, + "Label": 1, + "input": "@CBSNews @Dodge @ChryslerCares My driver side air bag has been recalled and replaced, but what about the passenger side?", + "output": "complaint" + } +] \ No newline at end of file diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 01946411d..c29d5df99 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -26,7 +26,6 @@ # First Party from tests.artifacts.predefined_data_configs import ( - DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_FOLDER_YAML, DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, @@ -35,6 +34,7 @@ from tests.artifacts.testdata import ( MODEL_NAME, TWITTER_COMPLAINTS_DATA_ARROW, + TWITTER_COMPLAINTS_DATA_DIR_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, @@ -248,6 +248,30 @@ def test_load_dataset_with_dataconfig_and_datafile(datafile, datasetconfigname): ) +@pytest.mark.parametrize( + "datafolder, column_names, datasetconfigname", + [ + ( + TWITTER_COMPLAINTS_DATA_DIR_JSON, + set(["ID", "Label", "input", "output"]), + "text_dataset_input_output_masking", + ), + ], +) +def test_load_dataset_with_dataconfig_and_datafile( + datafolder, column_names, datasetconfigname +): + """Ensure that both datasetconfig and datafile cannot be passed.""" + datasetconfig = DataSetConfig(name=datasetconfigname, data_paths=[datafolder]) + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + load_dataset = processor.load_dataset( + datasetconfig=datasetconfig, splitName="train", datafile=None + ) + assert set(load_dataset.column_names) == column_names + + def test_load_dataset_without_dataconfig_and_datafile(): """Ensure that both datasetconfig and datafile cannot be None.""" processor = get_datapreprocessor( diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 69ccbf4fa..907cfd054 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -27,14 +27,19 @@ import pytest import torch import transformers +import yaml # First Party from build.utils import serialize_args from scripts.run_inference import TunedCausalLM +from tests.artifacts.predefined_data_configs import ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, +) from tests.artifacts.testdata import ( EMPTY_DATA, MALFORMATTED_DATA, MODEL_NAME, + TWITTER_COMPLAINTS_DATA_DIR_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSONL, @@ -754,6 +759,56 @@ def test_run_causallm_ft_pretokenized(dataset_path): assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference +@pytest.mark.parametrize( + "datafiles, datasetconfigname", + [ + ( + TWITTER_COMPLAINTS_DATA_DIR_JSON, + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + ), + ], +) +def test_run_causallm_ft_and_inference_with_datafolder(datasetconfigname, datafiles): + """Check if we can finetune causallm models using multiple datasets with multiple files""" + with tempfile.TemporaryDirectory() as tempdir: + data_formatting_args = copy.deepcopy(DATA_ARGS) + + # set training_data_path and response_template to none + data_formatting_args.response_template = None + data_formatting_args.training_data_path = None + + # add data_paths in data_config file + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + with open(datasetconfigname, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + datasets = data["datasets"] + for _, d in enumerate(datasets): + d["data_paths"] = [datafiles] + yaml.dump(data, temp_yaml_file) + data_formatting_args.data_config_path = temp_yaml_file.name + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args) + + # validate full ft configs + _validate_training(tempdir) + _, checkpoint_path = _get_latest_checkpoint_trainer_state(tempdir) + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + + ############################# Helper functions ############################# def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): train_args = copy.deepcopy(training_args) From 2fe3304b30f9117b8f91c90e4b95a72ac0889355 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 13 Dec 2024 12:45:34 -0500 Subject: [PATCH 5/7] fmt Signed-off-by: Will Johnson --- tests/test_sft_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 546442485..75dfce269 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -36,15 +36,15 @@ from build.utils import serialize_args from scripts.run_inference import TunedCausalLM from tests.artifacts.predefined_data_configs import ( - DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, ) from tests.artifacts.testdata import ( EMPTY_DATA, MALFORMATTED_DATA, MODEL_NAME, - TWITTER_COMPLAINTS_DATA_DIR_JSON, TWITTER_COMPLAINTS_DATA_ARROW, + TWITTER_COMPLAINTS_DATA_DIR_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, From a9b2644a2b3fedfb9a602c263bebcf55592d475a Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 13 Dec 2024 16:04:45 -0500 Subject: [PATCH 6/7] fix: loading folder Signed-off-by: Will Johnson --- tuning/data/data_processors.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index f1fc125ac..cf839a512 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -85,25 +85,27 @@ def load_dataset( files = [datafile] loader = get_loader_for_filepath(file_path=datafile) elif datasetconfig: - files = [] - for path in datasetconfig.data_paths: - if os.path.isdir(path): - # If the path is a folder, collect all files within it - folder_files = [ - os.path.join(path, file) - for file in os.listdir(path) - if os.path.isfile(os.path.join(path, file)) - ] - files.extend(folder_files) - else: - files.append(path) + files = datasetconfig.data_paths name = datasetconfig.name - # simple check to make sure all files are of same type. + + # Check if the first path is a directory or a file with an extension + first_path = files[0] + if os.path.isdir(first_path): + try: + return datasets.load_dataset(first_path, split=splitName, **kwargs) + except DatasetNotFoundError as e: + raise e + except FileNotFoundError as e: + raise ValueError( + f"data path is invalid [{', '.join(files)}]" + ) from e + + # If the paths are files, ensure all have the same extension extns = [get_extension(f) for f in files] assert extns.count(extns[0]) == len( extns ), f"All files in the dataset {name} should have the same extension" - loader = get_loader_for_filepath(file_path=files[0]) + loader = get_loader_for_filepath(file_path=first_path) if loader in (None, ""): raise ValueError(f"data path is invalid [{', '.join(files)}]") From 865c5fac40d78e8c756440fa4a95a69413867c3b Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 18 Dec 2024 11:19:42 -0500 Subject: [PATCH 7/7] fix: assume only 1 dataset Signed-off-by: Will Johnson --- tuning/data/data_processors.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index cf839a512..91f09cac1 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -88,24 +88,24 @@ def load_dataset( files = datasetconfig.data_paths name = datasetconfig.name - # Check if the first path is a directory or a file with an extension - first_path = files[0] - if os.path.isdir(first_path): - try: - return datasets.load_dataset(first_path, split=splitName, **kwargs) - except DatasetNotFoundError as e: - raise e - except FileNotFoundError as e: - raise ValueError( - f"data path is invalid [{', '.join(files)}]" - ) from e + # Check if single file was passed, if directory, load as directory + if len(files)==1: + if os.path.isdir(files[0]): + try: + return datasets.load_dataset(files[0], split=splitName, **kwargs) + except DatasetNotFoundError as e: + raise e + except FileNotFoundError as e: + raise ValueError( + f"data path is invalid [{', '.join(files)}]" + ) from e # If the paths are files, ensure all have the same extension extns = [get_extension(f) for f in files] assert extns.count(extns[0]) == len( extns ), f"All files in the dataset {name} should have the same extension" - loader = get_loader_for_filepath(file_path=first_path) + loader = get_loader_for_filepath(file_path=files[0]) if loader in (None, ""): raise ValueError(f"data path is invalid [{', '.join(files)}]")