4
4
5
5
import torch
6
6
import torch .nn as nn
7
+ import torch .nn .functional as F
7
8
from PIL import Image
8
9
from transformers import Blip2VisionConfig , BlipVisionConfig
9
- from transformers .models .blip .modeling_blip import BlipAttention
10
10
11
+ from vllm .attention .selector import _Backend
11
12
from vllm .config import ModelConfig
12
13
from vllm .distributed import divide , get_tensor_model_parallel_world_size
13
14
from vllm .inputs import DecoderOnlyInputs , token_inputs
21
22
repeat_and_pad_placeholder_tokens )
22
23
from vllm .sequence import SequenceData
23
24
24
- try :
25
- from xformers import ops as xops
26
- USE_XFORMERS_OPS = True
27
- except ImportError :
28
- USE_XFORMERS_OPS = False
25
+ from .utils import get_vit_attn_backend
29
26
30
27
31
28
def get_blip_patch_grid_length (* , image_size : int , patch_size : int ) -> int :
@@ -168,7 +165,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
168
165
return embeddings
169
166
170
167
171
- class BlipParallelAttention (nn .Module ):
168
+ class BlipAttention (nn .Module ):
172
169
"""Multi-headed attention from 'Attention Is All You Need' paper"""
173
170
174
171
def __init__ (
@@ -208,6 +205,12 @@ def __init__(
208
205
self .tp_size = get_tensor_model_parallel_world_size ()
209
206
self .num_heads_per_partition = divide (self .num_heads , self .tp_size )
210
207
208
+ # Detect attention implementation.
209
+ self .attn_backend = get_vit_attn_backend (support_fa = False )
210
+ if self .attn_backend not in {_Backend .TORCH_SDPA , _Backend .XFORMERS }:
211
+ raise RuntimeError (
212
+ f"BLIP does not support { self .attn_backend } backend now." )
213
+
211
214
def _shape (self , tensor : torch .Tensor , seq_len : int , bsz : int ):
212
215
return tensor .view (bsz , seq_len , self .num_heads ,
213
216
self .head_dim ).transpose (1 , 2 ).contiguous ()
@@ -231,11 +234,26 @@ def forward(
231
234
self .num_heads_per_partition ,
232
235
self .head_dim )
233
236
234
- out = xops .memory_efficient_attention_forward (query_states ,
235
- key_states ,
236
- value_states ,
237
- p = self .dropout ,
238
- scale = self .scale )
237
+ if self .attn_backend == _Backend .XFORMERS :
238
+ from xformers import ops as xops
239
+
240
+ out = xops .memory_efficient_attention_forward (query_states ,
241
+ key_states ,
242
+ value_states ,
243
+ p = self .dropout ,
244
+ scale = self .scale )
245
+ elif self .attn_backend == _Backend .TORCH_SDPA :
246
+ query_states , key_states , value_states = (x .transpose (1 , 2 )
247
+ for x in (query_states ,
248
+ key_states ,
249
+ value_states ))
250
+ out = F .scaled_dot_product_attention (query_states ,
251
+ key_states ,
252
+ value_states ,
253
+ dropout_p = self .dropout ,
254
+ scale = self .scale )
255
+ out = out .transpose (1 , 2 )
256
+
239
257
out = out .view (bsz , tgt_len , - 1 )
240
258
attn_output , _ = self .projection (out )
241
259
@@ -285,18 +303,11 @@ def __init__(
285
303
super ().__init__ ()
286
304
287
305
# fallback to sdpa attention if tp unavailable
288
- num_heads = config .num_attention_heads
289
- tp_size = get_tensor_model_parallel_world_size ()
290
- if USE_XFORMERS_OPS and num_heads % tp_size == 0 :
291
- self .self_attn = BlipParallelAttention (
292
- config ,
293
- quant_config = quant_config ,
294
- prefix = f"{ prefix } .self_attn" ,
295
- )
296
- else :
297
- # Blip doesn't have SDPA attention implemented in transformers
298
- # use eager attention instead for cpu backend
299
- self .self_attn = BlipAttention (config )
306
+ self .self_attn = BlipAttention (
307
+ config ,
308
+ quant_config = quant_config ,
309
+ prefix = f"{ prefix } .self_attn" ,
310
+ )
300
311
self .layer_norm1 = nn .LayerNorm (config .hidden_size ,
301
312
eps = config .layer_norm_eps )
302
313
self .mlp = BlipMLP (config ,
@@ -374,11 +385,6 @@ def __init__(
374
385
prefix : str = "" ,
375
386
) -> None :
376
387
super ().__init__ ()
377
-
378
- tp_size = get_tensor_model_parallel_world_size ()
379
- num_heads = config .num_attention_heads
380
- self .shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
381
-
382
388
self .config = config
383
389
384
390
self .embeddings = BlipVisionEmbeddings (config )
@@ -422,7 +428,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
422
428
("qkv_proj" , "q_proj" , "q" ),
423
429
("qkv_proj" , "k_proj" , "k" ),
424
430
("qkv_proj" , "v_proj" , "v" ),
425
- ] if self . shard_weight else []
431
+ ]
426
432
params_dict = dict (self .named_parameters ())
427
433
loaded_params : Set [str ] = set ()
428
434
layer_count = len (self .encoder .layers )
0 commit comments