20
20
from transformers import PretrainedConfig
21
21
22
22
from vllm .attention import Attention , AttentionMetadata
23
- from vllm .config import CacheConfig , MultiModalConfig
23
+ from vllm .config import CacheConfig , LoRAConfig , MultiModalConfig
24
24
from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
25
25
from vllm .inputs import (INPUT_REGISTRY , DecoderOnlyInputs , InputContext ,
26
26
token_inputs )
30
30
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
31
31
MergedColumnParallelLinear ,
32
32
QKVParallelLinear ,
33
+ ReplicatedLinear ,
33
34
RowParallelLinear )
34
35
from vllm .model_executor .layers .logits_processor import LogitsProcessor
35
36
from vllm .model_executor .layers .quantization import QuantizationConfig
39
40
from vllm .model_executor .layers .vocab_parallel_embedding import (
40
41
ParallelLMHead , VocabParallelEmbedding )
41
42
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
43
+ from vllm .model_executor .models .module_mapping import MultiModelKeys
42
44
from vllm .model_executor .sampling_metadata import SamplingMetadata
43
45
from vllm .multimodal import MULTIMODAL_REGISTRY
44
46
from vllm .multimodal .base import MultiModalInputs
45
47
from vllm .multimodal .utils import cached_get_tokenizer
46
48
from vllm .sequence import IntermediateTensors , SequenceData
47
49
from vllm .utils import is_list_of
48
50
49
- from .interfaces import SupportsMultiModal , SupportsPP
51
+ from .interfaces import SupportsLoRA , SupportsMultiModal , SupportsPP
50
52
from .utils import (flatten_bn , is_pp_missing_parameter ,
51
53
make_empty_intermediate_tensors_factory , make_layers )
52
54
@@ -122,8 +124,8 @@ def __init__(
122
124
# Strided linear layer.
123
125
assert self ._qkv_same_embed_dim , \
124
126
'Visual Attention implementation only supports self-attention'
125
- self .in_proj = nn . Linear (embed_dim , 3 * embed_dim )
126
- self .out_proj = nn . Linear (embed_dim , embed_dim )
127
+ self .in_proj = ReplicatedLinear (embed_dim , 3 * embed_dim )
128
+ self .out_proj = ReplicatedLinear (embed_dim , embed_dim )
127
129
self .norm_factor = math .sqrt (self .hidden_size_per_attention_head )
128
130
129
131
def forward (
@@ -133,7 +135,7 @@ def forward(
133
135
) -> torch .Tensor :
134
136
# query/key/value: [sq, b, h]
135
137
sq , b , _ = x .size ()
136
- mixed_x_layer = self .in_proj (x )
138
+ mixed_x_layer , _ = self .in_proj (x )
137
139
138
140
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
139
141
new_tensor_shape = mixed_x_layer .size ()[:- 1 ] + \
@@ -182,7 +184,7 @@ def forward(
182
184
(self .hidden_size_per_partition ,)
183
185
context_layer = context_layer .view (* new_context_layer_shape )
184
186
185
- output = self .out_proj (context_layer )
187
+ output , _ = self .out_proj (context_layer )
186
188
187
189
return output
188
190
@@ -860,18 +862,15 @@ def dummy_data_for_qwen(
860
862
return seq_data , mm_data
861
863
862
864
863
- @MULTIMODAL_REGISTRY .register_image_input_mapper (input_mapper_for_qwen )
864
- @MULTIMODAL_REGISTRY .register_max_image_tokens (MAX_QWEN_IMG_TOKENS )
865
- @INPUT_REGISTRY .register_dummy_data (dummy_data_for_qwen )
866
- @INPUT_REGISTRY .register_input_processor (input_processor_for_qwen )
867
- class QWenLMHeadModel (nn .Module , SupportsMultiModal , SupportsPP ):
865
+ class QWenBaseModel (nn .Module , SupportsMultiModal , SupportsPP , SupportsLoRA ):
868
866
869
867
def __init__ (
870
868
self ,
871
869
config : PretrainedConfig ,
872
870
multimodal_config : MultiModalConfig ,
873
871
cache_config : Optional [CacheConfig ] = None ,
874
872
quant_config : Optional [QuantizationConfig ] = None ,
873
+ lora_config : Optional [LoRAConfig ] = None ,
875
874
):
876
875
super ().__init__ ()
877
876
self .config = config
@@ -990,3 +989,91 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
990
989
weight_loader = getattr (param , "weight_loader" ,
991
990
default_weight_loader )
992
991
weight_loader (param , loaded_weight )
992
+
993
+
994
+ class QWenLLM (QWenBaseModel ):
995
+ packed_modules_mapping = {
996
+ "c_attn" : ["c_attn" ],
997
+ "gate_up_proj" : [
998
+ "w2" ,
999
+ "w1" ,
1000
+ ],
1001
+ }
1002
+ # LoRA specific attributes
1003
+ supported_lora_modules = [
1004
+ "c_attn" ,
1005
+ "gate_up_proj" ,
1006
+ "c_proj" ,
1007
+ ]
1008
+
1009
+ embedding_modules = {}
1010
+ embedding_padding_modules = []
1011
+
1012
+
1013
+ class QWenVL (QWenBaseModel ):
1014
+ packed_modules_mapping = {
1015
+ "c_attn" : ["c_attn" ],
1016
+ "gate_up_proj" : [
1017
+ "w2" ,
1018
+ "w1" ,
1019
+ ],
1020
+ }
1021
+ # LoRA specific attributes
1022
+ supported_lora_modules = [
1023
+ "c_attn" ,
1024
+ "gate_up_proj" ,
1025
+ "c_proj" ,
1026
+ # visual module
1027
+ "out_proj" ,
1028
+ "in_proj" ,
1029
+ "c_fc" ,
1030
+ # resampler
1031
+ "kv_proj" ,
1032
+ ]
1033
+
1034
+ embedding_modules = {}
1035
+ embedding_padding_modules = []
1036
+
1037
+ def get_mm_mapping (self ) -> MultiModelKeys :
1038
+ """
1039
+ Get the module prefix in multimodal models
1040
+ """
1041
+ return MultiModelKeys .from_string_field (
1042
+ language_model = "transformer.h" ,
1043
+ connector = "transformer.visual.attn_pool" ,
1044
+ tower_model = "transformer.visual.transformer" )
1045
+
1046
+
1047
+ @MULTIMODAL_REGISTRY .register_image_input_mapper (input_mapper_for_qwen )
1048
+ @MULTIMODAL_REGISTRY .register_max_image_tokens (MAX_QWEN_IMG_TOKENS )
1049
+ @INPUT_REGISTRY .register_dummy_data (dummy_data_for_qwen )
1050
+ @INPUT_REGISTRY .register_input_processor (input_processor_for_qwen )
1051
+ class QWenLMHeadModel (QWenBaseModel ):
1052
+ """
1053
+ QWenLMHeadModel is not only applicable to LLM but also to VL, which is not
1054
+ conducive to the current integration logic of LoRA in vLLM. Therefore, it
1055
+ is necessary to separate them.
1056
+ """
1057
+ # Ensure that the LoRA support check passes when the class is not
1058
+ # initialized, but set all these attributes to empty.
1059
+ packed_modules_mapping = {}
1060
+ supported_lora_modules = []
1061
+ embedding_modules = {}
1062
+ embedding_padding_modules = []
1063
+
1064
+ def __new__ (
1065
+ cls ,
1066
+ config : PretrainedConfig ,
1067
+ multimodal_config : MultiModalConfig ,
1068
+ cache_config : Optional [CacheConfig ] = None ,
1069
+ quant_config : Optional [QuantizationConfig ] = None ,
1070
+ lora_config : Optional [LoRAConfig ] = None ,
1071
+ ):
1072
+ # Initialize VL
1073
+ if hasattr (config , "visual" ):
1074
+ return QWenVL (config , multimodal_config , cache_config ,
1075
+ quant_config , lora_config )
1076
+ # Initialize LLM
1077
+ else :
1078
+ return QWenLLM (config , multimodal_config , cache_config ,
1079
+ quant_config , lora_config )
0 commit comments