Skip to content

Commit

Permalink
Explicit dataset tokenizer text kwarg (#1031)
Browse files Browse the repository at this point in the history
## Purpose ##
* Allow VLM processors to be used to tokenize datasets with prompt keys

## Postrequisites ##
* #1030

## Changes ##
* Use `text` argument name for tokenizing the prompt column

## Testing ##
* w.r.t. tokenizers, using the `text` kwarg follows the precedent set by
[PretrainedTokenizerBase](https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L2790)
* w.r.t. processors, most processors use the text kwarg

Below are all the models I know to be compatible with this change, I'm
assuming that most other processors follow the same standard
1.
[llama](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py#L233)
2.
[pixtral](https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/processing_pixtral.py#L160)
3.
[phi3_vision](https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L321)
4.
[mllama](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/processing_mllama.py#L232)
5.
[qwen2_vl](https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/processing_qwen2_vl.py#L71)

Example of using VLM processor to tokenize a dataset with prompt key
```python3
from transformers import AutoProcessor
from llmcompressor.transformers import DataTrainingArguments, TextGenerationDataset

models_to_test = [
  "meta-llama/Meta-Llama-3-8B-Instruct",
  "mistralai/Mixtral-8x7B-Instruct-v0.1",
  "Qwen/Qwen2-VL-2B-Instruct",  # fails without changes
  "mgoin/pixtral-12b",  # fails without changes
]

for model_id in models_to_test:
  processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
  
  data_args = DataTrainingArguments(
      dataset="ultrachat-200k",
      splits={"calibration": "test_sft[:1]"}
  )
  
  dataset = TextGenerationDataset.load_from_registry(
      data_args.dataset,
      data_args=data_args,
      split=data_args.splits["calibration"],
      processor=processor,
  )(add_labels=False)
```

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs authored Jan 10, 2025
1 parent c119add commit 668e1be
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def tokenize(self, data: LazyRow) -> Dict[str, Any]:
# store unpadded prompt so we can mask out correct number of elements in labels
if prompt is not None:
data[self.PROMPT_KEY] = self.processor(
prompt,
text=prompt,
max_length=self.max_seq_length,
truncation=True,
)["input_ids"]
Expand Down

0 comments on commit 668e1be

Please # to comment.