Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Explicit dataset tokenizer
text
kwarg (#1031)
## Purpose ## * Allow VLM processors to be used to tokenize datasets with prompt keys ## Postrequisites ## * #1030 ## Changes ## * Use `text` argument name for tokenizing the prompt column ## Testing ## * w.r.t. tokenizers, using the `text` kwarg follows the precedent set by [PretrainedTokenizerBase](https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L2790) * w.r.t. processors, most processors use the text kwarg Below are all the models I know to be compatible with this change, I'm assuming that most other processors follow the same standard 1. [llama](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py#L233) 2. [pixtral](https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/processing_pixtral.py#L160) 3. [phi3_vision](https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L321) 4. [mllama](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/processing_mllama.py#L232) 5. [qwen2_vl](https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/processing_qwen2_vl.py#L71) Example of using VLM processor to tokenize a dataset with prompt key ```python3 from transformers import AutoProcessor from llmcompressor.transformers import DataTrainingArguments, TextGenerationDataset models_to_test = [ "meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mixtral-8x7B-Instruct-v0.1", "Qwen/Qwen2-VL-2B-Instruct", # fails without changes "mgoin/pixtral-12b", # fails without changes ] for model_id in models_to_test: processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) data_args = DataTrainingArguments( dataset="ultrachat-200k", splits={"calibration": "test_sft[:1]"} ) dataset = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, split=data_args.splits["calibration"], processor=processor, )(add_labels=False) ``` Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
- Loading branch information