@@ -317,7 +317,7 @@ def forward(
317
317
# normal attention
318
318
# When block_tables are not filled, it means q and k are the
319
319
# prompt, and they have the same length.
320
- flash_attn_varlen_func (
320
+ out = flash_attn_varlen_func (
321
321
q = query ,
322
322
k = key ,
323
323
v = value ,
@@ -329,13 +329,14 @@ def forward(
329
329
causal = True ,
330
330
window_size = self .sliding_window ,
331
331
alibi_slopes = self .alibi_slopes ,
332
- out = output [:num_prefill_tokens ],
333
332
)
333
+ assert output [:num_prefill_tokens ].shape == out .shape
334
+ output [:num_prefill_tokens ] = out
334
335
else :
335
336
# prefix-enabled attention
336
337
assert prefill_meta .seq_lens is not None
337
338
max_seq_len = max (prefill_meta .seq_lens )
338
- flash_attn_varlen_func (
339
+ output [: num_prefill_tokens ] = flash_attn_varlen_func (
339
340
q = query ,
340
341
k = key_cache ,
341
342
v = value_cache ,
@@ -347,12 +348,11 @@ def forward(
347
348
causal = True ,
348
349
alibi_slopes = self .alibi_slopes ,
349
350
block_table = prefill_meta .block_tables ,
350
- out = output [:num_prefill_tokens ],
351
351
)
352
352
353
353
if decode_meta := attn_metadata .decode_metadata :
354
354
# Decoding run.
355
- flash_attn_with_kvcache (
355
+ output [ num_prefill_tokens :] = flash_attn_with_kvcache (
356
356
decode_query .unsqueeze (1 ),
357
357
key_cache ,
358
358
value_cache ,
@@ -361,8 +361,7 @@ def forward(
361
361
softmax_scale = self .scale ,
362
362
causal = True ,
363
363
alibi_slopes = self .alibi_slopes ,
364
- out = output [num_prefill_tokens :].unsqueeze (1 ),
365
- )
364
+ ).squeeze (1 )
366
365
367
366
# Reshape the output tensor.
368
367
return output .view (num_tokens , hidden_size )
0 commit comments