From 6493ee42f25275aac135f09be03b5572addc684d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 23 Oct 2024 16:19:05 +0000 Subject: [PATCH 1/3] Init Signed-off-by: Jee Jee Li --- vllm/model_executor/models/qwen.py | 115 +++++++++++++++++++++++++++-- 1 file changed, 108 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index cd3f7c1b6c4db..319c37937d2df 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -20,7 +20,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, token_inputs) @@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -46,7 +47,7 @@ from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import is_list_of -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -860,11 +861,7 @@ def dummy_data_for_qwen( return seq_data, mm_data -@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) -@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) -@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) -class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP): +class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): def __init__( self, @@ -872,6 +869,7 @@ def __init__( multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config @@ -990,3 +988,106 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + +class QWenLLM(QWenBaseModel): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # vision encoder + "fc1", + "fc2", + "out_proj", + # language model + "qkv_proj", # same name with vision encoder + "o_proj", + "gate_up_proj", + "down_proj", + # resampler + "kv_proj", + ] + + embedding_modules = {} + embedding_padding_modules = [] + + +class QWenVL(QWenBaseModel): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # vision encoder + "fc1", + "fc2", + "out_proj", + # language model + "qkv_proj", # same name with vision encoder + "o_proj", + "gate_up_proj", + "down_proj", + # resampler + "kv_proj", + ] + + embedding_modules = {} + embedding_padding_modules = [] + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field(language_model="llm", + connector="resampler", + tower_model="vpm") + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) +@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) +@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) +class QWenLMHeadModel(QWenBaseModel): + """ + QWenLMHeadModel is not only applicable to LLM but also to VL, which is not + conducive to the current integration logic of LoRA in vLLM. Therefore, it + is necessary to separate them. + """ + # Ensure that the LoRA support check passes when the class is not + # initialized, but set all these attributes to empty. + packed_modules_mapping = {} + supported_lora_modules = [] + embedding_modules = {} + embedding_padding_modules = [] + + def __new__( + cls, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + if multimodal_config is None: + return QWenLLM(config, multimodal_config, cache_config, + quant_config) + else: + return QWenVL(config, multimodal_config, cache_config, + quant_config) From 64629612d074b3dd3e1685fdc2da69ddaad55e60 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 28 Oct 2024 08:10:45 +0000 Subject: [PATCH 2/3] Complete QWenVL support LoRA Signed-off-by: Jee Jee Li --- vllm/lora/models.py | 25 ++++++- vllm/model_executor/models/minicpmv.py | 11 ++- vllm/model_executor/models/module_mapping.py | 42 +++++++++++ vllm/model_executor/models/qwen.py | 78 +++++++++----------- 4 files changed, 104 insertions(+), 52 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index aaadca9a4d16d..78b6ad2e51238 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -26,7 +26,8 @@ is_regex_target_modules, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models import SupportsLoRA, supports_multimodal -from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.module_mapping import (ModelComposeMethod, + MultiModelKeys) from vllm.model_executor.models.utils import PPMissingLayer from vllm.utils import is_pin_memory_available @@ -577,11 +578,29 @@ def _filter_unsupported_mm_module(self, module_name: str) -> bool: language model. LoRA for other modules, such as the vision tower, will be filtered out. """ - if self.supports_mm: + module_mapping: MultiModelKeys = self.model.get_mm_mapping() + + def _verify_decoupled_model(): + """ + Suitable for MiniCPMV, InternVL, etc. + """ prefix = module_name.split(".")[0] - module_mapping: MultiModelKeys = self.model.get_mm_mapping() return (prefix in module_mapping.connector or prefix in module_mapping.tower_model) + + def _verify_coupled_model(): + """ + Suitable for QWenVL, GLM4V, etc. + """ + prefix_lst = module_mapping.connector + module_mapping.tower_model + return any( + [module_name.startswith(prefix) for prefix in prefix_lst]) + + if self.supports_mm: + if module_mapping.compose_type == ModelComposeMethod.Decoupled: + return _verify_decoupled_model() + else: + return _verify_coupled_model() return False def _register_packed_modules(self, module_full_name: str) -> None: diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index ca7c2be5a038e..1bedfa07b92cd 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -48,7 +48,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.minicpm import MiniCPMModel -from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.module_mapping import (ModelComposeMethod, + MultiModelKeys) from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.utils import LLMWrapper from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -635,9 +636,11 @@ def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ - return MultiModelKeys.from_string_field(language_model="llm", - connector="resampler", - tower_model="vpm") + return MultiModelKeys.from_string_field( + language_model="llm", + connector="resampler", + tower_model="vpm", + compose_type=ModelComposeMethod.Decoupled) def init_llm( self, diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py index a9102a6073a2f..269cdc640c8df 100644 --- a/vllm/model_executor/models/module_mapping.py +++ b/vllm/model_executor/models/module_mapping.py @@ -2,9 +2,46 @@ # https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py from dataclasses import dataclass, field +from enum import IntEnum from typing import List, Union +class ModelComposeMethod(IntEnum): + """ + `ModelComposeMethod` distinguishes between two architectural patterns in + multi-modal models, focusing on how vision model, language model, and + projector are implemented: + 1. Decoupled Architecture (like mllama, InternVL, miniCPMV), complete + independent implementation with its own layers, for example: + ``` + InternVLChatModel + ├── vision_model (visual encoder) + │ ├── embeddings + │ └── encoder + ├── language_model (language model) + │ ├── tok_embeddings + │ └── layers + └── mlp1 (projector) + ``` + 2. Coupled Architecture (like QWenVL, GLM4V), Integrated as a sub-module + with shared architectural patterns , for example: + + ``` + QWenVL + └── transformer + ├── wte + ├── h (language model) + ├── ln_f + └── visual (visual encoder) + ├── conv1 + ├── transformer + └── attn_pool (projector) + ``` + """ + Decoupled = 0 + Coupled = 1 + + @dataclass class ModelKeys: model_type: str = None @@ -41,6 +78,8 @@ class ModelKeys: output: str = None + compose_type: str = None + @dataclass class MultiModelKeys(ModelKeys): @@ -55,7 +94,9 @@ def from_string_field(language_model: Union[str, List[str]] = None, connector: Union[str, List[str]] = None, tower_model: Union[str, List[str]] = None, generator: Union[str, List[str]] = None, + compose_type: str = None, **kwargs) -> 'MultiModelKeys': + assert compose_type, "compose_type is not allowed to be None" def to_list(value): if value is None: @@ -66,4 +107,5 @@ def to_list(value): connector=to_list(connector), tower_model=to_list(tower_model), generator=to_list(generator), + compose_type=compose_type, **kwargs) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 319c37937d2df..04fdb27f42141 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -39,7 +40,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.module_mapping import (ModelComposeMethod, + MultiModelKeys) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -123,8 +125,8 @@ def __init__( # Strided linear layer. assert self._qkv_same_embed_dim, \ 'Visual Attention implementation only supports self-attention' - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) - self.out_proj = nn.Linear(embed_dim, embed_dim) + self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim) + self.out_proj = ReplicatedLinear(embed_dim, embed_dim) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) def forward( @@ -134,7 +136,7 @@ def forward( ) -> torch.Tensor: # query/key/value: [sq, b, h] sq, b, _ = x.size() - mixed_x_layer = self.in_proj(x) + mixed_x_layer, _ = self.in_proj(x) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ @@ -183,7 +185,7 @@ def forward( (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) - output = self.out_proj(context_layer) + output, _ = self.out_proj(context_layer) return output @@ -992,29 +994,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): class QWenLLM(QWenBaseModel): packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], + "c_attn": ["c_attn"], "gate_up_proj": [ - "gate_proj", - "up_proj", + "w2", + "w1", ], } # LoRA specific attributes supported_lora_modules = [ - # vision encoder - "fc1", - "fc2", - "out_proj", - # language model - "qkv_proj", # same name with vision encoder - "o_proj", + "c_attn", "gate_up_proj", - "down_proj", - # resampler - "kv_proj", + "c_proj", ] embedding_modules = {} @@ -1023,27 +1013,21 @@ class QWenLLM(QWenBaseModel): class QWenVL(QWenBaseModel): packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], + "c_attn": ["c_attn"], "gate_up_proj": [ - "gate_proj", - "up_proj", + "w2", + "w1", ], } # LoRA specific attributes supported_lora_modules = [ - # vision encoder - "fc1", - "fc2", - "out_proj", - # language model - "qkv_proj", # same name with vision encoder - "o_proj", + "c_attn", "gate_up_proj", - "down_proj", + "c_proj", + # visual module + "out_proj", + "in_proj", + "c_fc", # resampler "kv_proj", ] @@ -1055,9 +1039,11 @@ def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ - return MultiModelKeys.from_string_field(language_model="llm", - connector="resampler", - tower_model="vpm") + return MultiModelKeys.from_string_field( + language_model="transformer.h", + connector="transformer.visual.attn_pool", + tower_model="transformer.visual.transformer", + compose_type=ModelComposeMethod.Coupled) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) @@ -1085,9 +1071,11 @@ def __new__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - if multimodal_config is None: - return QWenLLM(config, multimodal_config, cache_config, - quant_config) - else: + # Initialize VL + if hasattr(config, "visual"): return QWenVL(config, multimodal_config, cache_config, - quant_config) + quant_config, lora_config) + # Initialize LLM + else: + return QWenLLM(config, multimodal_config, cache_config, + quant_config, lora_config) From 804a361d7fdbe9d30680538fa1456beea200a135 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 29 Oct 2024 01:20:50 +0000 Subject: [PATCH 3/3] Delete redundant code Signed-off-by: Jee Jee Li --- vllm/lora/models.py | 25 ++---------- vllm/model_executor/models/minicpmv.py | 11 ++--- vllm/model_executor/models/module_mapping.py | 42 -------------------- vllm/model_executor/models/qwen.py | 6 +-- 4 files changed, 9 insertions(+), 75 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 78b6ad2e51238..d0279f273db7a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -26,8 +26,7 @@ is_regex_target_modules, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models import SupportsLoRA, supports_multimodal -from vllm.model_executor.models.module_mapping import (ModelComposeMethod, - MultiModelKeys) +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer from vllm.utils import is_pin_memory_available @@ -578,29 +577,11 @@ def _filter_unsupported_mm_module(self, module_name: str) -> bool: language model. LoRA for other modules, such as the vision tower, will be filtered out. """ - module_mapping: MultiModelKeys = self.model.get_mm_mapping() - - def _verify_decoupled_model(): - """ - Suitable for MiniCPMV, InternVL, etc. - """ - prefix = module_name.split(".")[0] - return (prefix in module_mapping.connector - or prefix in module_mapping.tower_model) - - def _verify_coupled_model(): - """ - Suitable for QWenVL, GLM4V, etc. - """ + if self.supports_mm: + module_mapping: MultiModelKeys = self.model.get_mm_mapping() prefix_lst = module_mapping.connector + module_mapping.tower_model return any( [module_name.startswith(prefix) for prefix in prefix_lst]) - - if self.supports_mm: - if module_mapping.compose_type == ModelComposeMethod.Decoupled: - return _verify_decoupled_model() - else: - return _verify_coupled_model() return False def _register_packed_modules(self, module_full_name: str) -> None: diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 5410711d173f1..2ec51dc4647f5 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -48,8 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.minicpm import MiniCPMModel -from vllm.model_executor.models.module_mapping import (ModelComposeMethod, - MultiModelKeys) +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.utils import LLMWrapper from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -636,11 +635,9 @@ def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ - return MultiModelKeys.from_string_field( - language_model="llm", - connector="resampler", - tower_model="vpm", - compose_type=ModelComposeMethod.Decoupled) + return MultiModelKeys.from_string_field(language_model="llm", + connector="resampler", + tower_model="vpm") def init_llm( self, diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py index 269cdc640c8df..a9102a6073a2f 100644 --- a/vllm/model_executor/models/module_mapping.py +++ b/vllm/model_executor/models/module_mapping.py @@ -2,46 +2,9 @@ # https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py from dataclasses import dataclass, field -from enum import IntEnum from typing import List, Union -class ModelComposeMethod(IntEnum): - """ - `ModelComposeMethod` distinguishes between two architectural patterns in - multi-modal models, focusing on how vision model, language model, and - projector are implemented: - 1. Decoupled Architecture (like mllama, InternVL, miniCPMV), complete - independent implementation with its own layers, for example: - ``` - InternVLChatModel - ├── vision_model (visual encoder) - │ ├── embeddings - │ └── encoder - ├── language_model (language model) - │ ├── tok_embeddings - │ └── layers - └── mlp1 (projector) - ``` - 2. Coupled Architecture (like QWenVL, GLM4V), Integrated as a sub-module - with shared architectural patterns , for example: - - ``` - QWenVL - └── transformer - ├── wte - ├── h (language model) - ├── ln_f - └── visual (visual encoder) - ├── conv1 - ├── transformer - └── attn_pool (projector) - ``` - """ - Decoupled = 0 - Coupled = 1 - - @dataclass class ModelKeys: model_type: str = None @@ -78,8 +41,6 @@ class ModelKeys: output: str = None - compose_type: str = None - @dataclass class MultiModelKeys(ModelKeys): @@ -94,9 +55,7 @@ def from_string_field(language_model: Union[str, List[str]] = None, connector: Union[str, List[str]] = None, tower_model: Union[str, List[str]] = None, generator: Union[str, List[str]] = None, - compose_type: str = None, **kwargs) -> 'MultiModelKeys': - assert compose_type, "compose_type is not allowed to be None" def to_list(value): if value is None: @@ -107,5 +66,4 @@ def to_list(value): connector=to_list(connector), tower_model=to_list(tower_model), generator=to_list(generator), - compose_type=compose_type, **kwargs) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 04fdb27f42141..0a1b40927e9f9 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -40,8 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.module_mapping import (ModelComposeMethod, - MultiModelKeys) +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -1042,8 +1041,7 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="transformer.h", connector="transformer.visual.attn_pool", - tower_model="transformer.visual.transformer", - compose_type=ModelComposeMethod.Coupled) + tower_model="transformer.visual.transformer") @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)