23
23
# limitations under the License.
24
24
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
25
25
from functools import lru_cache , partial
26
- from typing import (Iterable , List , Mapping , Optional , Tuple , Type , TypedDict ,
27
- Union )
26
+ from typing import (Any , Callable , Iterable , List , Literal , Mapping , Optional ,
27
+ Tuple , Type , TypedDict , Union )
28
28
29
29
import torch
30
30
import torch .nn as nn
76
76
# === Vision Inputs === #
77
77
78
78
79
- class Qwen2VLImageInputs (TypedDict ):
80
- pixel_values : torch .Tensor
79
+ class Qwen2VLImagePixelInputs (TypedDict ):
80
+ type : Literal ["pixel_values" ]
81
+ data : torch .Tensor
81
82
"""Shape:
82
83
`(num_patches, num_channels * patch_size * patch_size)`
83
84
"""
84
85
85
86
image_grid_thw : torch .Tensor
86
87
"""Shape: `(num_images, 3)`
87
-
88
88
This should be in `(grid_t, grid_h, grid_w)` format.
89
89
"""
90
90
91
91
92
+ class Qwen2VLImageEmbeddingInputs (TypedDict ):
93
+ type : Literal ["image_embeds" ]
94
+ data : torch .Tensor
95
+ """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
96
+ `hidden_size` must match the hidden size of language model backbone.
97
+ """
98
+
99
+
100
+ Qwen2VLImageInputs = Union [Qwen2VLImagePixelInputs ,
101
+ Qwen2VLImageEmbeddingInputs ]
102
+
103
+
92
104
class Qwen2VLVideoInputs (TypedDict ):
93
105
pixel_values_videos : torch .Tensor
94
106
"""Shape:
@@ -567,6 +579,11 @@ def mm_input_mapper_for_qwen2_vl(
567
579
data_type_key : str ,
568
580
) -> MultiModalInputs :
569
581
"""Input mapper for Qwen2-VL."""
582
+ if data_type_key == "image" and isinstance (data , dict ):
583
+ return MultiModalInputs ({
584
+ "image_embeds" : data .get ("image_embeds" ),
585
+ "image_grid_thw" : data .get ("image_grid_thw" ),
586
+ })
570
587
model_config = ctx .model_config
571
588
image_processor = cached_get_image_processor (
572
589
model_config .model , trust_remote_code = model_config .trust_remote_code )
@@ -739,6 +756,48 @@ def _get_llm_num_vision_tokens(
739
756
return llm_num_vision_tokens
740
757
741
758
759
+ def _expand_pad_tokens (inputs : list , token_id : int , make_batched_fn : Callable ,
760
+ data_type_key : str , image_processor : Any ,
761
+ prompt_token_ids : List [int ]) -> List [int ]:
762
+ """
763
+ Expand pad tokens for multi-modal inputs (e.g., images or videos).
764
+
765
+ Args:
766
+ inputs (list): The multi-modal inputs (e.g., images or videos).
767
+ token_id (int): The token ID used to represent the multi-modal input.
768
+ make_batched_fn (Callable): A function to batch the inputs.
769
+ data_type_key (str): The type of the multi-modal input.
770
+ image_processor (Any): The image processor used to process the inputs.
771
+ prompt_token_ids (List[int]): The list of token IDs in the prompt.
772
+
773
+ Returns:
774
+ List[int]: The list of token IDs for the multi-modal inputs.
775
+ """
776
+ indices = [
777
+ idx for idx , token in enumerate (prompt_token_ids ) if token == token_id
778
+ ]
779
+ inputs = make_batched_fn (inputs )
780
+ assert len (indices ) == len (inputs )
781
+
782
+ prompt_token_ids_with_data = []
783
+ for cnt , data in enumerate (inputs ):
784
+ num_tokens = _get_llm_num_vision_tokens (
785
+ [data ] if data_type_key == "image" else data ,
786
+ data_type_key = data_type_key ,
787
+ image_processor = image_processor ,
788
+ )
789
+ if cnt == 0 :
790
+ end_idx = indices [cnt ]
791
+ non_data_tokens = prompt_token_ids [:end_idx ]
792
+ else :
793
+ non_data_tokens = prompt_token_ids [indices [cnt - 1 ] +
794
+ 1 :indices [cnt ]]
795
+ prompt_token_ids_with_data .extend (non_data_tokens )
796
+ prompt_token_ids_with_data .extend (token_id for _ in range (num_tokens ))
797
+ prompt_token_ids_with_data .extend (prompt_token_ids [indices [- 1 ] + 1 :])
798
+ return prompt_token_ids_with_data
799
+
800
+
742
801
def input_processor_for_qwen2_vl (ctx : InputContext ,
743
802
llm_inputs : LLMInputs ) -> LLMInputs :
744
803
multi_modal_data = llm_inputs .get ("multi_modal_data" , None )
@@ -775,62 +834,38 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
775
834
)["input_ids" ]
776
835
777
836
# Expand image pad tokens.
837
+
778
838
if image_inputs is not None :
779
- image_indices = [
780
- idx for idx , token in enumerate (prompt_token_ids )
781
- if token == hf_config .image_token_id
782
- ]
783
- image_inputs = make_batched_images (image_inputs )
784
- assert len (image_indices ) == len (image_inputs )
785
-
786
- prompt_token_ids_with_image = []
787
- for image_cnt , image in enumerate (image_inputs ):
788
- num_image_tokens = _get_llm_num_vision_tokens (
789
- [image ],
790
- data_type_key = "image" ,
791
- image_processor = image_processor ,
792
- )
793
- if image_cnt == 0 :
794
- non_image_tokens = prompt_token_ids [:image_indices [image_cnt ]]
795
- else :
796
- non_image_tokens = prompt_token_ids [image_indices [image_cnt -
797
- 1 ] +
798
- 1 :image_indices [image_cnt ]]
799
- prompt_token_ids_with_image .extend (non_image_tokens )
800
- prompt_token_ids_with_image .extend (
801
- hf_config .image_token_id for _ in range (num_image_tokens ))
802
- prompt_token_ids_with_image .extend (prompt_token_ids [image_indices [- 1 ] +
803
- 1 :])
804
- prompt_token_ids = prompt_token_ids_with_image
805
-
806
- # Expand video pad tokens.
839
+ if isinstance (image_inputs , dict ):
840
+ prompt_token_ids_with_image = []
841
+ image_indices = [
842
+ idx for idx , token in enumerate (prompt_token_ids )
843
+ if token == hf_config .image_token_id
844
+ ]
845
+ image_cnt = len (image_indices )
846
+ embed_dim = image_inputs .get ('image_embeds' ).size (0 )
847
+ assert embed_dim % image_cnt == 0
848
+ num_pad_tokens = embed_dim // image_cnt
849
+ for idx , token in enumerate (prompt_token_ids ):
850
+ if idx in image_indices :
851
+ prompt_token_ids_with_image .extend ([token ] *
852
+ num_pad_tokens )
853
+ else :
854
+ prompt_token_ids_with_image .append (token )
855
+ prompt_token_ids = prompt_token_ids_with_image
856
+ else :
857
+ prompt_token_ids = _expand_pad_tokens (image_inputs ,
858
+ hf_config .image_token_id ,
859
+ make_batched_images , "image" ,
860
+ image_processor ,
861
+ prompt_token_ids )
862
+
807
863
if video_inputs is not None :
808
- video_indices = [
809
- idx for idx , token in enumerate (prompt_token_ids )
810
- if token == hf_config .video_token_id
811
- ]
812
- video_inputs = make_batched_videos (video_inputs )
813
- assert len (video_indices ) == len (video_inputs )
814
-
815
- prompt_token_ids_with_video = []
816
- for video_cnt , video in enumerate (video_inputs ):
817
- num_video_tokens = _get_llm_num_vision_tokens (
818
- video ,
819
- data_type_key = "video" ,
820
- image_processor = image_processor ,
821
- )
822
- if video_cnt == 0 :
823
- non_video_tokens = prompt_token_ids [:video_indices [video_cnt ]]
824
- else :
825
- non_video_tokens = prompt_token_ids [video_indices [video_cnt -
826
- 1 ] +
827
- 1 :video_indices [video_cnt ]]
828
- prompt_token_ids_with_video .extend (non_video_tokens )
829
- prompt_token_ids_with_video .extend (
830
- hf_config .video_token_id for _ in range (num_video_tokens ))
831
- prompt_token_ids_with_video .extend (prompt_token_ids [video_indices [- 1 ] +
832
- 1 :])
833
- prompt_token_ids = prompt_token_ids_with_video
864
+ prompt_token_ids = _expand_pad_tokens (video_inputs ,
865
+ hf_config .video_token_id ,
866
+ make_batched_videos , "video" ,
867
+ image_processor ,
868
+ prompt_token_ids )
834
869
835
870
return LLMInputs (
836
871
prompt_token_ids = prompt_token_ids ,
@@ -910,22 +945,32 @@ def _validate_and_reshape_mm_tensor(self,
910
945
def _parse_and_validate_image_input (
911
946
self , ** kwargs : object ) -> Optional [Qwen2VLImageInputs ]:
912
947
pixel_values = kwargs .pop ("pixel_values" , None )
948
+ image_embeds = kwargs .pop ("image_embeds" , None )
913
949
image_grid_thw = kwargs .pop ("image_grid_thw" , None )
914
950
915
- if pixel_values is None :
951
+ if pixel_values is None and image_embeds is None :
916
952
return None
917
953
918
- pixel_values = self ._validate_and_reshape_mm_tensor (
919
- pixel_values , "image pixel values" )
920
- image_grid_thw = self ._validate_and_reshape_mm_tensor (
921
- image_grid_thw , "image grid_thw" )
954
+ if pixel_values is not None :
955
+ pixel_values = self ._validate_and_reshape_mm_tensor (
956
+ pixel_values , "image pixel values" )
957
+ image_grid_thw = self ._validate_and_reshape_mm_tensor (
958
+ image_grid_thw , "image grid_thw" )
922
959
923
- if not isinstance (pixel_values , (torch .Tensor , list )):
924
- raise ValueError ("Incorrect type of image pixel values. "
925
- f"Got type: { type (pixel_values )} " )
960
+ if not isinstance (pixel_values , (torch .Tensor , list )):
961
+ raise ValueError ("Incorrect type of image pixel values. "
962
+ f"Got type: { type (pixel_values )} " )
926
963
927
- return Qwen2VLImageInputs (pixel_values = pixel_values ,
928
- image_grid_thw = image_grid_thw )
964
+ return Qwen2VLImagePixelInputs (type = "pixel_values" ,
965
+ data = pixel_values ,
966
+ image_grid_thw = image_grid_thw )
967
+
968
+ if image_embeds is not None :
969
+ if not isinstance (image_embeds , torch .Tensor ):
970
+ raise ValueError ("Incorrect type of image embeddings. "
971
+ f"Got type: { type (image_embeds )} " )
972
+ return Qwen2VLImageEmbeddingInputs (type = "image_embeds" ,
973
+ data = image_embeds )
929
974
930
975
def _parse_and_validate_video_input (
931
976
self , ** kwargs : object ) -> Optional [Qwen2VLVideoInputs ]:
@@ -947,7 +992,10 @@ def _parse_and_validate_video_input(
947
992
948
993
def _process_image_input (self ,
949
994
image_input : Qwen2VLImageInputs ) -> torch .Tensor :
950
- pixel_values = image_input ["pixel_values" ].type (self .visual .dtype )
995
+ if image_input ["type" ] == "image_embeds" :
996
+ return image_input ["data" ].type (self .visual .dtype )
997
+
998
+ pixel_values = image_input ["data" ].type (self .visual .dtype )
951
999
image_embeds = self .visual (pixel_values ,
952
1000
grid_thw = image_input ["image_grid_thw" ])
953
1001
return image_embeds
0 commit comments