Skip to content

Commit 72c1353

Browse files
[Model] Support multiple images for qwen-vl (vllm-project#8247)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 7e7dbf8 commit 72c1353

File tree

4 files changed

+343
-65
lines changed

4 files changed

+343
-65
lines changed

docs/source/models/supported_models.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ Multimodal Language Models
254254
-
255255
* - :code:`QWenLMHeadModel`
256256
- Qwen-VL
257-
- Image\ :sup:`E`
257+
- Image\ :sup:`E+`
258258
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
259259
-
260260
* - :code:`Qwen2VLForConditionalGeneration`

examples/offline_inference_vision_language_multi_image.py

+60-24
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,39 @@
1919
]
2020

2121

22-
def load_phi3v(question, image_urls: List[str]):
22+
def load_qwenvl_chat(question: str, image_urls: List[str]):
23+
model_name = "Qwen/Qwen-VL-Chat"
24+
llm = LLM(
25+
model=model_name,
26+
trust_remote_code=True,
27+
max_num_seqs=5,
28+
limit_mm_per_prompt={"image": len(image_urls)},
29+
)
30+
placeholders = "".join(f"Picture {i}: <img></img>\n"
31+
for i, _ in enumerate(image_urls, start=1))
32+
33+
# This model does not have a chat_template attribute on its tokenizer,
34+
# so we need to explicitly pass it. We use ChatML since it's used in the
35+
# generation utils of the model:
36+
# https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
37+
tokenizer = AutoTokenizer.from_pretrained(model_name,
38+
trust_remote_code=True)
39+
40+
# Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
41+
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
42+
43+
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
44+
prompt = tokenizer.apply_chat_template(messages,
45+
tokenize=False,
46+
add_generation_prompt=True,
47+
chat_template=chat_template)
48+
49+
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
50+
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
51+
return llm, prompt, stop_token_ids, None, chat_template
52+
53+
54+
def load_phi3v(question: str, image_urls: List[str]):
2355
llm = LLM(
2456
model="microsoft/Phi-3.5-vision-instruct",
2557
trust_remote_code=True,
@@ -30,10 +62,10 @@ def load_phi3v(question, image_urls: List[str]):
3062
for i, _ in enumerate(image_urls, start=1))
3163
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
3264
stop_token_ids = None
33-
return llm, prompt, stop_token_ids, None
65+
return llm, prompt, stop_token_ids, None, None
3466

3567

36-
def load_internvl(question, image_urls: List[str]):
68+
def load_internvl(question: str, image_urls: List[str]):
3769
model_name = "OpenGVLab/InternVL2-2B"
3870

3971
llm = LLM(
@@ -61,7 +93,7 @@ def load_internvl(question, image_urls: List[str]):
6193
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
6294
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
6395

64-
return llm, prompt, stop_token_ids, None
96+
return llm, prompt, stop_token_ids, None, None
6597

6698

6799
def load_qwen2_vl(question, image_urls: List[str]):
@@ -111,18 +143,19 @@ def load_qwen2_vl(question, image_urls: List[str]):
111143
else:
112144
image_data, _ = process_vision_info(messages)
113145

114-
return llm, prompt, stop_token_ids, image_data
146+
return llm, prompt, stop_token_ids, image_data, None
115147

116148

117149
model_example_map = {
118150
"phi3_v": load_phi3v,
119151
"internvl_chat": load_internvl,
120152
"qwen2_vl": load_qwen2_vl,
153+
"qwen_vl_chat": load_qwenvl_chat,
121154
}
122155

123156

124157
def run_generate(model, question: str, image_urls: List[str]):
125-
llm, prompt, stop_token_ids, image_data = model_example_map[model](
158+
llm, prompt, stop_token_ids, image_data, _ = model_example_map[model](
126159
question, image_urls)
127160
if image_data is None:
128161
image_data = [fetch_image(url) for url in image_urls]
@@ -146,29 +179,32 @@ def run_generate(model, question: str, image_urls: List[str]):
146179

147180

148181
def run_chat(model: str, question: str, image_urls: List[str]):
149-
llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls)
182+
llm, _, stop_token_ids, _, chat_template = model_example_map[model](
183+
question, image_urls)
150184

151185
sampling_params = SamplingParams(temperature=0.0,
152186
max_tokens=128,
153187
stop_token_ids=stop_token_ids)
154-
155-
outputs = llm.chat([{
156-
"role":
157-
"user",
158-
"content": [
159-
{
160-
"type": "text",
161-
"text": question,
162-
},
163-
*({
164-
"type": "image_url",
165-
"image_url": {
166-
"url": image_url
188+
outputs = llm.chat(
189+
[{
190+
"role":
191+
"user",
192+
"content": [
193+
{
194+
"type": "text",
195+
"text": question,
167196
},
168-
} for image_url in image_urls),
169-
],
170-
}],
171-
sampling_params=sampling_params)
197+
*({
198+
"type": "image_url",
199+
"image_url": {
200+
"url": image_url
201+
},
202+
} for image_url in image_urls),
203+
],
204+
}],
205+
sampling_params=sampling_params,
206+
chat_template=chat_template,
207+
)
172208

173209
for o in outputs:
174210
generated_text = o.outputs[0].text

0 commit comments

Comments
 (0)