Skip to content

Commit 8c054b7

Browse files
[Frontend] Clean up type annotations for mistral tokenizer (#8314)
1 parent 6234385 commit 8c054b7

File tree

6 files changed

+114
-59
lines changed

6 files changed

+114
-59
lines changed

tests/async_engine/test_chat_template.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

3-
from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
3+
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
4+
load_chat_template)
45
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
56
from vllm.transformers_utils.tokenizer import get_tokenizer
67

@@ -87,7 +88,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
8788
add_generation_prompt=add_generation_prompt)
8889

8990
# Call the function and get the result
90-
result = apply_chat_template(
91+
result = apply_hf_chat_template(
9192
tokenizer,
9293
conversation=mock_request.messages,
9394
chat_template=mock_request.chat_template or template_content,

vllm/entrypoints/chat_utils.py

+41-20
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# yapf: enable
2424
# pydantic needs the TypedDict from typing_extensions
2525
from pydantic import ConfigDict
26+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
2627
from typing_extensions import Required, TypeAlias, TypedDict
2728

2829
from vllm.config import ModelConfig
@@ -31,7 +32,7 @@
3132
from vllm.multimodal.utils import (async_get_and_parse_audio,
3233
async_get_and_parse_image,
3334
get_and_parse_audio, get_and_parse_image)
34-
from vllm.transformers_utils.tokenizer import AnyTokenizer
35+
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
3536

3637
logger = init_logger(__name__)
3738

@@ -379,6 +380,9 @@ def _parse_chat_message_content_parts(
379380
audio_url = _AudioParser(part)["audio_url"]
380381

381382
mm_parser.parse_audio(audio_url["url"])
383+
elif part_type == "refusal":
384+
text = _RefusalParser(part)["refusal"]
385+
texts.append(text)
382386
else:
383387
raise NotImplementedError(f"Unknown part type: {part_type}")
384388

@@ -433,6 +437,21 @@ def _parse_chat_message_content(
433437
return result
434438

435439

440+
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
441+
# per the Transformers docs & maintainers, tool call arguments in
442+
# assistant-role messages with tool_calls need to be dicts not JSON str -
443+
# this is how tool-use chat templates will expect them moving forwards
444+
# so, for messages that have tool_calls, parse the string (which we get
445+
# from openAI format) to dict
446+
for message in messages:
447+
if (message["role"] == "assistant" and "tool_calls" in message
448+
and isinstance(message["tool_calls"], list)):
449+
450+
for item in message["tool_calls"]:
451+
item["function"]["arguments"] = json.loads(
452+
item["function"]["arguments"])
453+
454+
436455
def parse_chat_messages(
437456
messages: List[ChatCompletionMessageParam],
438457
model_config: ModelConfig,
@@ -446,6 +465,8 @@ def parse_chat_messages(
446465

447466
conversation.extend(sub_messages)
448467

468+
_postprocess_messages(conversation)
469+
449470
return conversation, mm_tracker.all_mm_data()
450471

451472

@@ -462,41 +483,41 @@ def parse_chat_messages_futures(
462483

463484
conversation.extend(sub_messages)
464485

486+
_postprocess_messages(conversation)
487+
465488
return conversation, mm_tracker.all_mm_data()
466489

467490

468-
def apply_chat_template(
469-
tokenizer: AnyTokenizer,
491+
def apply_hf_chat_template(
492+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
470493
conversation: List[ConversationMessage],
471494
chat_template: Optional[str],
472495
*,
473496
tokenize: bool = False, # Different from HF's default
474497
**kwargs: Any,
475-
) -> Union[str, List[int]]:
498+
) -> str:
476499
if chat_template is None and tokenizer.chat_template is None:
477500
raise ValueError(
478501
"As of transformers v4.44, default chat template is no longer "
479502
"allowed, so you must provide a chat template if the tokenizer "
480503
"does not define one.")
481504

482-
# per the Transformers docs & maintainers, tool call arguments in
483-
# assistant-role messages with tool_calls need to be dicts not JSON str -
484-
# this is how tool-use chat templates will expect them moving forwards
485-
# so, for messages that have tool_calls, parse the string (which we get
486-
# from openAI format) to dict
487-
for message in conversation:
488-
if (message["role"] == "assistant" and "tool_calls" in message
489-
and isinstance(message["tool_calls"], list)):
505+
return tokenizer.apply_chat_template(
506+
conversation=conversation, # type: ignore[arg-type]
507+
chat_template=chat_template,
508+
tokenize=tokenize,
509+
**kwargs,
510+
)
490511

491-
for i in range(len(message["tool_calls"])):
492-
args: str = message["tool_calls"][i]["function"]["arguments"]
493-
parsed_args: Dict = json.loads(args)
494-
message["tool_calls"][i]["function"]["arguments"] = parsed_args
495512

496-
prompt = tokenizer.apply_chat_template(
497-
conversation=conversation,
513+
def apply_mistral_chat_template(
514+
tokenizer: MistralTokenizer,
515+
messages: List[ChatCompletionMessageParam],
516+
chat_template: Optional[str],
517+
**kwargs: Any,
518+
) -> List[int]:
519+
return tokenizer.apply_chat_template(
520+
messages=messages,
498521
chat_template=chat_template,
499-
tokenize=tokenize,
500522
**kwargs,
501523
)
502-
return prompt

vllm/entrypoints/llm.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from vllm.engine.arg_utils import EngineArgs
77
from vllm.engine.llm_engine import LLMEngine
88
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
9-
apply_chat_template,
9+
apply_hf_chat_template,
10+
apply_mistral_chat_template,
1011
parse_chat_messages)
1112
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
1213
from vllm.inputs.parse import parse_and_batch_prompt
@@ -19,7 +20,7 @@
1920
from vllm.pooling_params import PoolingParams
2021
from vllm.prompt_adapter.request import PromptAdapterRequest
2122
from vllm.sampling_params import SamplingParams
22-
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
23+
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
2324
get_cached_tokenizer)
2425
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
2526
from vllm.usage.usage_lib import UsageContext
@@ -393,12 +394,21 @@ def chat(
393394
conversation, mm_data = parse_chat_messages(messages, model_config,
394395
tokenizer)
395396

396-
prompt = apply_chat_template(
397-
tokenizer,
398-
conversation,
399-
chat_template=chat_template,
400-
add_generation_prompt=add_generation_prompt,
401-
)
397+
prompt: Union[str, List[int]]
398+
if isinstance(tokenizer, MistralTokenizer):
399+
prompt = apply_mistral_chat_template(
400+
tokenizer,
401+
messages=messages,
402+
chat_template=chat_template,
403+
add_generation_prompt=add_generation_prompt,
404+
)
405+
else:
406+
prompt = apply_hf_chat_template(
407+
tokenizer,
408+
conversation=conversation,
409+
chat_template=chat_template,
410+
add_generation_prompt=add_generation_prompt,
411+
)
402412

403413
inputs: PromptInputs
404414
if is_list_of(prompt, int):

vllm/entrypoints/openai/serving_chat.py

+30-18
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from vllm.config import ModelConfig
1212
from vllm.engine.protocol import AsyncEngineClient
1313
from vllm.entrypoints.chat_utils import (ConversationMessage,
14-
apply_chat_template,
14+
apply_hf_chat_template,
15+
apply_mistral_chat_template,
1516
load_chat_template,
1617
parse_chat_messages_futures)
1718
from vllm.entrypoints.logger import RequestLogger
@@ -35,7 +36,7 @@
3536
from vllm.sequence import Logprob
3637
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
3738
log_tracing_disabled_warning)
38-
from vllm.transformers_utils.tokenizer import AnyTokenizer
39+
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
3940
from vllm.utils import iterate_with_cancellation, random_uuid
4041

4142
logger = init_logger(__name__)
@@ -121,15 +122,27 @@ async def create_chat_completion(
121122
tool.model_dump() for tool in request.tools
122123
]
123124

124-
prompt = apply_chat_template(
125-
tokenizer,
126-
conversation=conversation,
127-
chat_template=request.chat_template or self.chat_template,
128-
add_generation_prompt=request.add_generation_prompt,
129-
tools=tool_dicts,
130-
documents=request.documents,
131-
**(request.chat_template_kwargs or {}),
132-
)
125+
prompt: Union[str, List[int]]
126+
if isinstance(tokenizer, MistralTokenizer):
127+
prompt = apply_mistral_chat_template(
128+
tokenizer,
129+
messages=request.messages,
130+
chat_template=request.chat_template or self.chat_template,
131+
add_generation_prompt=request.add_generation_prompt,
132+
tools=tool_dicts,
133+
documents=request.documents,
134+
**(request.chat_template_kwargs or {}),
135+
)
136+
else:
137+
prompt = apply_hf_chat_template(
138+
tokenizer,
139+
conversation=conversation,
140+
chat_template=request.chat_template or self.chat_template,
141+
add_generation_prompt=request.add_generation_prompt,
142+
tools=tool_dicts,
143+
documents=request.documents,
144+
**(request.chat_template_kwargs or {}),
145+
)
133146
except Exception as e:
134147
logger.error("Error in applying chat template from request: %s", e)
135148
return self.create_error_response(str(e))
@@ -307,11 +320,10 @@ async def chat_completion_stream_generator(
307320
# Send response to echo the input portion of the
308321
# last message
309322
if request.echo:
310-
last_msg_content: Optional[str] = ""
311-
if conversation and conversation[-1].get(
312-
"content") and conversation[-1].get(
313-
"role") == role:
314-
last_msg_content = conversation[-1]["content"]
323+
last_msg_content: str = ""
324+
if conversation and "content" in conversation[
325+
-1] and conversation[-1].get("role") == role:
326+
last_msg_content = conversation[-1]["content"] or ""
315327

316328
if last_msg_content:
317329
for i in range(num_choices):
@@ -659,8 +671,8 @@ async def chat_completion_full_generator(
659671

660672
if request.echo:
661673
last_msg_content = ""
662-
if conversation and conversation[-1].get(
663-
"content") and conversation[-1].get("role") == role:
674+
if conversation and "content" in conversation[-1] and conversation[
675+
-1].get("role") == role:
664676
last_msg_content = conversation[-1]["content"] or ""
665677

666678
for choice in choices:

vllm/entrypoints/openai/serving_tokenization.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from vllm.config import ModelConfig
44
from vllm.engine.protocol import AsyncEngineClient
5-
from vllm.entrypoints.chat_utils import (apply_chat_template,
5+
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
6+
apply_mistral_chat_template,
67
load_chat_template,
78
parse_chat_messages_futures)
89
from vllm.entrypoints.logger import RequestLogger
@@ -18,6 +19,7 @@
1819
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
1920
OpenAIServing)
2021
from vllm.logger import init_logger
22+
from vllm.transformers_utils.tokenizer import MistralTokenizer
2123
from vllm.utils import random_uuid
2224

2325
logger = init_logger(__name__)
@@ -66,6 +68,7 @@ async def create_tokenize(
6668

6769
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
6870

71+
prompt: Union[str, List[int]]
6972
if isinstance(request, TokenizeChatRequest):
7073
model_config = self.model_config
7174

@@ -77,12 +80,20 @@ async def create_tokenize(
7780
logger.warning(
7881
"Multi-modal inputs are ignored during tokenization")
7982

80-
prompt = apply_chat_template(
81-
tokenizer,
82-
conversation=conversation,
83-
chat_template=self.chat_template,
84-
add_generation_prompt=request.add_generation_prompt,
85-
)
83+
if isinstance(tokenizer, MistralTokenizer):
84+
prompt = apply_mistral_chat_template(
85+
tokenizer,
86+
messages=request.messages,
87+
chat_template=self.chat_template,
88+
add_generation_prompt=request.add_generation_prompt,
89+
)
90+
else:
91+
prompt = apply_hf_chat_template(
92+
tokenizer,
93+
conversation=conversation,
94+
chat_template=self.chat_template,
95+
add_generation_prompt=request.add_generation_prompt,
96+
)
8697
else:
8798
prompt = request.prompt
8899

vllm/transformers_utils/tokenizers/mistral.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
Tekkenizer)
1717

1818
if TYPE_CHECKING:
19-
from vllm.entrypoints.chat_utils import ConversationMessage
19+
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
2020

2121

2222
@dataclass
@@ -122,19 +122,19 @@ def get_added_vocab(self) -> List[str]:
122122
return []
123123

124124
def encode(self, prompt: str) -> List[int]:
125-
# `encode ` should only be used for prompt completion
125+
# `encode` should only be used for prompt completion
126126
# it should never be used for chat_completion.
127127
# For chat completion use `apply_chat_template`
128128
return self.tokenizer.encode(prompt, bos=True, eos=False)
129129

130130
def apply_chat_template(self,
131-
conversation: List["ConversationMessage"],
131+
messages: List["ChatCompletionMessageParam"],
132132
tools: Optional[Dict[str, Any]] = None,
133133
**kwargs) -> List[int]:
134134
assert tools is None, "`tools` are not yet supported."
135135

136136
request = ChatCompletionRequest(
137-
messages=conversation) # type: ignore[type-var]
137+
messages=messages) # type: ignore[type-var]
138138
encoded = self.mistral.encode_chat_completion(request)
139139

140140
# encode-decode to get clean prompt

0 commit comments

Comments
 (0)