1
1
"""Minimal implementation of CLIPVisionModel intended to be only used
2
2
within a vision language model."""
3
- from typing import Iterable , List , Optional , Tuple , Union
3
+ from typing import Iterable , List , Optional , Set , Tuple , Union
4
4
5
5
import numpy as np
6
6
import torch
@@ -483,14 +483,16 @@ def device(self):
483
483
484
484
# (TODO) Add prefix argument for filtering out weights to be loaded
485
485
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
486
- def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
486
+ def load_weights (self , weights : Iterable [Tuple [str ,
487
+ torch .Tensor ]]) -> Set [str ]:
487
488
stacked_params_mapping = [
488
489
# (param_name, shard_name, shard_id)
489
490
("qkv_proj" , "q_proj" , "q" ),
490
491
("qkv_proj" , "k_proj" , "k" ),
491
492
("qkv_proj" , "v_proj" , "v" ),
492
493
] if self .shard_weight else []
493
494
params_dict = dict (self .named_parameters ())
495
+ loaded_params : Set [str ] = set ()
494
496
layer_count = len (self .vision_model .encoder .layers )
495
497
496
498
for name , loaded_weight in weights :
@@ -508,8 +510,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
508
510
for (param_name , weight_name , shard_id ) in stacked_params_mapping :
509
511
if weight_name not in name :
510
512
continue
513
+ name = name .replace (weight_name , param_name )
511
514
512
- param = params_dict [name . replace ( weight_name , param_name ) ]
515
+ param = params_dict [name ]
513
516
weight_loader = param .weight_loader
514
517
weight_loader (param , loaded_weight , shard_id )
515
518
break
@@ -518,3 +521,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
518
521
weight_loader = getattr (param , "weight_loader" ,
519
522
default_weight_loader )
520
523
weight_loader (param , loaded_weight )
524
+ loaded_params .add (name )
525
+ return loaded_params
0 commit comments