Skip to content

Commit c07b271

Browse files
Yard1joerunde
authored andcommitted
Revert "[Core] Remove unnecessary copies in flash attn backend" (vllm-project#5478)
1 parent a0a9a1e commit c07b271

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

vllm/attention/backends/flash_attn.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def forward(
317317
# normal attention
318318
# When block_tables are not filled, it means q and k are the
319319
# prompt, and they have the same length.
320-
flash_attn_varlen_func(
320+
out = flash_attn_varlen_func(
321321
q=query,
322322
k=key,
323323
v=value,
@@ -329,13 +329,14 @@ def forward(
329329
causal=True,
330330
window_size=self.sliding_window,
331331
alibi_slopes=self.alibi_slopes,
332-
out=output[:num_prefill_tokens],
333332
)
333+
assert output[:num_prefill_tokens].shape == out.shape
334+
output[:num_prefill_tokens] = out
334335
else:
335336
# prefix-enabled attention
336337
assert prefill_meta.seq_lens is not None
337338
max_seq_len = max(prefill_meta.seq_lens)
338-
flash_attn_varlen_func(
339+
output[:num_prefill_tokens] = flash_attn_varlen_func(
339340
q=query,
340341
k=key_cache,
341342
v=value_cache,
@@ -347,12 +348,11 @@ def forward(
347348
causal=True,
348349
alibi_slopes=self.alibi_slopes,
349350
block_table=prefill_meta.block_tables,
350-
out=output[:num_prefill_tokens],
351351
)
352352

353353
if decode_meta := attn_metadata.decode_metadata:
354354
# Decoding run.
355-
flash_attn_with_kvcache(
355+
output[num_prefill_tokens:] = flash_attn_with_kvcache(
356356
decode_query.unsqueeze(1),
357357
key_cache,
358358
value_cache,
@@ -361,8 +361,7 @@ def forward(
361361
softmax_scale=self.scale,
362362
causal=True,
363363
alibi_slopes=self.alibi_slopes,
364-
out=output[num_prefill_tokens:].unsqueeze(1),
365-
)
364+
).squeeze(1)
366365

367366
# Reshape the output tensor.
368367
return output.view(num_tokens, hidden_size)

0 commit comments

Comments
 (0)