diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 69ccbf4fa..5dbc1144c 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -15,7 +15,10 @@ """Unit Tests for SFT Trainer. """ +# pylint: disable=too-many-lines + # Standard +from dataclasses import asdict import copy import json import os @@ -46,6 +49,13 @@ from tuning import sft_trainer from tuning.config import configs, peft_config from tuning.config.tracker_configs import FileLoggingTrackerConfig +from tuning.data.data_config import ( + DataConfig, + DataHandlerConfig, + DataPreProcessorConfig, + DataSetConfig, +) +from tuning.data.data_handlers import apply_dataset_formatting MODEL_ARGS = configs.ModelArguments( model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32" @@ -1124,3 +1134,100 @@ def test_pretokenized_dataset_wrong_format(): # is essentially swallowing a KeyError here. with pytest.raises(ValueError): sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) + + +########################################################################### +### Tests for checking different cases for the argument additional_handlers +### The argument `additional_handlers` in train::sft_trainer.py is used to pass +### extra data handlers which should be a Dict[str,callable] + + +@pytest.mark.parametrize( + "additional_handlers", + [ + "thisisnotokay", + [], + {lambda x: {"x": x}: "notokayeither"}, + {"thisisfine": "thisisnot"}, + ], +) +def test_run_with_bad_additional_data_handlers(additional_handlers): + """Ensure that bad additional_handlers argument (which is not Dict[str,callable]) + throws an error""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + with pytest.raises( + ValueError, match="Handlers should be of type Dict, str to callable" + ): + sft_trainer.train( + MODEL_ARGS, + DATA_ARGS, + train_args, + PEFT_PT_ARGS, + additional_data_handlers=additional_handlers, + ) + + +def test_run_with_additional_data_handlers_as_none(): + """Ensure that additional_handlers as None should work.""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + sft_trainer.train( + MODEL_ARGS, + DATA_ARGS, + train_args, + PEFT_PT_ARGS, + additional_data_handlers=None, + ) + _validate_training(tempdir) + + +def test_run_by_passing_additional_data_handlers(): + """Ensure that good additional_handlers argument can take a + data handler and can successfully run a e2e training.""" + # This is my test handler + TEST_HANDLER = "my_test_handler" + + def test_handler(element, tokenizer, **kwargs): + return apply_dataset_formatting(element, tokenizer, "custom_formatted_field") + + # This data config calls for data handler to be applied to dataset + preprocessor_config = DataPreProcessorConfig() + handler_config = DataHandlerConfig(name="my_test_handler", arguments=None) + dataaset_config = DataSetConfig( + name="test_dataset", + data_paths=TWITTER_COMPLAINTS_DATA_JSON, + data_handlers=[handler_config], + ) + data_config = DataConfig( + dataprocessor=preprocessor_config, datasets=[dataaset_config] + ) + + # dump the data config to a file, also test if json data config works + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".json" + ) as temp_data_file: + data_config_raw = json.dumps(asdict(data_config)) + temp_data_file.write(data_config_raw) + data_config_path = temp_data_file.name + + # now launch sft trainer after registering data handler + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + data_args = copy.deepcopy(DATA_ARGS) + data_args.data_config_path = data_config_path + data_args.dataset_text_field = "custom_formatted_field" + + sft_trainer.train( + MODEL_ARGS, + DATA_ARGS, + train_args, + PEFT_PT_ARGS, + additional_data_handlers={TEST_HANDLER: test_handler}, + ) + _validate_training(tempdir) diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index e92b7b684..33a368314 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -13,7 +13,7 @@ # limitations under the License. # Standard -from typing import Dict, List, Union +from typing import Callable, Dict, List, Union import logging import os @@ -35,7 +35,7 @@ class DataPreProcessor: tokenizer = None data_config: DataConfig = None processor_config: DataPreProcessorConfig = None - registered_handlers: Dict[str, callable] = None + registered_handlers: Dict[str, Callable] = None def __init__( self, processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer @@ -46,8 +46,27 @@ def __init__( # Initialize other objects self.registered_handlers = {} - def register_data_handler(self, name: str, func: callable): + # Auto register available data handlers + for k, v in AVAILABLE_DATA_HANDLERS.items(): + self.registered_handlers[k] = v + + def register_data_handler(self, name: str, func: Callable): + if not isinstance(name, str) or not callable(func): + raise ValueError("Handlers should be of type Dict, str to callable") + if name in self.registered_handlers: + logging.warning( + "Handler name '%s' already exists and will be overwritten", name + ) self.registered_handlers[name] = func + logging.info("Registered new handler %s", name) + + def register_data_handlers(self, handlers: Dict[str, Callable]): + if handlers is None: + return + if not isinstance(handlers, Dict): + raise ValueError("Handlers should be of type Dict, str to callable") + for k, v in handlers.items(): + self.register_data_handler(name=k, func=v) def load_dataset( self, @@ -238,19 +257,14 @@ def process_dataset_configs( return train_dataset -def autoregister_available_handlers(processor: DataPreProcessor): - if processor is None: - return - for name, func in AVAILABLE_DATA_HANDLERS.items(): - processor.register_data_handler(name=name, func=func) - - def get_datapreprocessor( - processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer + processor_config: DataPreProcessorConfig, + tokenizer: AutoTokenizer, + additional_data_handlers: Dict[str, Callable] = None, ) -> DataPreProcessor: processor = DataPreProcessor( processor_config=processor_config, tokenizer=tokenizer, ) - autoregister_available_handlers(processor) + processor.register_data_handlers(additional_data_handlers) return processor diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index 5db8e0aee..f9be9a23e 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -13,7 +13,7 @@ # limitations under the License. # Standard -from typing import Union +from typing import Callable, Dict, Union import logging # Third Party @@ -55,10 +55,16 @@ def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]): # TODO: For now assume only training dataset is passed via data config file. # This is very limited but is done to keep first implementation minimal -def _process_dataconfig_file(data_args: DataArguments, tokenizer: AutoTokenizer): +def _process_dataconfig_file( + data_args: DataArguments, + tokenizer: AutoTokenizer, + additional_data_handlers: Dict[str, Callable] = None, +): data_config = load_and_validate_data_config(data_args.data_config_path) processor = get_datapreprocessor( - processor_config=data_config.dataprocessor, tokenizer=tokenizer + processor_config=data_config.dataprocessor, + tokenizer=tokenizer, + additional_data_handlers=additional_data_handlers, ) train_dataset = processor.process_dataset_configs(data_config.datasets) @@ -179,14 +185,16 @@ def _process_raw_data_args( tokenizer: AutoTokenizer, packing: bool, max_seq_length: int, + additional_data_handlers: Dict[str, Callable] = None, ): # Create a data processor with default processor config default_processor_config = DataPreProcessorConfig() data_processor = get_datapreprocessor( - processor_config=default_processor_config, tokenizer=tokenizer + processor_config=default_processor_config, + tokenizer=tokenizer, + additional_data_handlers=additional_data_handlers, ) - assert isinstance( data_args.training_data_path, str ), "Training data path has to be set and str" @@ -259,7 +267,10 @@ def _process_raw_data_args( # If no data config file is specified, process the remaining data arguments # to determine the use case based on their presence, as explained in _process_raw_data_args. def process_dataargs( - data_args: DataArguments, tokenizer: AutoTokenizer, train_args: TrainingArguments + data_args: DataArguments, + tokenizer: AutoTokenizer, + train_args: TrainingArguments, + additional_data_handlers: Dict[str, Callable] = None, ): """ Args: @@ -268,11 +279,17 @@ def process_dataargs( train_args: TrainingArguments Training arguments passed to the library Used for packing and max_seq_length + additional_data_handlers: A Dict of [str, callable] data handlers + which need to be registered with the data preprocessor Returns: Tuple(Dataset, Dataset, str, DataCollator, int, Dict) - tuple containing train_dataset, eval_dataset, dataset_text_field, - data_collator, max_seq_length and dataset_kwargs - + tuple containing + train_dataset (Dataset/IterableDataset), + eval_dataset (Dataset/IterableDataset), + dataset_text_field (str), + data_collator (DataCollator) + max_seq_length(int) and + dataset_kwargs (Dict) """ max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) @@ -290,26 +307,32 @@ def process_dataargs( if data_args.data_config_path: train_dataset, eval_dataset, dataset_text_field = _process_dataconfig_file( - data_args, tokenizer + data_args, tokenizer, additional_data_handlers ) else: train_dataset, eval_dataset, dataset_text_field = _process_raw_data_args( - data_args, tokenizer, train_args.packing, max_seq_length + data_args, + tokenizer, + train_args.packing, + max_seq_length, + additional_data_handlers, ) + # Note: This check should not be removed. + # Its important to recompute this post handling to + # check if we already tokenized the dataset or not. + is_tokenized_dataset = is_pretokenized_dataset(train_dataset or eval_dataset) + data_collator = get_data_collator( train_args.packing, data_args.response_template, tokenizer, - # Note: This check should not be removed. - # Its important to recompute this post handling to - # check if we already tokenized the dataset or not. - is_pretokenized_dataset(train_dataset), + is_tokenized_dataset, max_seq_length, ) dataset_kwargs = {} - if is_pretokenized_dataset(train_dataset or eval_dataset): + if is_tokenized_dataset: dataset_kwargs["skip_prepare_dataset"] = True return ( diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index c02d73781..2ad55c06b 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -13,7 +13,7 @@ # limitations under the License. # Standard -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import dataclasses import json import logging @@ -85,6 +85,7 @@ def train( attention_and_distributed_packing_config: Optional[ AttentionAndDistributedPackingConfig ] = None, + additional_data_handlers: Optional[Dict[str, Callable]] = None, ) -> tuple[SFTTrainer, dict]: """Call the SFTTrainer @@ -113,7 +114,8 @@ def train( Should be used in combination with quantized_lora_config. Also currently fused_lora and fast_kernels must used together (may change in future). \ attention_and_distributed_packing_config: Used for padding-free attention and multipack. - + additional_data_handlers: Dict [str:Callable] of any extra data handlers \ + to be registered with the data preprocessor Returns: Tuple: Instance of SFTTrainer , some metadata in a dict Metadata contains information on number of added tokens while tuning. @@ -297,7 +299,7 @@ def train( data_collator, max_seq_length, dataset_kwargs, - ) = process_dataargs(data_args, tokenizer, train_args) + ) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers) additional_metrics["data_preprocessing_time"] = ( time.time() - data_preprocessing_time )