diff --git a/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml index b66b01d55..ac7e07030 100644 --- a/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml +++ b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml @@ -10,5 +10,5 @@ datasets: remove_columns: all batched: false fn_kwargs: - input_field_name: "INPUT" - output_field_name: "OUTPUT" \ No newline at end of file + input_field_name: input + output_field_name: output \ No newline at end of file diff --git a/tests/artifacts/testdata/__init__.py b/tests/artifacts/testdata/__init__.py index 762f88ab9..5b1853772 100644 --- a/tests/artifacts/testdata/__init__.py +++ b/tests/artifacts/testdata/__init__.py @@ -23,7 +23,7 @@ JSONL_DATA_DIR = os.path.join(os.path.dirname(__file__), "jsonl") ARROW_DATA_DIR = os.path.join(os.path.dirname(__file__), "arrow") PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet") - +TWITTER_COMPLAINTS_DATA_DIR_JSON = os.path.join(DATA_DIR, "datafolder") TWITTER_COMPLAINTS_DATA_JSON = os.path.join( JSON_DATA_DIR, "twitter_complaints_small.json" ) diff --git a/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_1.json b/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_1.json new file mode 100644 index 000000000..2668241f8 --- /dev/null +++ b/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_1.json @@ -0,0 +1,152 @@ +[ + { + "ID": 0, + "Label": 2, + "input": "@HMRCcustomers No this is my first job", + "output": "no complaint" + }, + { + "ID": 1, + "Label": 2, + "input": "@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.", + "output": "no complaint" + }, + { + "ID": 2, + "Label": 1, + "input": "If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService", + "output": "complaint" + }, + { + "ID": 3, + "Label": 1, + "input": "@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.", + "output": "complaint" + }, + { + "ID": 4, + "Label": 2, + "input": "Couples wallpaper, so cute. :) #BrothersAtHome", + "output": "no complaint" + }, + { + "ID": 5, + "Label": 2, + "input": "@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https://t.co/WRtNsokblG", + "output": "no complaint" + }, + { + "ID": 6, + "Label": 2, + "input": "@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?", + "output": "no complaint" + }, + { + "ID": 7, + "Label": 1, + "input": "@nationalgridus I have no water and the bill is current and paid. Can you do something about this?", + "output": "complaint" + }, + { + "ID": 8, + "Label": 1, + "input": "Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude/condescending. I'll take my $$ to @Sephora", + "output": "complaint" + }, + { + "ID": 9, + "Label": 2, + "input": "@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd", + "output": "no complaint" + }, + { + "ID": 10, + "Label": 2, + "input": "@NortonSupport Thanks much.", + "output": "no complaint" + }, + { + "ID": 11, + "Label": 2, + "input": "@VerizonSupport all of a sudden I can't connect to my primary wireless network but guest one works", + "output": "no complaint" + }, + { + "ID": 12, + "Label": 2, + "input": "Aaaahhhhh!!!! My @Razer @PlayOverwatch d.va meka headset came in!!! I didn't even know it had shipped!!! So excited\u2026 https://t.co/4gXy9xED8d", + "output": "no complaint" + }, + { + "ID": 13, + "Label": 2, + "input": "@Lin_Manuel @jmessinaphoto @VAMNit Omg a little squish!!!!! Enjoy and congrats!!!! I miss mine being so young! \ufffd\ufffd\ufffd\ufffd\ufffd\ufffd", + "output": "no complaint" + }, + { + "ID": 14, + "Label": 2, + "input": "@IanJamesPoulter What's your secret to poaching eggs? Mine NEVER look that good.", + "output": "no complaint" + }, + { + "ID": 15, + "Label": 2, + "input": "@AWSSupport When will be able Kinesis Firehose compatible with Elasticsearch 6.0? Thank you!", + "output": "no complaint" + }, + { + "ID": 16, + "Label": 2, + "input": "@NCIS_CBS https://t.co/eeVL9Eu3bE", + "output": "no complaint" + }, + { + "ID": 17, + "Label": 2, + "input": "@msetchell Via the settings? That\u2019s how I do it on master T\u2019s", + "output": "no complaint" + }, + { + "ID": 18, + "Label": 2, + "input": "Today at work there was a low flying duck heading toward a crowd of people, and I yelled \"watch out! and I'm very disappointed with myself.", + "output": "no complaint" + }, + { + "ID": 19, + "Label": 1, + "input": "@NortonSupport @NortonOnline What the hell is a dm 5-10 days to get money back bank account now overdrawn thanks guys", + "output": "complaint" + }, + { + "ID": 20, + "Label": 1, + "input": "@united not happy with this delay from Newark to Manchester tonight :( only 30 mins free Wi-fi sucks ...", + "output": "complaint" + }, + { + "ID": 21, + "Label": 1, + "input": "@ZARA_Care I've been waiting on a reply to my tweets and DMs for days now?", + "output": "complaint" + }, + { + "ID": 22, + "Label": 2, + "input": "New Listing! Large 2 Family Home for Sale in #Passaic Park, #NJ #realestate #homesforsale Great Location!\u2026 https://t.co/IV4OrLXkMk", + "output": "no complaint" + }, + { + "ID": 23, + "Label": 1, + "input": "@SouthwestAir I love you but when sending me flight changes please don't use military time #ignoranceisbliss", + "output": "complaint" + }, + { + "ID": 24, + "Label": 2, + "input": "@JetBlue Completely understand but would prefer being on time to filling out forms....", + "output": "no complaint" + } +] \ No newline at end of file diff --git a/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_2.json b/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_2.json new file mode 100644 index 000000000..e93fed8e4 --- /dev/null +++ b/tests/artifacts/testdata/datafolder/twitter_complaints_input_output_2.json @@ -0,0 +1,152 @@ +[ + { + "ID": 25, + "Label": 2, + "input": "@nvidiacc I own two gtx 460 in sli. I want to try windows 8 dev preview. Which driver should I use. Can I use the windows 7 one.", + "output": "no complaint" + }, + { + "ID": 26, + "Label": 2, + "input": "Just posted a photo https://t.co/RShFwCjPHu", + "output": "no complaint" + }, + { + "ID": 27, + "Label": 2, + "input": "Love crescent rolls? Try adding pesto @PerdueChicken to them and you\u2019re going to love it! #Promotion #PerdueCrew -\u2026 https://t.co/KBHOfqCukH", + "output": "no complaint" + }, + { + "ID": 28, + "Label": 1, + "input": "@TopmanAskUs please just give me my money back.", + "output": "complaint" + }, + { + "ID": 29, + "Label": 2, + "input": "I just gave 5 stars to Tracee at @neimanmarcus for the great service I received!", + "output": "no complaint" + }, + { + "ID": 30, + "Label": 2, + "input": "@FitbitSupport when are you launching new clock faces for Indian market", + "output": "no complaint" + }, + { + "ID": 31, + "Label": 1, + "input": "@HPSupport my printer will not allow me to choose color instead it only prints monochrome #hppsdr #ijkhelp", + "output": "complaint" + }, + { + "ID": 32, + "Label": 1, + "input": "@DIRECTV can I get a monthly charge double refund when it sprinkles outside and we lose reception? #IamEmbarrasedForYou", + "output": "complaint" + }, + { + "ID": 33, + "Label": 1, + "input": "@AlfaRomeoCares Hi thanks for replying, could be my internet but link doesn't seem to be working", + "output": "complaint" + }, + { + "ID": 34, + "Label": 2, + "input": "Looks tasty! Going to share with everyone I know #FebrezeONE #sponsored https://t.co/4AQI53npei", + "output": "no complaint" + }, + { + "ID": 35, + "Label": 2, + "input": "@OnePlus_IN can OnePlus 5T do front camera portrait?", + "output": "no complaint" + }, + { + "ID": 36, + "Label": 1, + "input": "@sho_help @showtime your arrive is terrible streaming is stop and start every couple mins. Get it together it's xmas", + "output": "complaint" + }, + { + "ID": 37, + "Label": 2, + "input": "@KandraKPTV I just witnessed a huge building fire in Santa Monica California", + "output": "no complaint" + }, + { + "ID": 38, + "Label": 2, + "input": "@fernrocks most definitely the latter for me", + "output": "no complaint" + }, + { + "ID": 39, + "Label": 1, + "input": "@greateranglia Could I ask why the Area in front of BIC Station was not gritted withh all the snow.", + "output": "complaint" + }, + { + "ID": 40, + "Label": 2, + "input": "I'm earning points with #CricketRewards https://t.co/GfpGhqqnhE", + "output": "no complaint" + }, + { + "ID": 41, + "Label": 2, + "input": "@Schrapnel @comcast RIP me", + "output": "no complaint" + }, + { + "ID": 42, + "Label": 2, + "input": "The wait is finally over, just joined @SquareUK, hope to get started real soon!", + "output": "no complaint" + }, + { + "ID": 43, + "Label": 2, + "input": "@WholeFoods what's the best way to give feedback on a particular store to the regional/national office?", + "output": "no complaint" + }, + { + "ID": 44, + "Label": 2, + "input": "@DanielNewman I honestly would believe anything. People are...too much sometimes.", + "output": "no complaint" + }, + { + "ID": 45, + "Label": 2, + "input": "@asblough Yep! It should send you a notification with your driver\u2019s name and what time they\u2019ll be showing up!", + "output": "no complaint" + }, + { + "ID": 46, + "Label": 2, + "input": "@Wavy2Timez for real", + "output": "no complaint" + }, + { + "ID": 47, + "Label": 1, + "input": "@KenyaPower_Care no power in south b area... is it scheduled.", + "output": "complaint" + }, + { + "ID": 48, + "Label": 1, + "input": "Honda won't do anything about water leaking in brand new car. Frustrated! @HondaCustSvc @AmericanHonda", + "output": "complaint" + }, + { + "ID": 49, + "Label": 1, + "input": "@CBSNews @Dodge @ChryslerCares My driver side air bag has been recalled and replaced, but what about the passenger side?", + "output": "complaint" + } +] \ No newline at end of file diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 5559ac8ec..d673bc2d1 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -34,6 +34,7 @@ from tests.artifacts.testdata import ( MODEL_NAME, TWITTER_COMPLAINTS_DATA_ARROW, + TWITTER_COMPLAINTS_DATA_DIR_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, @@ -247,6 +248,30 @@ def test_load_dataset_with_dataconfig_and_datafile(datafile, datasetconfigname): ) +@pytest.mark.parametrize( + "datafolder, column_names, datasetconfigname", + [ + ( + TWITTER_COMPLAINTS_DATA_DIR_JSON, + set(["ID", "Label", "input", "output"]), + "text_dataset_input_output_masking", + ), + ], +) +def test_load_dataset_with_dataconfig_and_datafile( + datafolder, column_names, datasetconfigname +): + """Ensure that both datasetconfig and datafile cannot be passed.""" + datasetconfig = DataSetConfig(name=datasetconfigname, data_paths=[datafolder]) + processor = get_datapreprocessor( + processor_config=DataPreProcessorConfig(), tokenizer=None + ) + load_dataset = processor.load_dataset( + datasetconfig=datasetconfig, splitName="train", datafile=None + ) + assert set(load_dataset.column_names) == column_names + + def test_load_dataset_without_dataconfig_and_datafile(): """Ensure that both datasetconfig and datafile cannot be None.""" processor = get_datapreprocessor( diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 8dcdf3087..75dfce269 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -37,12 +37,14 @@ from scripts.run_inference import TunedCausalLM from tests.artifacts.predefined_data_configs import ( DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, ) from tests.artifacts.testdata import ( EMPTY_DATA, MALFORMATTED_DATA, MODEL_NAME, TWITTER_COMPLAINTS_DATA_ARROW, + TWITTER_COMPLAINTS_DATA_DIR_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, @@ -785,6 +787,10 @@ def test_run_causallm_ft_pretokenized(dataset_path): @pytest.mark.parametrize( "datafiles, datasetconfigname", [ + ( + [TWITTER_COMPLAINTS_DATA_DIR_JSON], + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + ), ( [ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index 33a368314..91f09cac1 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -87,7 +87,20 @@ def load_dataset( elif datasetconfig: files = datasetconfig.data_paths name = datasetconfig.name - # simple check to make sure all files are of same type. + + # Check if single file was passed, if directory, load as directory + if len(files)==1: + if os.path.isdir(files[0]): + try: + return datasets.load_dataset(files[0], split=splitName, **kwargs) + except DatasetNotFoundError as e: + raise e + except FileNotFoundError as e: + raise ValueError( + f"data path is invalid [{', '.join(files)}]" + ) from e + + # If the paths are files, ensure all have the same extension extns = [get_extension(f) for f in files] assert extns.count(extns[0]) == len( extns