From 668e1be9f443ec41c5b7d21ab85cfd1c86b09347 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 10 Jan 2025 15:27:40 -0500 Subject: [PATCH] Explicit dataset tokenizer `text` kwarg (#1031) ## Purpose ## * Allow VLM processors to be used to tokenize datasets with prompt keys ## Postrequisites ## * https://github.com/vllm-project/llm-compressor/pull/1030 ## Changes ## * Use `text` argument name for tokenizing the prompt column ## Testing ## * w.r.t. tokenizers, using the `text` kwarg follows the precedent set by [PretrainedTokenizerBase](https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L2790) * w.r.t. processors, most processors use the text kwarg Below are all the models I know to be compatible with this change, I'm assuming that most other processors follow the same standard 1. [llama](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py#L233) 2. [pixtral](https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/processing_pixtral.py#L160) 3. [phi3_vision](https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L321) 4. [mllama](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/processing_mllama.py#L232) 5. [qwen2_vl](https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/processing_qwen2_vl.py#L71) Example of using VLM processor to tokenize a dataset with prompt key ```python3 from transformers import AutoProcessor from llmcompressor.transformers import DataTrainingArguments, TextGenerationDataset models_to_test = [ "meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mixtral-8x7B-Instruct-v0.1", "Qwen/Qwen2-VL-2B-Instruct", # fails without changes "mgoin/pixtral-12b", # fails without changes ] for model_id in models_to_test: processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) data_args = DataTrainingArguments( dataset="ultrachat-200k", splits={"calibration": "test_sft[:1]"} ) dataset = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, split=data_args.splits["calibration"], processor=processor, )(add_labels=False) ``` Signed-off-by: Kyle Sayers --- src/llmcompressor/transformers/finetune/data/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index ddc348b89..81a3fc95f 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -253,7 +253,7 @@ def tokenize(self, data: LazyRow) -> Dict[str, Any]: # store unpadded prompt so we can mask out correct number of elements in labels if prompt is not None: data[self.PROMPT_KEY] = self.processor( - prompt, + text=prompt, max_length=self.max_seq_length, truncation=True, )["input_ids"]