Skip to content

Commit 485bef9

Browse files
whyiugliuyanyi
authored andcommitted
[Model] support input embeddings for qwen2vl (vllm-project#8856)
1 parent 56e8cea commit 485bef9

File tree

3 files changed

+136
-71
lines changed

3 files changed

+136
-71
lines changed

docs/source/models/supported_models.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ Multimodal Language Models
281281
-
282282
* - :code:`Qwen2VLForConditionalGeneration`
283283
- Qwen2-VL
284-
- Image\ :sup:`+` / Video\ :sup:`+`
284+
- Image\ :sup:`E+` / Video\ :sup:`+`
285285
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
286286
-
287287
* - :code:`UltravoxModel`

docs/source/models/vlm.rst

+17
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,24 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
6060
for o in outputs:
6161
generated_text = o.outputs[0].text
6262
print(generated_text)
63+
64+
# Inference with image embeddings as input with additional parameters
65+
# Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding.
66+
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
67+
image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3)
68+
mm_data['image'] = {
69+
"image_embeds": image_embeds,
70+
"image_grid_thw": image_grid_thw,
71+
}
72+
outputs = llm.generate({
73+
"prompt": prompt,
74+
"multi_modal_data": mm_data,
75+
})
6376
77+
for o in outputs:
78+
generated_text = o.outputs[0].text
79+
print(generated_text)
80+
6481
# Batch inference
6582
image_1 = PIL.Image.open(...)
6683
image_2 = PIL.Image.open(...)

vllm/model_executor/models/qwen2_vl.py

+118-70
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
# limitations under the License.
2424
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
2525
from functools import lru_cache, partial
26-
from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict,
27-
Union)
26+
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
27+
Tuple, Type, TypedDict, Union)
2828

2929
import torch
3030
import torch.nn as nn
@@ -76,19 +76,31 @@
7676
# === Vision Inputs === #
7777

7878

79-
class Qwen2VLImageInputs(TypedDict):
80-
pixel_values: torch.Tensor
79+
class Qwen2VLImagePixelInputs(TypedDict):
80+
type: Literal["pixel_values"]
81+
data: torch.Tensor
8182
"""Shape:
8283
`(num_patches, num_channels * patch_size * patch_size)`
8384
"""
8485

8586
image_grid_thw: torch.Tensor
8687
"""Shape: `(num_images, 3)`
87-
8888
This should be in `(grid_t, grid_h, grid_w)` format.
8989
"""
9090

9191

92+
class Qwen2VLImageEmbeddingInputs(TypedDict):
93+
type: Literal["image_embeds"]
94+
data: torch.Tensor
95+
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
96+
`hidden_size` must match the hidden size of language model backbone.
97+
"""
98+
99+
100+
Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
101+
Qwen2VLImageEmbeddingInputs]
102+
103+
92104
class Qwen2VLVideoInputs(TypedDict):
93105
pixel_values_videos: torch.Tensor
94106
"""Shape:
@@ -567,6 +579,11 @@ def mm_input_mapper_for_qwen2_vl(
567579
data_type_key: str,
568580
) -> MultiModalInputs:
569581
"""Input mapper for Qwen2-VL."""
582+
if data_type_key == "image" and isinstance(data, dict):
583+
return MultiModalInputs({
584+
"image_embeds": data.get("image_embeds"),
585+
"image_grid_thw": data.get("image_grid_thw"),
586+
})
570587
model_config = ctx.model_config
571588
image_processor = cached_get_image_processor(
572589
model_config.model, trust_remote_code=model_config.trust_remote_code)
@@ -739,6 +756,48 @@ def _get_llm_num_vision_tokens(
739756
return llm_num_vision_tokens
740757

741758

759+
def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
760+
data_type_key: str, image_processor: Any,
761+
prompt_token_ids: List[int]) -> List[int]:
762+
"""
763+
Expand pad tokens for multi-modal inputs (e.g., images or videos).
764+
765+
Args:
766+
inputs (list): The multi-modal inputs (e.g., images or videos).
767+
token_id (int): The token ID used to represent the multi-modal input.
768+
make_batched_fn (Callable): A function to batch the inputs.
769+
data_type_key (str): The type of the multi-modal input.
770+
image_processor (Any): The image processor used to process the inputs.
771+
prompt_token_ids (List[int]): The list of token IDs in the prompt.
772+
773+
Returns:
774+
List[int]: The list of token IDs for the multi-modal inputs.
775+
"""
776+
indices = [
777+
idx for idx, token in enumerate(prompt_token_ids) if token == token_id
778+
]
779+
inputs = make_batched_fn(inputs)
780+
assert len(indices) == len(inputs)
781+
782+
prompt_token_ids_with_data = []
783+
for cnt, data in enumerate(inputs):
784+
num_tokens = _get_llm_num_vision_tokens(
785+
[data] if data_type_key == "image" else data,
786+
data_type_key=data_type_key,
787+
image_processor=image_processor,
788+
)
789+
if cnt == 0:
790+
end_idx = indices[cnt]
791+
non_data_tokens = prompt_token_ids[:end_idx]
792+
else:
793+
non_data_tokens = prompt_token_ids[indices[cnt - 1] +
794+
1:indices[cnt]]
795+
prompt_token_ids_with_data.extend(non_data_tokens)
796+
prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens))
797+
prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:])
798+
return prompt_token_ids_with_data
799+
800+
742801
def input_processor_for_qwen2_vl(ctx: InputContext,
743802
llm_inputs: LLMInputs) -> LLMInputs:
744803
multi_modal_data = llm_inputs.get("multi_modal_data", None)
@@ -775,62 +834,38 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
775834
)["input_ids"]
776835

777836
# Expand image pad tokens.
837+
778838
if image_inputs is not None:
779-
image_indices = [
780-
idx for idx, token in enumerate(prompt_token_ids)
781-
if token == hf_config.image_token_id
782-
]
783-
image_inputs = make_batched_images(image_inputs)
784-
assert len(image_indices) == len(image_inputs)
785-
786-
prompt_token_ids_with_image = []
787-
for image_cnt, image in enumerate(image_inputs):
788-
num_image_tokens = _get_llm_num_vision_tokens(
789-
[image],
790-
data_type_key="image",
791-
image_processor=image_processor,
792-
)
793-
if image_cnt == 0:
794-
non_image_tokens = prompt_token_ids[:image_indices[image_cnt]]
795-
else:
796-
non_image_tokens = prompt_token_ids[image_indices[image_cnt -
797-
1] +
798-
1:image_indices[image_cnt]]
799-
prompt_token_ids_with_image.extend(non_image_tokens)
800-
prompt_token_ids_with_image.extend(
801-
hf_config.image_token_id for _ in range(num_image_tokens))
802-
prompt_token_ids_with_image.extend(prompt_token_ids[image_indices[-1] +
803-
1:])
804-
prompt_token_ids = prompt_token_ids_with_image
805-
806-
# Expand video pad tokens.
839+
if isinstance(image_inputs, dict):
840+
prompt_token_ids_with_image = []
841+
image_indices = [
842+
idx for idx, token in enumerate(prompt_token_ids)
843+
if token == hf_config.image_token_id
844+
]
845+
image_cnt = len(image_indices)
846+
embed_dim = image_inputs.get('image_embeds').size(0)
847+
assert embed_dim % image_cnt == 0
848+
num_pad_tokens = embed_dim // image_cnt
849+
for idx, token in enumerate(prompt_token_ids):
850+
if idx in image_indices:
851+
prompt_token_ids_with_image.extend([token] *
852+
num_pad_tokens)
853+
else:
854+
prompt_token_ids_with_image.append(token)
855+
prompt_token_ids = prompt_token_ids_with_image
856+
else:
857+
prompt_token_ids = _expand_pad_tokens(image_inputs,
858+
hf_config.image_token_id,
859+
make_batched_images, "image",
860+
image_processor,
861+
prompt_token_ids)
862+
807863
if video_inputs is not None:
808-
video_indices = [
809-
idx for idx, token in enumerate(prompt_token_ids)
810-
if token == hf_config.video_token_id
811-
]
812-
video_inputs = make_batched_videos(video_inputs)
813-
assert len(video_indices) == len(video_inputs)
814-
815-
prompt_token_ids_with_video = []
816-
for video_cnt, video in enumerate(video_inputs):
817-
num_video_tokens = _get_llm_num_vision_tokens(
818-
video,
819-
data_type_key="video",
820-
image_processor=image_processor,
821-
)
822-
if video_cnt == 0:
823-
non_video_tokens = prompt_token_ids[:video_indices[video_cnt]]
824-
else:
825-
non_video_tokens = prompt_token_ids[video_indices[video_cnt -
826-
1] +
827-
1:video_indices[video_cnt]]
828-
prompt_token_ids_with_video.extend(non_video_tokens)
829-
prompt_token_ids_with_video.extend(
830-
hf_config.video_token_id for _ in range(num_video_tokens))
831-
prompt_token_ids_with_video.extend(prompt_token_ids[video_indices[-1] +
832-
1:])
833-
prompt_token_ids = prompt_token_ids_with_video
864+
prompt_token_ids = _expand_pad_tokens(video_inputs,
865+
hf_config.video_token_id,
866+
make_batched_videos, "video",
867+
image_processor,
868+
prompt_token_ids)
834869

835870
return LLMInputs(
836871
prompt_token_ids=prompt_token_ids,
@@ -910,22 +945,32 @@ def _validate_and_reshape_mm_tensor(self,
910945
def _parse_and_validate_image_input(
911946
self, **kwargs: object) -> Optional[Qwen2VLImageInputs]:
912947
pixel_values = kwargs.pop("pixel_values", None)
948+
image_embeds = kwargs.pop("image_embeds", None)
913949
image_grid_thw = kwargs.pop("image_grid_thw", None)
914950

915-
if pixel_values is None:
951+
if pixel_values is None and image_embeds is None:
916952
return None
917953

918-
pixel_values = self._validate_and_reshape_mm_tensor(
919-
pixel_values, "image pixel values")
920-
image_grid_thw = self._validate_and_reshape_mm_tensor(
921-
image_grid_thw, "image grid_thw")
954+
if pixel_values is not None:
955+
pixel_values = self._validate_and_reshape_mm_tensor(
956+
pixel_values, "image pixel values")
957+
image_grid_thw = self._validate_and_reshape_mm_tensor(
958+
image_grid_thw, "image grid_thw")
922959

923-
if not isinstance(pixel_values, (torch.Tensor, list)):
924-
raise ValueError("Incorrect type of image pixel values. "
925-
f"Got type: {type(pixel_values)}")
960+
if not isinstance(pixel_values, (torch.Tensor, list)):
961+
raise ValueError("Incorrect type of image pixel values. "
962+
f"Got type: {type(pixel_values)}")
926963

927-
return Qwen2VLImageInputs(pixel_values=pixel_values,
928-
image_grid_thw=image_grid_thw)
964+
return Qwen2VLImagePixelInputs(type="pixel_values",
965+
data=pixel_values,
966+
image_grid_thw=image_grid_thw)
967+
968+
if image_embeds is not None:
969+
if not isinstance(image_embeds, torch.Tensor):
970+
raise ValueError("Incorrect type of image embeddings. "
971+
f"Got type: {type(image_embeds)}")
972+
return Qwen2VLImageEmbeddingInputs(type="image_embeds",
973+
data=image_embeds)
929974

930975
def _parse_and_validate_video_input(
931976
self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]:
@@ -947,7 +992,10 @@ def _parse_and_validate_video_input(
947992

948993
def _process_image_input(self,
949994
image_input: Qwen2VLImageInputs) -> torch.Tensor:
950-
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
995+
if image_input["type"] == "image_embeds":
996+
return image_input["data"].type(self.visual.dtype)
997+
998+
pixel_values = image_input["data"].type(self.visual.dtype)
951999
image_embeds = self.visual(pixel_values,
9521000
grid_thw=image_input["image_grid_thw"])
9531001
return image_embeds

0 commit comments

Comments
 (0)