From cad3a2daea92e4d24485b6322996fcae25ccfb23 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Sun, 8 Dec 2024 16:02:00 +0530 Subject: [PATCH 1/3] Expose additional data handlers as an argument to the train function. Signed-off-by: Dushyant Behl --- tuning/data/data_processors.py | 26 +++++++++------- tuning/data/setup_dataprocessor.py | 49 +++++++++++++++++++++--------- tuning/sft_trainer.py | 8 +++-- 3 files changed, 55 insertions(+), 28 deletions(-) diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index e92b7b684..d266258d1 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -13,7 +13,7 @@ # limitations under the License. # Standard -from typing import Dict, List, Union +from typing import Callable, Dict, List, Union import logging import os @@ -35,7 +35,7 @@ class DataPreProcessor: tokenizer = None data_config: DataConfig = None processor_config: DataPreProcessorConfig = None - registered_handlers: Dict[str, callable] = None + registered_handlers: Dict[str, Callable] = None def __init__( self, processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer @@ -46,9 +46,20 @@ def __init__( # Initialize other objects self.registered_handlers = {} - def register_data_handler(self, name: str, func: callable): + def register_data_handler(self, name: str, func: Callable): + assert isinstance(name, str), "Handler name should be of str type" + assert callable(func), "Handler should be a callable routine" self.registered_handlers[name] = func + def register_data_handlers(self, handlers: Dict[str, Callable]): + if handlers is None: + return + assert isinstance( + handlers, Dict + ), "Handlers should be of type Dict[str:Callable]" + for k, v in handlers.items(): + self.register_data_handler(name=k, func=v) + def load_dataset( self, datasetconfig: DataSetConfig, @@ -238,13 +249,6 @@ def process_dataset_configs( return train_dataset -def autoregister_available_handlers(processor: DataPreProcessor): - if processor is None: - return - for name, func in AVAILABLE_DATA_HANDLERS.items(): - processor.register_data_handler(name=name, func=func) - - def get_datapreprocessor( processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer ) -> DataPreProcessor: @@ -252,5 +256,5 @@ def get_datapreprocessor( processor_config=processor_config, tokenizer=tokenizer, ) - autoregister_available_handlers(processor) + processor.register_data_handlers(AVAILABLE_DATA_HANDLERS) return processor diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index 5db8e0aee..22df51920 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -13,7 +13,7 @@ # limitations under the License. # Standard -from typing import Union +from typing import Callable, Dict, Union import logging # Third Party @@ -55,11 +55,16 @@ def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]): # TODO: For now assume only training dataset is passed via data config file. # This is very limited but is done to keep first implementation minimal -def _process_dataconfig_file(data_args: DataArguments, tokenizer: AutoTokenizer): +def _process_dataconfig_file( + data_args: DataArguments, + tokenizer: AutoTokenizer, + additional_data_handlers: Dict[str, Callable] = None, +): data_config = load_and_validate_data_config(data_args.data_config_path) processor = get_datapreprocessor( processor_config=data_config.dataprocessor, tokenizer=tokenizer ) + processor.register_data_handlers(additional_data_handlers) train_dataset = processor.process_dataset_configs(data_config.datasets) return (train_dataset, None, data_args.dataset_text_field) @@ -179,6 +184,7 @@ def _process_raw_data_args( tokenizer: AutoTokenizer, packing: bool, max_seq_length: int, + additional_data_handlers: Dict[str, Callable] = None, ): # Create a data processor with default processor config @@ -186,7 +192,7 @@ def _process_raw_data_args( data_processor = get_datapreprocessor( processor_config=default_processor_config, tokenizer=tokenizer ) - + data_processor.register_data_handlers(additional_data_handlers) assert isinstance( data_args.training_data_path, str ), "Training data path has to be set and str" @@ -259,7 +265,10 @@ def _process_raw_data_args( # If no data config file is specified, process the remaining data arguments # to determine the use case based on their presence, as explained in _process_raw_data_args. def process_dataargs( - data_args: DataArguments, tokenizer: AutoTokenizer, train_args: TrainingArguments + data_args: DataArguments, + tokenizer: AutoTokenizer, + train_args: TrainingArguments, + additional_data_handlers: Dict[str, Callable] = None, ): """ Args: @@ -268,11 +277,17 @@ def process_dataargs( train_args: TrainingArguments Training arguments passed to the library Used for packing and max_seq_length + additional_data_handlers: A Dict of [str, callable] data handlers + which need to be registered with the data preprocessor Returns: Tuple(Dataset, Dataset, str, DataCollator, int, Dict) - tuple containing train_dataset, eval_dataset, dataset_text_field, - data_collator, max_seq_length and dataset_kwargs - + tuple containing + train_dataset (Dataset/IterableDataset), + eval_dataset (Dataset/IterableDataset), + dataset_text_field (str), + data_collator (DataCollator) + max_seq_length(int) and + dataset_kwargs (Dict) """ max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) @@ -290,26 +305,32 @@ def process_dataargs( if data_args.data_config_path: train_dataset, eval_dataset, dataset_text_field = _process_dataconfig_file( - data_args, tokenizer + data_args, tokenizer, additional_data_handlers ) else: train_dataset, eval_dataset, dataset_text_field = _process_raw_data_args( - data_args, tokenizer, train_args.packing, max_seq_length + data_args, + tokenizer, + train_args.packing, + max_seq_length, + additional_data_handlers, ) + # Note: This check should not be removed. + # Its important to recompute this post handling to + # check if we already tokenized the dataset or not. + is_tokenized_dataset = is_pretokenized_dataset(train_dataset or eval_dataset) + data_collator = get_data_collator( train_args.packing, data_args.response_template, tokenizer, - # Note: This check should not be removed. - # Its important to recompute this post handling to - # check if we already tokenized the dataset or not. - is_pretokenized_dataset(train_dataset), + is_tokenized_dataset, max_seq_length, ) dataset_kwargs = {} - if is_pretokenized_dataset(train_dataset or eval_dataset): + if is_tokenized_dataset: dataset_kwargs["skip_prepare_dataset"] = True return ( diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index c02d73781..2ad55c06b 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -13,7 +13,7 @@ # limitations under the License. # Standard -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import dataclasses import json import logging @@ -85,6 +85,7 @@ def train( attention_and_distributed_packing_config: Optional[ AttentionAndDistributedPackingConfig ] = None, + additional_data_handlers: Optional[Dict[str, Callable]] = None, ) -> tuple[SFTTrainer, dict]: """Call the SFTTrainer @@ -113,7 +114,8 @@ def train( Should be used in combination with quantized_lora_config. Also currently fused_lora and fast_kernels must used together (may change in future). \ attention_and_distributed_packing_config: Used for padding-free attention and multipack. - + additional_data_handlers: Dict [str:Callable] of any extra data handlers \ + to be registered with the data preprocessor Returns: Tuple: Instance of SFTTrainer , some metadata in a dict Metadata contains information on number of added tokens while tuning. @@ -297,7 +299,7 @@ def train( data_collator, max_seq_length, dataset_kwargs, - ) = process_dataargs(data_args, tokenizer, train_args) + ) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers) additional_metrics["data_preprocessing_time"] = ( time.time() - data_preprocessing_time ) From 6310c5dae73cfa7caac2603f661191022c2434df Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Thu, 12 Dec 2024 20:53:39 +0530 Subject: [PATCH 2/3] add unit tests for additional data handlers Signed-off-by: Dushyant Behl --- tests/test_sft_trainer.py | 105 +++++++++++++++++++++++++++++ tuning/data/data_processors.py | 22 ++++-- tuning/data/setup_dataprocessor.py | 10 +-- 3 files changed, 126 insertions(+), 11 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 69ccbf4fa..06045cbda 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -15,7 +15,10 @@ """Unit Tests for SFT Trainer. """ +# pylint: disable=too-many-lines + # Standard +from dataclasses import asdict import copy import json import os @@ -46,6 +49,13 @@ from tuning import sft_trainer from tuning.config import configs, peft_config from tuning.config.tracker_configs import FileLoggingTrackerConfig +from tuning.data.data_config import ( + DataConfig, + DataHandlerConfig, + DataPreProcessorConfig, + DataSetConfig, +) +from tuning.data.data_handlers import apply_dataset_formatting MODEL_ARGS = configs.ModelArguments( model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32" @@ -1124,3 +1134,98 @@ def test_pretokenized_dataset_wrong_format(): # is essentially swallowing a KeyError here. with pytest.raises(ValueError): sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) + + +########################################################################### +### Tests for checking different cases for the argument additional_handlers +### The argument `additional_handlers` in train::sft_trainer.py is used to pass +### extra data handlers which should be a Dict[str,callable] + +### Test for checking if bad additional_handlers argument +### (which is not Dict[str,callable]) throws an error +@pytest.mark.parametrize( + "additional_handlers", + [ + "thisisnotokay", + [], + {lambda x: {"x": x}: "notokayeither"}, + {"thisisfine": "thisisnot"}, + ], +) +def test_run_with_bad_additional_data_handlers(additional_handlers): + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + with pytest.raises( + ValueError, match="Handlers should be of type Dict, str to callable" + ): + sft_trainer.train( + MODEL_ARGS, + DATA_ARGS, + train_args, + PEFT_PT_ARGS, + additional_data_handlers=additional_handlers, + ) + + +### Test for checking if additional_handlers=None should work +def test_run_with_additional_data_handlers_as_none(): + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + sft_trainer.train( + MODEL_ARGS, + DATA_ARGS, + train_args, + PEFT_PT_ARGS, + additional_data_handlers=None, + ) + + +### Test for checking if a good additional_handlers argument +### can take a data handler and can successfully run a e2e training. +def test_run_by_passing_additional_data_handlers(): + + # This is my test handler + TEST_HANDLER = "my_test_handler" + + def test_handler(element, tokenizer, **kwargs): + return apply_dataset_formatting(element, tokenizer, "custom_formatted_field") + + # This data config calls for data handler to be applied to dataset + preprocessor_config = DataPreProcessorConfig() + handler_config = DataHandlerConfig(name="my_test_handler", arguments=None) + dataaset_config = DataSetConfig( + name="test_dataset", + data_paths=TWITTER_COMPLAINTS_DATA_JSON, + data_handlers=[handler_config], + ) + data_config = DataConfig( + dataprocessor=preprocessor_config, datasets=[dataaset_config] + ) + + # dump the data config to a file, also test if json data config works + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".json" + ) as temp_data_file: + data_config_raw = json.dumps(asdict(data_config)) + temp_data_file.write(data_config_raw) + data_config_path = temp_data_file.name + + # now launch sft trainer after registering data handler + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + data_args = copy.deepcopy(DATA_ARGS) + data_args.data_config_path = data_config_path + data_args.dataset_text_field = "custom_formatted_field" + + sft_trainer.train( + MODEL_ARGS, + DATA_ARGS, + train_args, + PEFT_PT_ARGS, + additional_data_handlers={TEST_HANDLER: test_handler}, + ) diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index d266258d1..2475bff92 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -46,17 +46,23 @@ def __init__( # Initialize other objects self.registered_handlers = {} + # Auto register available data handlers + for k, v in AVAILABLE_DATA_HANDLERS.items(): + self.registered_handlers[k] = v + def register_data_handler(self, name: str, func: Callable): - assert isinstance(name, str), "Handler name should be of str type" - assert callable(func), "Handler should be a callable routine" + if not isinstance(name, str) or not callable(func): + raise ValueError("Handlers should be of type Dict, str to callable") + if name in self.registered_handlers: + logging.warning("Handler name %s existed is being overwritten", name) self.registered_handlers[name] = func + logging.info("Registered new handler %s", name) def register_data_handlers(self, handlers: Dict[str, Callable]): if handlers is None: return - assert isinstance( - handlers, Dict - ), "Handlers should be of type Dict[str:Callable]" + if not isinstance(handlers, Dict): + raise ValueError("Handlers should be of type Dict, str to callable") for k, v in handlers.items(): self.register_data_handler(name=k, func=v) @@ -250,11 +256,13 @@ def process_dataset_configs( def get_datapreprocessor( - processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer + processor_config: DataPreProcessorConfig, + tokenizer: AutoTokenizer, + additional_data_handlers: Dict[str, Callable] = None, ) -> DataPreProcessor: processor = DataPreProcessor( processor_config=processor_config, tokenizer=tokenizer, ) - processor.register_data_handlers(AVAILABLE_DATA_HANDLERS) + processor.register_data_handlers(additional_data_handlers) return processor diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index 22df51920..f9be9a23e 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -62,9 +62,10 @@ def _process_dataconfig_file( ): data_config = load_and_validate_data_config(data_args.data_config_path) processor = get_datapreprocessor( - processor_config=data_config.dataprocessor, tokenizer=tokenizer + processor_config=data_config.dataprocessor, + tokenizer=tokenizer, + additional_data_handlers=additional_data_handlers, ) - processor.register_data_handlers(additional_data_handlers) train_dataset = processor.process_dataset_configs(data_config.datasets) return (train_dataset, None, data_args.dataset_text_field) @@ -190,9 +191,10 @@ def _process_raw_data_args( # Create a data processor with default processor config default_processor_config = DataPreProcessorConfig() data_processor = get_datapreprocessor( - processor_config=default_processor_config, tokenizer=tokenizer + processor_config=default_processor_config, + tokenizer=tokenizer, + additional_data_handlers=additional_data_handlers, ) - data_processor.register_data_handlers(additional_data_handlers) assert isinstance( data_args.training_data_path, str ), "Training data path has to be set and str" From 00285d960ed4fa943f137c4a4821e1ad2971add2 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Thu, 12 Dec 2024 16:27:26 -0500 Subject: [PATCH 3/3] PR Changes Signed-off-by: Abhishek --- tests/test_sft_trainer.py | 14 ++++++++------ tuning/data/data_processors.py | 4 +++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 06045cbda..5dbc1144c 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1141,8 +1141,7 @@ def test_pretokenized_dataset_wrong_format(): ### The argument `additional_handlers` in train::sft_trainer.py is used to pass ### extra data handlers which should be a Dict[str,callable] -### Test for checking if bad additional_handlers argument -### (which is not Dict[str,callable]) throws an error + @pytest.mark.parametrize( "additional_handlers", [ @@ -1153,6 +1152,8 @@ def test_pretokenized_dataset_wrong_format(): ], ) def test_run_with_bad_additional_data_handlers(additional_handlers): + """Ensure that bad additional_handlers argument (which is not Dict[str,callable]) + throws an error""" with tempfile.TemporaryDirectory() as tempdir: train_args = copy.deepcopy(TRAIN_ARGS) train_args.output_dir = tempdir @@ -1169,8 +1170,8 @@ def test_run_with_bad_additional_data_handlers(additional_handlers): ) -### Test for checking if additional_handlers=None should work def test_run_with_additional_data_handlers_as_none(): + """Ensure that additional_handlers as None should work.""" with tempfile.TemporaryDirectory() as tempdir: train_args = copy.deepcopy(TRAIN_ARGS) train_args.output_dir = tempdir @@ -1182,12 +1183,12 @@ def test_run_with_additional_data_handlers_as_none(): PEFT_PT_ARGS, additional_data_handlers=None, ) + _validate_training(tempdir) -### Test for checking if a good additional_handlers argument -### can take a data handler and can successfully run a e2e training. def test_run_by_passing_additional_data_handlers(): - + """Ensure that good additional_handlers argument can take a + data handler and can successfully run a e2e training.""" # This is my test handler TEST_HANDLER = "my_test_handler" @@ -1229,3 +1230,4 @@ def test_handler(element, tokenizer, **kwargs): PEFT_PT_ARGS, additional_data_handlers={TEST_HANDLER: test_handler}, ) + _validate_training(tempdir) diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index 2475bff92..33a368314 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -54,7 +54,9 @@ def register_data_handler(self, name: str, func: Callable): if not isinstance(name, str) or not callable(func): raise ValueError("Handlers should be of type Dict, str to callable") if name in self.registered_handlers: - logging.warning("Handler name %s existed is being overwritten", name) + logging.warning( + "Handler name '%s' already exists and will be overwritten", name + ) self.registered_handlers[name] = func logging.info("Registered new handler %s", name)