Skip to content

Commit 7a4df5f

Browse files
authored
[Model][LoRA]LoRA support added for Qwen (#9622)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent c5d7fb9 commit 7a4df5f

File tree

2 files changed

+101
-14
lines changed

2 files changed

+101
-14
lines changed

vllm/lora/models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -578,10 +578,10 @@ def _filter_unsupported_mm_module(self, module_name: str) -> bool:
578578
be filtered out.
579579
"""
580580
if self.supports_mm:
581-
prefix = module_name.split(".")[0]
582581
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
583-
return (prefix in module_mapping.connector
584-
or prefix in module_mapping.tower_model)
582+
prefix_lst = module_mapping.connector + module_mapping.tower_model
583+
return any(
584+
[module_name.startswith(prefix) for prefix in prefix_lst])
585585
return False
586586

587587
def _register_packed_modules(self, module_full_name: str) -> None:

vllm/model_executor/models/qwen.py

+98-11
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from transformers import PretrainedConfig
2121

2222
from vllm.attention import Attention, AttentionMetadata
23-
from vllm.config import CacheConfig, MultiModalConfig
23+
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
2424
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
2525
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
2626
token_inputs)
@@ -30,6 +30,7 @@
3030
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
3131
MergedColumnParallelLinear,
3232
QKVParallelLinear,
33+
ReplicatedLinear,
3334
RowParallelLinear)
3435
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3536
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -39,14 +40,15 @@
3940
from vllm.model_executor.layers.vocab_parallel_embedding import (
4041
ParallelLMHead, VocabParallelEmbedding)
4142
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43+
from vllm.model_executor.models.module_mapping import MultiModelKeys
4244
from vllm.model_executor.sampling_metadata import SamplingMetadata
4345
from vllm.multimodal import MULTIMODAL_REGISTRY
4446
from vllm.multimodal.base import MultiModalInputs
4547
from vllm.multimodal.utils import cached_get_tokenizer
4648
from vllm.sequence import IntermediateTensors, SequenceData
4749
from vllm.utils import is_list_of
4850

49-
from .interfaces import SupportsMultiModal, SupportsPP
51+
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
5052
from .utils import (flatten_bn, is_pp_missing_parameter,
5153
make_empty_intermediate_tensors_factory, make_layers)
5254

@@ -122,8 +124,8 @@ def __init__(
122124
# Strided linear layer.
123125
assert self._qkv_same_embed_dim, \
124126
'Visual Attention implementation only supports self-attention'
125-
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
126-
self.out_proj = nn.Linear(embed_dim, embed_dim)
127+
self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
128+
self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
127129
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
128130

129131
def forward(
@@ -133,7 +135,7 @@ def forward(
133135
) -> torch.Tensor:
134136
# query/key/value: [sq, b, h]
135137
sq, b, _ = x.size()
136-
mixed_x_layer = self.in_proj(x)
138+
mixed_x_layer, _ = self.in_proj(x)
137139

138140
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
139141
new_tensor_shape = mixed_x_layer.size()[:-1] + \
@@ -182,7 +184,7 @@ def forward(
182184
(self.hidden_size_per_partition,)
183185
context_layer = context_layer.view(*new_context_layer_shape)
184186

185-
output = self.out_proj(context_layer)
187+
output, _ = self.out_proj(context_layer)
186188

187189
return output
188190

@@ -860,18 +862,15 @@ def dummy_data_for_qwen(
860862
return seq_data, mm_data
861863

862864

863-
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
864-
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
865-
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
866-
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
867-
class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
865+
class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
868866

869867
def __init__(
870868
self,
871869
config: PretrainedConfig,
872870
multimodal_config: MultiModalConfig,
873871
cache_config: Optional[CacheConfig] = None,
874872
quant_config: Optional[QuantizationConfig] = None,
873+
lora_config: Optional[LoRAConfig] = None,
875874
):
876875
super().__init__()
877876
self.config = config
@@ -990,3 +989,91 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
990989
weight_loader = getattr(param, "weight_loader",
991990
default_weight_loader)
992991
weight_loader(param, loaded_weight)
992+
993+
994+
class QWenLLM(QWenBaseModel):
995+
packed_modules_mapping = {
996+
"c_attn": ["c_attn"],
997+
"gate_up_proj": [
998+
"w2",
999+
"w1",
1000+
],
1001+
}
1002+
# LoRA specific attributes
1003+
supported_lora_modules = [
1004+
"c_attn",
1005+
"gate_up_proj",
1006+
"c_proj",
1007+
]
1008+
1009+
embedding_modules = {}
1010+
embedding_padding_modules = []
1011+
1012+
1013+
class QWenVL(QWenBaseModel):
1014+
packed_modules_mapping = {
1015+
"c_attn": ["c_attn"],
1016+
"gate_up_proj": [
1017+
"w2",
1018+
"w1",
1019+
],
1020+
}
1021+
# LoRA specific attributes
1022+
supported_lora_modules = [
1023+
"c_attn",
1024+
"gate_up_proj",
1025+
"c_proj",
1026+
# visual module
1027+
"out_proj",
1028+
"in_proj",
1029+
"c_fc",
1030+
# resampler
1031+
"kv_proj",
1032+
]
1033+
1034+
embedding_modules = {}
1035+
embedding_padding_modules = []
1036+
1037+
def get_mm_mapping(self) -> MultiModelKeys:
1038+
"""
1039+
Get the module prefix in multimodal models
1040+
"""
1041+
return MultiModelKeys.from_string_field(
1042+
language_model="transformer.h",
1043+
connector="transformer.visual.attn_pool",
1044+
tower_model="transformer.visual.transformer")
1045+
1046+
1047+
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
1048+
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
1049+
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
1050+
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
1051+
class QWenLMHeadModel(QWenBaseModel):
1052+
"""
1053+
QWenLMHeadModel is not only applicable to LLM but also to VL, which is not
1054+
conducive to the current integration logic of LoRA in vLLM. Therefore, it
1055+
is necessary to separate them.
1056+
"""
1057+
# Ensure that the LoRA support check passes when the class is not
1058+
# initialized, but set all these attributes to empty.
1059+
packed_modules_mapping = {}
1060+
supported_lora_modules = []
1061+
embedding_modules = {}
1062+
embedding_padding_modules = []
1063+
1064+
def __new__(
1065+
cls,
1066+
config: PretrainedConfig,
1067+
multimodal_config: MultiModalConfig,
1068+
cache_config: Optional[CacheConfig] = None,
1069+
quant_config: Optional[QuantizationConfig] = None,
1070+
lora_config: Optional[LoRAConfig] = None,
1071+
):
1072+
# Initialize VL
1073+
if hasattr(config, "visual"):
1074+
return QWenVL(config, multimodal_config, cache_config,
1075+
quant_config, lora_config)
1076+
# Initialize LLM
1077+
else:
1078+
return QWenLLM(config, multimodal_config, cache_config,
1079+
quant_config, lora_config)

0 commit comments

Comments
 (0)