Skip to content

Commit 1721bdd

Browse files
Isotr0pymfournioux
authored andcommitted
[Model] Remove transformers attention porting in VITs (vllm-project#10414)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
1 parent 76d81f2 commit 1721bdd

File tree

7 files changed

+139
-102
lines changed

7 files changed

+139
-102
lines changed

vllm/model_executor/models/blip.py

+36-30
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
import torch
66
import torch.nn as nn
7+
import torch.nn.functional as F
78
from PIL import Image
89
from transformers import Blip2VisionConfig, BlipVisionConfig
9-
from transformers.models.blip.modeling_blip import BlipAttention
1010

11+
from vllm.attention.selector import _Backend
1112
from vllm.config import ModelConfig
1213
from vllm.distributed import divide, get_tensor_model_parallel_world_size
1314
from vllm.inputs import DecoderOnlyInputs, token_inputs
@@ -21,11 +22,7 @@
2122
repeat_and_pad_placeholder_tokens)
2223
from vllm.sequence import SequenceData
2324

24-
try:
25-
from xformers import ops as xops
26-
USE_XFORMERS_OPS = True
27-
except ImportError:
28-
USE_XFORMERS_OPS = False
25+
from .utils import get_vit_attn_backend
2926

3027

3128
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
@@ -168,7 +165,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
168165
return embeddings
169166

170167

171-
class BlipParallelAttention(nn.Module):
168+
class BlipAttention(nn.Module):
172169
"""Multi-headed attention from 'Attention Is All You Need' paper"""
173170

174171
def __init__(
@@ -208,6 +205,12 @@ def __init__(
208205
self.tp_size = get_tensor_model_parallel_world_size()
209206
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
210207

208+
# Detect attention implementation.
209+
self.attn_backend = get_vit_attn_backend(support_fa=False)
210+
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
211+
raise RuntimeError(
212+
f"BLIP does not support {self.attn_backend} backend now.")
213+
211214
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
212215
return tensor.view(bsz, seq_len, self.num_heads,
213216
self.head_dim).transpose(1, 2).contiguous()
@@ -231,11 +234,26 @@ def forward(
231234
self.num_heads_per_partition,
232235
self.head_dim)
233236

234-
out = xops.memory_efficient_attention_forward(query_states,
235-
key_states,
236-
value_states,
237-
p=self.dropout,
238-
scale=self.scale)
237+
if self.attn_backend == _Backend.XFORMERS:
238+
from xformers import ops as xops
239+
240+
out = xops.memory_efficient_attention_forward(query_states,
241+
key_states,
242+
value_states,
243+
p=self.dropout,
244+
scale=self.scale)
245+
elif self.attn_backend == _Backend.TORCH_SDPA:
246+
query_states, key_states, value_states = (x.transpose(1, 2)
247+
for x in (query_states,
248+
key_states,
249+
value_states))
250+
out = F.scaled_dot_product_attention(query_states,
251+
key_states,
252+
value_states,
253+
dropout_p=self.dropout,
254+
scale=self.scale)
255+
out = out.transpose(1, 2)
256+
239257
out = out.view(bsz, tgt_len, -1)
240258
attn_output, _ = self.projection(out)
241259

@@ -285,18 +303,11 @@ def __init__(
285303
super().__init__()
286304

287305
# fallback to sdpa attention if tp unavailable
288-
num_heads = config.num_attention_heads
289-
tp_size = get_tensor_model_parallel_world_size()
290-
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
291-
self.self_attn = BlipParallelAttention(
292-
config,
293-
quant_config=quant_config,
294-
prefix=f"{prefix}.self_attn",
295-
)
296-
else:
297-
# Blip doesn't have SDPA attention implemented in transformers
298-
# use eager attention instead for cpu backend
299-
self.self_attn = BlipAttention(config)
306+
self.self_attn = BlipAttention(
307+
config,
308+
quant_config=quant_config,
309+
prefix=f"{prefix}.self_attn",
310+
)
300311
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
301312
eps=config.layer_norm_eps)
302313
self.mlp = BlipMLP(config,
@@ -374,11 +385,6 @@ def __init__(
374385
prefix: str = "",
375386
) -> None:
376387
super().__init__()
377-
378-
tp_size = get_tensor_model_parallel_world_size()
379-
num_heads = config.num_attention_heads
380-
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
381-
382388
self.config = config
383389

384390
self.embeddings = BlipVisionEmbeddings(config)
@@ -422,7 +428,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
422428
("qkv_proj", "q_proj", "q"),
423429
("qkv_proj", "k_proj", "k"),
424430
("qkv_proj", "v_proj", "v"),
425-
] if self.shard_weight else []
431+
]
426432
params_dict = dict(self.named_parameters())
427433
loaded_params: Set[str] = set()
428434
layer_count = len(self.encoder.layers)

vllm/model_executor/models/clip.py

+36-29
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import numpy as np
66
import torch
77
import torch.nn as nn
8+
import torch.nn.functional as F
89
from PIL import Image
910
from transformers import CLIPVisionConfig
10-
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
1111

12+
from vllm.attention.selector import _Backend
1213
from vllm.config import ModelConfig
1314
from vllm.distributed import divide, get_tensor_model_parallel_world_size
1415
from vllm.inputs import DecoderOnlyInputs, token_inputs
@@ -23,11 +24,7 @@
2324
repeat_and_pad_placeholder_tokens)
2425
from vllm.sequence import SequenceData
2526

26-
try:
27-
from xformers import ops as xops
28-
USE_XFORMERS_OPS = True
29-
except ImportError:
30-
USE_XFORMERS_OPS = False
27+
from .utils import get_vit_attn_backend
3128

3229

3330
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
@@ -197,7 +194,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
197194
return embeddings
198195

199196

200-
class CLIPParallelAttention(nn.Module):
197+
class CLIPAttention(nn.Module):
201198
"""Multi-headed attention from 'Attention Is All You Need' paper"""
202199

203200
def __init__(
@@ -237,6 +234,12 @@ def __init__(
237234
self.tp_size = get_tensor_model_parallel_world_size()
238235
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
239236

237+
# Detect attention implementation.
238+
self.attn_backend = get_vit_attn_backend(support_fa=False)
239+
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
240+
raise RuntimeError(
241+
f"CLIP does not support {self.attn_backend} backend now.")
242+
240243
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
241244
return tensor.view(bsz, seq_len, self.num_heads,
242245
self.head_dim).transpose(1, 2).contiguous()
@@ -261,11 +264,26 @@ def forward(
261264
self.num_heads_per_partition,
262265
self.head_dim)
263266

264-
out = xops.memory_efficient_attention_forward(query_states,
265-
key_states,
266-
value_states,
267-
p=self.dropout,
268-
scale=self.scale)
267+
if self.attn_backend == _Backend.XFORMERS:
268+
from xformers import ops as xops
269+
270+
out = xops.memory_efficient_attention_forward(query_states,
271+
key_states,
272+
value_states,
273+
p=self.dropout,
274+
scale=self.scale)
275+
elif self.attn_backend == _Backend.TORCH_SDPA:
276+
query_states, key_states, value_states = (x.transpose(1, 2)
277+
for x in (query_states,
278+
key_states,
279+
value_states))
280+
out = F.scaled_dot_product_attention(query_states,
281+
key_states,
282+
value_states,
283+
dropout_p=self.dropout,
284+
scale=self.scale)
285+
out = out.transpose(1, 2)
286+
269287
out = out.view(bsz, tgt_len, -1)
270288
attn_output, _ = self.out_proj(out)
271289

@@ -311,17 +329,11 @@ def __init__(
311329
prefix: str = "",
312330
) -> None:
313331
super().__init__()
314-
315-
num_heads = config.num_attention_heads
316-
tp_size = get_tensor_model_parallel_world_size()
317-
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
318-
self.self_attn = CLIPParallelAttention(
319-
config,
320-
quant_config=quant_config,
321-
prefix=f"{prefix}.self_attn",
322-
)
323-
else:
324-
self.self_attn = CLIPSdpaAttention(config)
332+
self.self_attn = CLIPAttention(
333+
config,
334+
quant_config=quant_config,
335+
prefix=f"{prefix}.self_attn",
336+
)
325337
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
326338
eps=config.layer_norm_eps)
327339
self.mlp = CLIPMLP(config,
@@ -461,11 +473,6 @@ def __init__(
461473
prefix: str = "",
462474
) -> None:
463475
super().__init__()
464-
465-
tp_size = get_tensor_model_parallel_world_size()
466-
num_heads = config.num_attention_heads
467-
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
468-
469476
self.vision_model = CLIPVisionTransformer(
470477
config=config,
471478
quant_config=quant_config,
@@ -490,7 +497,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
490497
("qkv_proj", "q_proj", "q"),
491498
("qkv_proj", "k_proj", "k"),
492499
("qkv_proj", "v_proj", "v"),
493-
] if self.shard_weight else []
500+
]
494501
params_dict = dict(self.named_parameters())
495502
loaded_params: Set[str] = set()
496503
layer_count = len(self.vision_model.encoder.layers)

vllm/model_executor/models/intern_vit.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn.functional as F
1313
from transformers import PretrainedConfig
1414

15+
from vllm.attention.selector import _Backend
1516
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1617
get_tensor_model_parallel_world_size,
1718
split_tensor_along_last_dim,
@@ -24,11 +25,7 @@
2425
from vllm.model_executor.layers.quantization import QuantizationConfig
2526
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2627

27-
try:
28-
from xformers import ops as xops
29-
USE_XFORMERS_OPS = True
30-
except ImportError:
31-
USE_XFORMERS_OPS = False
28+
from .utils import get_vit_attn_backend
3229

3330
NORM2FN = {
3431
'rms_norm': RMSNorm,
@@ -186,6 +183,11 @@ def __init__(
186183
prefix=f"{prefix}.proj",
187184
)
188185

186+
self.attn_backend = get_vit_attn_backend(support_fa=False)
187+
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
188+
raise RuntimeError(
189+
f"InternViT does not support {self.attn_backend} backend now.")
190+
189191
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
190192
if self.tp_size > 1:
191193
q = tensor_model_parallel_all_gather(q.contiguous())
@@ -211,11 +213,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
211213
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
212214
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
213215

214-
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
215-
x = x.view(B, N, -1)
216+
if self.attn_backend == _Backend.XFORMERS:
217+
from xformers import ops as xops
216218

217-
x, _ = self.proj(x)
218-
return x
219+
out = xops.memory_efficient_attention_forward(q,
220+
k,
221+
v,
222+
scale=self.scale)
223+
elif self.attn_backend == _Backend.TORCH_SDPA:
224+
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
225+
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
226+
out = out.transpose(1, 2)
227+
228+
out = out.view(B, N, -1)
229+
out, _ = self.proj(out)
230+
return out
219231

220232

221233
class InternSdpaAttention(nn.Module):
@@ -362,7 +374,7 @@ def _init_attn(
362374
tp_size = get_tensor_model_parallel_world_size()
363375
num_heads = config.num_attention_heads
364376

365-
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
377+
if (num_heads + num_dummy_heads) % tp_size == 0:
366378
return InternParallelAttention(config,
367379
quant_config=quant_config,
368380
num_dummy_heads=num_dummy_heads,

vllm/model_executor/models/molmo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def __init__(
187187
)
188188

189189
# Detect attention implementation.
190-
self.attn_backend: _Backend = get_vit_attn_backend()
190+
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
191191
if self.attn_backend not in {
192192
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
193193
}:

vllm/model_executor/models/qwen2_vl.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def __init__(
260260
prefix=f"{prefix}.proj")
261261

262262
# Detect attention implementation.
263-
self.attn_backend: _Backend = get_vit_attn_backend()
263+
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
264264
if self.attn_backend not in {
265265
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
266266
}:

0 commit comments

Comments
 (0)