From 4168c8717a3ebb76e8d21e893ba823111db06ac0 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Tue, 10 Dec 2024 21:32:27 +0530 Subject: [PATCH] Perform dataset mixing via sampling probabilities in data config (#408) Code to perform dataset sampling via sampling probabilities in data Signed-off-by: Dushyant Behl --- .../predefined_data_configs/__init__.py | 9 +- .../apply_custom_template.yaml | 2 +- .../multiple_datasets_with_sampling.yaml | 41 ++++++ .../tokenize_and_apply_input_masking.yaml | 4 +- tests/data/test_data_preprocessing_utils.py | 129 ++++++++++++++++-- tuning/data/data_config.py | 39 ++++-- tuning/data/data_processors.py | 81 ++++++++--- 7 files changed, 255 insertions(+), 50 deletions(-) create mode 100644 tests/artifacts/predefined_data_configs/multiple_datasets_with_sampling.yaml diff --git a/tests/artifacts/predefined_data_configs/__init__.py b/tests/artifacts/predefined_data_configs/__init__.py index f9b766be6..c199406c6 100644 --- a/tests/artifacts/predefined_data_configs/__init__.py +++ b/tests/artifacts/predefined_data_configs/__init__.py @@ -19,12 +19,15 @@ ### Constants used for data PREDEFINED_DATA_CONFIGS = os.path.join(os.path.dirname(__file__)) -APPLY_CUSTOM_TEMPLATE_YAML = os.path.join( +DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML = os.path.join( PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml" ) -PRETOKENIZE_JSON_DATA_YAML = os.path.join( +DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML = os.path.join( PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml" ) -TOKENIZE_AND_APPLY_INPUT_MASKING_YAML = os.path.join( +DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML = os.path.join( PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking.yaml" ) +DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML = os.path.join( + PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling.yaml" +) diff --git a/tests/artifacts/predefined_data_configs/apply_custom_template.yaml b/tests/artifacts/predefined_data_configs/apply_custom_template.yaml index 4aab0d76a..c41797624 100644 --- a/tests/artifacts/predefined_data_configs/apply_custom_template.yaml +++ b/tests/artifacts/predefined_data_configs/apply_custom_template.yaml @@ -11,4 +11,4 @@ datasets: batched: false fn_kwargs: dataset_text_field: "dataset_text_field" - dataset_template: "dataset_template" \ No newline at end of file + template: "dataset_template" \ No newline at end of file diff --git a/tests/artifacts/predefined_data_configs/multiple_datasets_with_sampling.yaml b/tests/artifacts/predefined_data_configs/multiple_datasets_with_sampling.yaml new file mode 100644 index 000000000..3bfbb701a --- /dev/null +++ b/tests/artifacts/predefined_data_configs/multiple_datasets_with_sampling.yaml @@ -0,0 +1,41 @@ +dataprocessor: + type: default + sampling_stopping_strategy: first_exhausted + seed: 66 +datasets: + - name: dataset_1 + sampling: 0.3 + data_paths: + - "FILE_PATH" + data_handlers: + - name: tokenize_and_apply_input_masking + arguments: + remove_columns: all + batched: false + fn_kwargs: + input_field_name: input + output_field_name: output + - name: dataset_2 + sampling: 0.4 + data_paths: + - "FILE_PATH" + data_handlers: + - name: tokenize_and_apply_input_masking + arguments: + remove_columns: all + batched: false + fn_kwargs: + input_field_name: input + output_field_name: output + - name: dataset_3 + sampling: 0.3 + data_paths: + - "FILE_PATH" + data_handlers: + - name: tokenize_and_apply_input_masking + arguments: + remove_columns: all + batched: false + fn_kwargs: + input_field_name: input + output_field_name: output diff --git a/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml index d8fc16eec..b66b01d55 100644 --- a/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml +++ b/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml @@ -10,5 +10,5 @@ datasets: remove_columns: all batched: false fn_kwargs: - input_field: "INPUT" - output_field: "OUTPUT" \ No newline at end of file + input_field_name: "INPUT" + output_field_name: "OUTPUT" \ No newline at end of file diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index 6e7dacde8..c34204f4f 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -26,9 +26,10 @@ # First Party from tests.artifacts.predefined_data_configs import ( - APPLY_CUSTOM_TEMPLATE_YAML, - PRETOKENIZE_JSON_DATA_YAML, - TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, ) from tests.artifacts.testdata import ( MODEL_NAME, @@ -428,22 +429,22 @@ def test_process_data_args_throws_error_where_needed(data_args, packing): @pytest.mark.parametrize( "data_config_path, data_path", [ - (APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), - (APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), - (APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), - (PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), - (PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), - (PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET), - ( - TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON), + (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL), + (DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET), + (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON), + (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL), + (DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET), + ( + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, ), ( - TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, ), ( - TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, + DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, ), ], @@ -709,3 +710,105 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname): with open(datafile, "r") as file: data = json.load(file) assert len(train_dataset) == len(data) + + +@pytest.mark.parametrize( + "datafiles, sampling, datasetconfigname", + [ + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + [0.3, None, 0.3], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + [0.3, 0.5, 0.3], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), + ], +) +def test_process_dataset_configs_with_sampling_error( + datafiles, sampling, datasetconfigname +): + + data_args = configs.DataArguments() + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + TRAIN_ARGS = configs.TrainingArguments( + packing=False, + max_seq_length=1024, + output_dir="tmp", # Not needed but positional + ) + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + with open(datasetconfigname, "r") as f: + data = yaml.safe_load(f) + datasets = data["datasets"] + for i in range(len(datasets)): + d = datasets[i] + d["data_paths"][0] = datafiles[i] + d["sampling"] = sampling[i] + yaml.dump(data, temp_yaml_file) + data_args.data_config_path = temp_yaml_file.name + + with pytest.raises(ValueError): + (_, _, _, _, _, _) = process_dataargs( + data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS + ) + + +@pytest.mark.parametrize( + "datafiles, datasetconfigname", + [ + ( + [ + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, + ], + DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, + ), + ], +) +def test_process_dataset_configs_with_sampling(datafiles, datasetconfigname): + + data_args = configs.DataArguments() + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + TRAIN_ARGS = configs.TrainingArguments( + packing=False, + max_seq_length=1024, + output_dir="tmp", # Not needed but positional + ) + + with tempfile.NamedTemporaryFile( + "w", delete=False, suffix=".yaml" + ) as temp_yaml_file: + with open(datasetconfigname, "r") as f: + data = yaml.safe_load(f) + datasets = data["datasets"] + for i in range(len(datasets)): + d = datasets[i] + d["data_paths"][0] = datafiles[i] + yaml.dump(data, temp_yaml_file) + data_args.data_config_path = temp_yaml_file.name + + (train_set, eval_set, _, _, _, _) = process_dataargs( + data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS + ) + + assert isinstance(train_set, Dataset) + if eval_set: + assert isinstance(eval_set, Dataset) + + assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) + if eval_set: + assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names)) diff --git a/tuning/data/data_config.py b/tuning/data/data_config.py index 7e3ccd83b..4da83d720 100644 --- a/tuning/data/data_config.py +++ b/tuning/data/data_config.py @@ -32,13 +32,16 @@ class DataHandlerConfig: class DataSetConfig: name: str data_paths: List[str] - sampling: Optional[Dict] = None + sampling: Optional[float] = None data_handlers: Optional[List[DataHandlerConfig]] = None @dataclass class DataPreProcessorConfig: type: Optional[str] = "default" + sampling_stopping_strategy: Optional[str] = "all_exhausted" + # Default seed is not none to ensure reproducability + sampling_seed: Optional[float] = 42 @dataclass @@ -84,17 +87,12 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig: ) p = _p c.data_paths.append(p) - if "sampling" in kwargs: - sampling_kwargs = kwargs["sampling"] - assert isinstance( - dict, sampling_kwargs - ), "sampling arguments should be of the type dict" - if "ratio" in sampling_kwargs: - ratio = sampling_kwargs["ratio"] - assert isinstance(ratio, float) and ( - 0 <= ratio <= 1.0 - ), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]" - c.sampling = sampling_kwargs + if "sampling" in kwargs and kwargs["sampling"] is not None: + ratio = kwargs["sampling"] + assert isinstance(ratio, float) and ( + 0 <= ratio <= 1.0 + ), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]" + c.sampling = ratio if "data_handlers" in kwargs: c.data_handlers = [] for handler in kwargs["data_handlers"]: @@ -106,6 +104,23 @@ def _validate_dataprocessor_config(dataprocessor_config) -> DataPreProcessorConf kwargs = dataprocessor_config c = DataPreProcessorConfig() assert isinstance(kwargs, dict), "dataprocessor in data_config needs to be a dict" + if "type" in kwargs: + assert isinstance(kwargs["type"], str), "dataprocessor type must be a string" + c.type = kwargs["type"] + if "sampling_stopping_strategy" in kwargs: + strategy = kwargs["sampling_stopping_strategy"] + assert isinstance( + strategy, str + ), "dataset sampling stopping strategy must be a string" + assert strategy in [ + "first_exhausted", + "all_exhausted", + ], "allowed sampling stopping strategies are all_exhausted(default) or first_exhausted" + c.sampling_stopping_strategy = strategy + if "sampling_seed" in kwargs: + seed = kwargs["sampling_seed"] + assert isinstance(seed, int), "sampling seed should be int" + c.sampling_seed = seed return c diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index c3f38e3f1..e92b7b684 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -93,12 +93,35 @@ def load_dataset( def _process_dataset_configs( self, dataset_configs: List[DataSetConfig], **extra_kwargs ) -> Union[Dataset, IterableDataset]: - train_dataset = None - final_datasets = None + splitName = "train" # default + all_datasetdicts = [] + sampling_probabilities = [] + + # quick check to see if we are sampling and if we need to throw error. + sampling_probabilities = [d.sampling for d in dataset_configs if d.sampling] + + if len(sampling_probabilities) > 0: + if len(sampling_probabilities) != len(dataset_configs): + raise ValueError( + "Sampling probabilities should be provided for all datasets" + ) + if sum(p for p in sampling_probabilities) != 1: + raise ValueError("Sampling probabilities don't sum to 1") + sample_datasets = True + logging.info( + "Sampling ratios are specified; given datasets will be interleaved." + ) + else: + logging.info( + "Sampling is not specified; if multiple datasets are provided," + " the given datasets will be concatenated." + ) + sample_datasets = False + logging.info("Starting DataPreProcessor...") - # Iterate over the multiple datasets provided to us + # Now Iterate over the multiple datasets provided to us to process for d in dataset_configs: logging.info("Loading %s", d.name) @@ -115,9 +138,6 @@ def _process_dataset_configs( else: raw_datasets = raw_dataset - if d.sampling: - logging.warning("Sampling multiple datasets is not supported yet") - if d.data_handlers: # Execute the datahandlers for data_handler in d.data_handlers: handler_name: str = data_handler.name @@ -153,19 +173,42 @@ def _process_dataset_configs( raw_datasets = raw_datasets.map(handler, **kwargs) - if final_datasets is None: - final_datasets = raw_datasets - else: - for k in raw_datasets.keys(): - if k in final_datasets: - final_datasets[k] = datasets.concatenate_datasets( - [final_datasets[k], raw_datasets[k]] - ) - else: - final_datasets[k] = raw_datasets[k] - - if "train" in final_datasets: - train_dataset = final_datasets["train"] + # Append the processed datasets to the final dict + all_datasetdicts.append(raw_datasets) + + # This is a dict of { split: list[datasets] } + final_datasets = {} + for d in all_datasetdicts: + for k, v in d.items(): + if k not in final_datasets: + final_datasets[k] = [v] + else: + final_datasets[k].append(v) + + if sample_datasets: + strategy = self.processor_config.sampling_stopping_strategy + seed = self.processor_config.sampling_seed + logging.info( + "Interleaving datasets: strategy[%s] seed[%d] probabilities[%s]", + strategy, + seed, + str(sampling_probabilities), + ) + for k, v in final_datasets.items(): + interleaved = datasets.interleave_datasets( + datasets=v, + probabilities=sampling_probabilities, + stopping_strategy=strategy, + seed=seed, + ) + final_datasets[k] = interleaved + else: + for k, v in final_datasets.items(): + final_datasets[k] = ( + v[0] if len(v) == 1 else datasets.concatenate_datasets(v) + ) + + train_dataset = final_datasets.get("train", None) return train_dataset