diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 3db0c2ec0..dc3e3733e 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -295,9 +295,9 @@ def train( ( formatted_train_dataset, formatted_validation_dataset, - dataset_text_field, + data_args.dataset_text_field, data_collator, - max_seq_length, + train_args.max_seq_length, dataset_kwargs, ) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers) additional_metrics["data_preprocessing_time"] = ( @@ -327,7 +327,7 @@ def train( } additional_args = { - "dataset_text_field": dataset_text_field, + "dataset_text_field": data_args.dataset_text_field, "dataset_kwargs": dataset_kwargs, } training_args = SFTConfig(**transformer_kwargs, **additional_args)