Skip to content

Commit 945ad61

Browse files
WoosukKwonsumitd2
authored andcommitted
[V1] Support sliding window attention (vllm-project#9679)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent 97615d2 commit 945ad61

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

vllm/v1/attention/backends/flash_attn.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@ def __init__(
8282
if alibi_slopes is not None:
8383
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
8484
self.alibi_slopes = alibi_slopes
85-
self.sliding_window = ((sliding_window, sliding_window)
86-
if sliding_window is not None else (-1, -1))
85+
if sliding_window is None:
86+
self.sliding_window = (-1, -1)
87+
else:
88+
self.sliding_window = (sliding_window - 1, 0)
8789
self.kv_cache_dtype = kv_cache_dtype
8890
if logits_soft_cap is None:
8991
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
@@ -93,12 +95,6 @@ def __init__(
9395
assert self.num_heads % self.num_kv_heads == 0
9496
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
9597

96-
if sliding_window is not None:
97-
# NOTE(woosuk): flash-attn's sliding window does not work with
98-
# paged KV cache.
99-
raise ValueError(
100-
"Sliding window is not supported in FlashAttention.")
101-
10298
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
10399
if head_size not in support_head_sizes:
104100
raise ValueError(

0 commit comments

Comments
 (0)