22
22
from PIL import Image
23
23
from torch import nn
24
24
# Temporary solution for transformers below 4.46.0.
25
+ from transformers import PretrainedConfig as Idefics3Config
25
26
from transformers import ProcessorMixin as Idefics3ImageProcessor
26
27
27
28
from vllm .attention import AttentionMetadata
31
32
from vllm .logger import init_logger
32
33
from vllm .model_executor .layers .linear import ReplicatedLinear
33
34
from vllm .model_executor .layers .logits_processor import LogitsProcessor
35
+ from vllm .model_executor .layers .quantization import QuantizationConfig
34
36
from vllm .model_executor .layers .sampler import Sampler , SamplerOutput
35
37
from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
36
38
from vllm .model_executor .models .module_mapping import MultiModelKeys
@@ -374,12 +376,23 @@ def dummy_data_for_idefics3(
374
376
375
377
class Idefics3SimpleMLP (nn .Module ):
376
378
377
- def __init__ (self , config ):
379
+ def __init__ (
380
+ self ,
381
+ config : Idefics3Config ,
382
+ quant_config : Optional [QuantizationConfig ] = None ,
383
+ prefix : str = "" ,
384
+ ):
378
385
super ().__init__ ()
379
386
input_size = config .vision_config .hidden_size * (config .scale_factor **
380
387
2 )
381
388
output_size = config .text_config .hidden_size
382
- self .proj = ReplicatedLinear (input_size , output_size , bias = False )
389
+ self .proj = ReplicatedLinear (
390
+ input_size ,
391
+ output_size ,
392
+ bias = False ,
393
+ quant_config = quant_config ,
394
+ prefix = maybe_prefix (prefix , "proj" ),
395
+ )
383
396
384
397
def forward (self , x : torch .Tensor ) -> torch .Tensor :
385
398
out , _ = self .proj (x )
@@ -388,10 +401,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
388
401
389
402
class Idefics3Connector (nn .Module ):
390
403
391
- def __init__ (self , config ):
404
+ def __init__ (
405
+ self ,
406
+ config : Idefics3Config ,
407
+ quant_config : Optional [QuantizationConfig ] = None ,
408
+ prefix : str = "" ,
409
+ ):
392
410
super ().__init__ ()
393
411
self .scale_factor = config .scale_factor
394
- self .modality_projection = Idefics3SimpleMLP (config )
412
+ self .modality_projection = Idefics3SimpleMLP (
413
+ config ,
414
+ quant_config ,
415
+ prefix = maybe_prefix (prefix , "modality_projection" ),
416
+ )
395
417
396
418
def pixel_shuffle (self ,
397
419
x : torch .Tensor ,
@@ -431,9 +453,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
431
453
self .config = config
432
454
self .padding_idx = self .config .text_config .pad_token_id
433
455
self .vocab_size = self .config .text_config .vocab_size
434
- self .vision_model = Idefics3VisionTransformer (config .vision_config ,
435
- quant_config )
436
- self .connector = Idefics3Connector (config )
456
+ self .vision_model = Idefics3VisionTransformer (
457
+ config .vision_config ,
458
+ quant_config = quant_config ,
459
+ prefix = maybe_prefix (prefix , "vision_model" ))
460
+ self .connector = Idefics3Connector (
461
+ config ,
462
+ quant_config ,
463
+ prefix = maybe_prefix (prefix , "connector" ),
464
+ )
437
465
self .text_model = LlamaModel (
438
466
vllm_config = vllm_config .with_hf_config (config .text_config ),
439
467
prefix = maybe_prefix (prefix , "text_model" ),
@@ -637,6 +665,32 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
637
665
"gate_up_proj" ,
638
666
"down_proj" ,
639
667
]
668
+
669
+ # BitandBytes specific attributes
670
+ default_bitsandbytes_target_modules = [
671
+ ".gate_proj." ,
672
+ ".down_proj." ,
673
+ ".up_proj." ,
674
+ ".q_proj." ,
675
+ ".k_proj." ,
676
+ ".v_proj." ,
677
+ ".o_proj." ,
678
+ # vision_model
679
+ ".fc1." ,
680
+ ".fc2." ,
681
+ ".out_proj." ,
682
+ # connector
683
+ ".proj." ,
684
+ ]
685
+ bitsandbytes_stacked_params_mapping = {
686
+ # shard_name, weight_name, index
687
+ "q_proj" : ("qkv_proj" , 0 ),
688
+ "k_proj" : ("qkv_proj" , 1 ),
689
+ "v_proj" : ("qkv_proj" , 2 ),
690
+ "gate_proj" : ("gate_up_proj" , 0 ),
691
+ "up_proj" : ("gate_up_proj" , 1 ),
692
+ }
693
+
640
694
embedding_modules = {}
641
695
embedding_padding_modules = []
642
696
0 commit comments