Skip to content

Commit ac6fa44

Browse files
B-201jeejeelee
authored andcommitted
[Model] Add BNB quantization support for Idefics3 (vllm-project#10310)
Signed-off-by: B-201 <Joy25810@foxmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
1 parent e4a392b commit ac6fa44

File tree

1 file changed

+61
-7
lines changed

1 file changed

+61
-7
lines changed

vllm/model_executor/models/idefics3.py

+61-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from PIL import Image
2323
from torch import nn
2424
# Temporary solution for transformers below 4.46.0.
25+
from transformers import PretrainedConfig as Idefics3Config
2526
from transformers import ProcessorMixin as Idefics3ImageProcessor
2627

2728
from vllm.attention import AttentionMetadata
@@ -31,6 +32,7 @@
3132
from vllm.logger import init_logger
3233
from vllm.model_executor.layers.linear import ReplicatedLinear
3334
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35+
from vllm.model_executor.layers.quantization import QuantizationConfig
3436
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
3537
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
3638
from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -374,12 +376,23 @@ def dummy_data_for_idefics3(
374376

375377
class Idefics3SimpleMLP(nn.Module):
376378

377-
def __init__(self, config):
379+
def __init__(
380+
self,
381+
config: Idefics3Config,
382+
quant_config: Optional[QuantizationConfig] = None,
383+
prefix: str = "",
384+
):
378385
super().__init__()
379386
input_size = config.vision_config.hidden_size * (config.scale_factor**
380387
2)
381388
output_size = config.text_config.hidden_size
382-
self.proj = ReplicatedLinear(input_size, output_size, bias=False)
389+
self.proj = ReplicatedLinear(
390+
input_size,
391+
output_size,
392+
bias=False,
393+
quant_config=quant_config,
394+
prefix=maybe_prefix(prefix, "proj"),
395+
)
383396

384397
def forward(self, x: torch.Tensor) -> torch.Tensor:
385398
out, _ = self.proj(x)
@@ -388,10 +401,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
388401

389402
class Idefics3Connector(nn.Module):
390403

391-
def __init__(self, config):
404+
def __init__(
405+
self,
406+
config: Idefics3Config,
407+
quant_config: Optional[QuantizationConfig] = None,
408+
prefix: str = "",
409+
):
392410
super().__init__()
393411
self.scale_factor = config.scale_factor
394-
self.modality_projection = Idefics3SimpleMLP(config)
412+
self.modality_projection = Idefics3SimpleMLP(
413+
config,
414+
quant_config,
415+
prefix=maybe_prefix(prefix, "modality_projection"),
416+
)
395417

396418
def pixel_shuffle(self,
397419
x: torch.Tensor,
@@ -431,9 +453,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
431453
self.config = config
432454
self.padding_idx = self.config.text_config.pad_token_id
433455
self.vocab_size = self.config.text_config.vocab_size
434-
self.vision_model = Idefics3VisionTransformer(config.vision_config,
435-
quant_config)
436-
self.connector = Idefics3Connector(config)
456+
self.vision_model = Idefics3VisionTransformer(
457+
config.vision_config,
458+
quant_config=quant_config,
459+
prefix=maybe_prefix(prefix, "vision_model"))
460+
self.connector = Idefics3Connector(
461+
config,
462+
quant_config,
463+
prefix=maybe_prefix(prefix, "connector"),
464+
)
437465
self.text_model = LlamaModel(
438466
vllm_config=vllm_config.with_hf_config(config.text_config),
439467
prefix=maybe_prefix(prefix, "text_model"),
@@ -637,6 +665,32 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
637665
"gate_up_proj",
638666
"down_proj",
639667
]
668+
669+
# BitandBytes specific attributes
670+
default_bitsandbytes_target_modules = [
671+
".gate_proj.",
672+
".down_proj.",
673+
".up_proj.",
674+
".q_proj.",
675+
".k_proj.",
676+
".v_proj.",
677+
".o_proj.",
678+
# vision_model
679+
".fc1.",
680+
".fc2.",
681+
".out_proj.",
682+
# connector
683+
".proj.",
684+
]
685+
bitsandbytes_stacked_params_mapping = {
686+
# shard_name, weight_name, index
687+
"q_proj": ("qkv_proj", 0),
688+
"k_proj": ("qkv_proj", 1),
689+
"v_proj": ("qkv_proj", 2),
690+
"gate_proj": ("gate_up_proj", 0),
691+
"up_proj": ("gate_up_proj", 1),
692+
}
693+
640694
embedding_modules = {}
641695
embedding_padding_modules = []
642696

0 commit comments

Comments
 (0)