diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 8c64c2bfdeb8f..300bab72877b8 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -317,7 +317,7 @@ def forward( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - flash_attn_varlen_func( + out = flash_attn_varlen_func( q=query, k=key, v=value, @@ -329,13 +329,14 @@ def forward( causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, - out=output[:num_prefill_tokens], ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - flash_attn_varlen_func( + output[:num_prefill_tokens] = flash_attn_varlen_func( q=query, k=key_cache, v=value_cache, @@ -347,12 +348,11 @@ def forward( causal=True, alibi_slopes=self.alibi_slopes, block_table=prefill_meta.block_tables, - out=output[:num_prefill_tokens], ) if decode_meta := attn_metadata.decode_metadata: # Decoding run. - flash_attn_with_kvcache( + output[num_prefill_tokens:] = flash_attn_with_kvcache( decode_query.unsqueeze(1), key_cache, value_cache, @@ -361,8 +361,7 @@ def forward( softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, - out=output[num_prefill_tokens:].unsqueeze(1), - ) + ).squeeze(1) # Reshape the output tensor. return output.view(num_tokens, hidden_size)