Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: Expose additional data handlers as an argument in train #409

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
"""Unit Tests for SFT Trainer.
"""

# pylint: disable=too-many-lines

# Standard
from dataclasses import asdict
import copy
import json
import os
Expand Down Expand Up @@ -46,6 +49,13 @@
from tuning import sft_trainer
from tuning.config import configs, peft_config
from tuning.config.tracker_configs import FileLoggingTrackerConfig
from tuning.data.data_config import (
DataConfig,
DataHandlerConfig,
DataPreProcessorConfig,
DataSetConfig,
)
from tuning.data.data_handlers import apply_dataset_formatting

MODEL_ARGS = configs.ModelArguments(
model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32"
Expand Down Expand Up @@ -1124,3 +1134,100 @@ def test_pretokenized_dataset_wrong_format():
# is essentially swallowing a KeyError here.
with pytest.raises(ValueError):
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)


###########################################################################
### Tests for checking different cases for the argument additional_handlers
### The argument `additional_handlers` in train::sft_trainer.py is used to pass
### extra data handlers which should be a Dict[str,callable]


@pytest.mark.parametrize(
"additional_handlers",
[
"thisisnotokay",
[],
{lambda x: {"x": x}: "notokayeither"},
{"thisisfine": "thisisnot"},
],
)
def test_run_with_bad_additional_data_handlers(additional_handlers):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: docstrings for all added test cases

Copy link
Contributor Author

@dushyantbehl dushyantbehl Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments. Thanks @willmj

"""Ensure that bad additional_handlers argument (which is not Dict[str,callable])
throws an error"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

with pytest.raises(
ValueError, match="Handlers should be of type Dict, str to callable"
):
sft_trainer.train(
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_PT_ARGS,
additional_data_handlers=additional_handlers,
)


def test_run_with_additional_data_handlers_as_none():
"""Ensure that additional_handlers as None should work."""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

sft_trainer.train(
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_PT_ARGS,
additional_data_handlers=None,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just add this line after this:

_validate_training(tempdir)

_validate_training(tempdir)


def test_run_by_passing_additional_data_handlers():
"""Ensure that good additional_handlers argument can take a
data handler and can successfully run a e2e training."""
# This is my test handler
TEST_HANDLER = "my_test_handler"

def test_handler(element, tokenizer, **kwargs):
return apply_dataset_formatting(element, tokenizer, "custom_formatted_field")

# This data config calls for data handler to be applied to dataset
preprocessor_config = DataPreProcessorConfig()
handler_config = DataHandlerConfig(name="my_test_handler", arguments=None)
dataaset_config = DataSetConfig(
name="test_dataset",
data_paths=TWITTER_COMPLAINTS_DATA_JSON,
data_handlers=[handler_config],
)
data_config = DataConfig(
dataprocessor=preprocessor_config, datasets=[dataaset_config]
)

# dump the data config to a file, also test if json data config works
with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".json"
) as temp_data_file:
data_config_raw = json.dumps(asdict(data_config))
temp_data_file.write(data_config_raw)
data_config_path = temp_data_file.name

# now launch sft trainer after registering data handler
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
data_args = copy.deepcopy(DATA_ARGS)
data_args.data_config_path = data_config_path
data_args.dataset_text_field = "custom_formatted_field"

sft_trainer.train(
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_PT_ARGS,
additional_data_handlers={TEST_HANDLER: test_handler},
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here also, we can just add this line after this:

_validate_training(tempdir)

_validate_training(tempdir)
38 changes: 26 additions & 12 deletions tuning/data/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

# Standard
from typing import Dict, List, Union
from typing import Callable, Dict, List, Union
import logging
import os

Expand All @@ -35,7 +35,7 @@ class DataPreProcessor:
tokenizer = None
data_config: DataConfig = None
processor_config: DataPreProcessorConfig = None
registered_handlers: Dict[str, callable] = None
registered_handlers: Dict[str, Callable] = None

def __init__(
self, processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer
Expand All @@ -46,8 +46,27 @@ def __init__(
# Initialize other objects
self.registered_handlers = {}

def register_data_handler(self, name: str, func: callable):
# Auto register available data handlers
for k, v in AVAILABLE_DATA_HANDLERS.items():
self.registered_handlers[k] = v

def register_data_handler(self, name: str, func: Callable):
if not isinstance(name, str) or not callable(func):
raise ValueError("Handlers should be of type Dict, str to callable")
if name in self.registered_handlers:
logging.warning(
"Handler name '%s' already exists and will be overwritten", name
)
self.registered_handlers[name] = func
dushyantbehl marked this conversation as resolved.
Show resolved Hide resolved
logging.info("Registered new handler %s", name)

def register_data_handlers(self, handlers: Dict[str, Callable]):
if handlers is None:
return
if not isinstance(handlers, Dict):
raise ValueError("Handlers should be of type Dict, str to callable")
for k, v in handlers.items():
self.register_data_handler(name=k, func=v)

def load_dataset(
self,
Expand Down Expand Up @@ -238,19 +257,14 @@ def process_dataset_configs(
return train_dataset


def autoregister_available_handlers(processor: DataPreProcessor):
if processor is None:
return
for name, func in AVAILABLE_DATA_HANDLERS.items():
processor.register_data_handler(name=name, func=func)


def get_datapreprocessor(
processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer
processor_config: DataPreProcessorConfig,
tokenizer: AutoTokenizer,
additional_data_handlers: Dict[str, Callable] = None,
) -> DataPreProcessor:
processor = DataPreProcessor(
processor_config=processor_config,
tokenizer=tokenizer,
)
autoregister_available_handlers(processor)
processor.register_data_handlers(additional_data_handlers)
return processor
55 changes: 39 additions & 16 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

# Standard
from typing import Union
from typing import Callable, Dict, Union
import logging

# Third Party
Expand Down Expand Up @@ -55,10 +55,16 @@ def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]):

# TODO: For now assume only training dataset is passed via data config file.
# This is very limited but is done to keep first implementation minimal
def _process_dataconfig_file(data_args: DataArguments, tokenizer: AutoTokenizer):
def _process_dataconfig_file(
data_args: DataArguments,
tokenizer: AutoTokenizer,
additional_data_handlers: Dict[str, Callable] = None,
):
data_config = load_and_validate_data_config(data_args.data_config_path)
processor = get_datapreprocessor(
processor_config=data_config.dataprocessor, tokenizer=tokenizer
processor_config=data_config.dataprocessor,
tokenizer=tokenizer,
additional_data_handlers=additional_data_handlers,
)
train_dataset = processor.process_dataset_configs(data_config.datasets)

Expand Down Expand Up @@ -179,14 +185,16 @@ def _process_raw_data_args(
tokenizer: AutoTokenizer,
packing: bool,
max_seq_length: int,
additional_data_handlers: Dict[str, Callable] = None,
):

# Create a data processor with default processor config
default_processor_config = DataPreProcessorConfig()
data_processor = get_datapreprocessor(
processor_config=default_processor_config, tokenizer=tokenizer
processor_config=default_processor_config,
tokenizer=tokenizer,
additional_data_handlers=additional_data_handlers,
)

assert isinstance(
data_args.training_data_path, str
), "Training data path has to be set and str"
Expand Down Expand Up @@ -259,7 +267,10 @@ def _process_raw_data_args(
# If no data config file is specified, process the remaining data arguments
# to determine the use case based on their presence, as explained in _process_raw_data_args.
def process_dataargs(
data_args: DataArguments, tokenizer: AutoTokenizer, train_args: TrainingArguments
data_args: DataArguments,
tokenizer: AutoTokenizer,
train_args: TrainingArguments,
additional_data_handlers: Dict[str, Callable] = None,
):
"""
Args:
Expand All @@ -268,11 +279,17 @@ def process_dataargs(
train_args: TrainingArguments
Training arguments passed to the library
Used for packing and max_seq_length
additional_data_handlers: A Dict of [str, callable] data handlers
which need to be registered with the data preprocessor
Returns:
Tuple(Dataset, Dataset, str, DataCollator, int, Dict)
tuple containing train_dataset, eval_dataset, dataset_text_field,
data_collator, max_seq_length and dataset_kwargs

tuple containing
train_dataset (Dataset/IterableDataset),
eval_dataset (Dataset/IterableDataset),
dataset_text_field (str),
data_collator (DataCollator)
max_seq_length(int) and
dataset_kwargs (Dict)
"""

max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
Expand All @@ -290,26 +307,32 @@ def process_dataargs(

if data_args.data_config_path:
train_dataset, eval_dataset, dataset_text_field = _process_dataconfig_file(
data_args, tokenizer
data_args, tokenizer, additional_data_handlers
)
else:
train_dataset, eval_dataset, dataset_text_field = _process_raw_data_args(
data_args, tokenizer, train_args.packing, max_seq_length
data_args,
tokenizer,
train_args.packing,
max_seq_length,
additional_data_handlers,
)

# Note: This check should not be removed.
# Its important to recompute this post handling to
# check if we already tokenized the dataset or not.
is_tokenized_dataset = is_pretokenized_dataset(train_dataset or eval_dataset)

data_collator = get_data_collator(
train_args.packing,
data_args.response_template,
tokenizer,
# Note: This check should not be removed.
# Its important to recompute this post handling to
# check if we already tokenized the dataset or not.
is_pretokenized_dataset(train_dataset),
is_tokenized_dataset,
max_seq_length,
)

dataset_kwargs = {}
if is_pretokenized_dataset(train_dataset or eval_dataset):
if is_tokenized_dataset:
dataset_kwargs["skip_prepare_dataset"] = True

return (
Expand Down
8 changes: 5 additions & 3 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

# Standard
from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union
import dataclasses
import json
import logging
Expand Down Expand Up @@ -85,6 +85,7 @@ def train(
attention_and_distributed_packing_config: Optional[
AttentionAndDistributedPackingConfig
] = None,
additional_data_handlers: Optional[Dict[str, Callable]] = None,
) -> tuple[SFTTrainer, dict]:
"""Call the SFTTrainer

Expand Down Expand Up @@ -113,7 +114,8 @@ def train(
Should be used in combination with quantized_lora_config. Also currently
fused_lora and fast_kernels must used together (may change in future). \
attention_and_distributed_packing_config: Used for padding-free attention and multipack.

additional_data_handlers: Dict [str:Callable] of any extra data handlers \
to be registered with the data preprocessor
Returns:
Tuple: Instance of SFTTrainer , some metadata in a dict
Metadata contains information on number of added tokens while tuning.
Expand Down Expand Up @@ -297,7 +299,7 @@ def train(
data_collator,
max_seq_length,
dataset_kwargs,
) = process_dataargs(data_args, tokenizer, train_args)
) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers)
additional_metrics["data_preprocessing_time"] = (
time.time() - data_preprocessing_time
)
Expand Down
Loading