Skip to content

Commit 955f789

Browse files
pavanimajetyanko-intel
authored andcommitted
[Core] Add Sliding Window Support with Flashinfer (vllm-project#10462)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
1 parent 7590527 commit 955f789

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

tests/core/block/e2e/test_correctness_sliding_window.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55

6+
from tests.kernels.utils import override_backend_env_variable
67
from vllm import LLM, SamplingParams
78

89
from .conftest import get_text_from_llm_generator
@@ -28,8 +29,9 @@
2829
@pytest.mark.parametrize("test_llm_kwargs", [{}])
2930
@pytest.mark.parametrize("batch_size", [5])
3031
@pytest.mark.parametrize("seed", [1])
32+
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
3133
def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
32-
batch_size, seed):
34+
batch_size, seed, backend, monkeypatch):
3335
"""
3436
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
3537
asks for value of one of them (which is outside the sliding window).
@@ -38,6 +40,8 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
3840
3941
Additionally, we compare the results of the v1 and v2 managers.
4042
"""
43+
override_backend_env_variable(monkeypatch, backend)
44+
4145
sampling_params = SamplingParams(
4246
max_tokens=1024,
4347
ignore_eos=True,
@@ -84,7 +88,9 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
8488
@pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}])
8589
@pytest.mark.parametrize("batch_size", [5])
8690
@pytest.mark.parametrize("seed", [1])
87-
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
91+
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
92+
def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
93+
backend, monkeypatch):
8894
"""
8995
This is similar to test_sliding_window_retrival, however, it doesn't
9096
compare against the v1 block manager since v1 doesn't support
@@ -93,6 +99,8 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
9399
The results with and without chunked prefill are not the same due to
94100
numerical instabilities.
95101
"""
102+
override_backend_env_variable(monkeypatch, backend)
103+
96104
sampling_params = SamplingParams(
97105
max_tokens=10,
98106
ignore_eos=True,

vllm/attention/backends/flashinfer.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -757,9 +757,8 @@ def __init__(
757757
if alibi_slopes is not None:
758758
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
759759
self.alibi_slopes = alibi_slopes
760-
if sliding_window is not None:
761-
raise ValueError("Sliding window is not supported in FlashInfer.")
762-
self.sliding_window = (-1, -1)
760+
self.sliding_window = ((sliding_window - 1,
761+
0) if sliding_window is not None else (-1, -1))
763762
self.kv_cache_dtype = kv_cache_dtype
764763
self.logits_soft_cap = logits_soft_cap
765764

@@ -865,6 +864,8 @@ def unified_flash_infer(
865864
assert query.shape[0] == num_prefill_tokens
866865
assert decode_query.shape[0] == num_decode_tokens
867866

867+
window_left = window_size[0] if window_size is not None else -1
868+
868869
prefill_output: Optional[torch.Tensor] = None
869870
decode_output: Optional[torch.Tensor] = None
870871
if prefill_meta := attn_metadata.prefill_metadata:
@@ -895,7 +896,8 @@ def unified_flash_infer(
895896
logits_soft_cap=logits_soft_cap,
896897
causal=True,
897898
k_scale=k_scale,
898-
v_scale=v_scale)
899+
v_scale=v_scale,
900+
window_left=window_left)
899901
if decode_meta := attn_metadata.decode_metadata:
900902
assert attn_metadata.decode_metadata is not None
901903
assert attn_metadata.decode_metadata.decode_wrapper is not None
@@ -905,7 +907,8 @@ def unified_flash_infer(
905907
sm_scale=softmax_scale,
906908
logits_soft_cap=logits_soft_cap,
907909
k_scale=k_scale,
908-
v_scale=v_scale)
910+
v_scale=v_scale,
911+
window_left=window_left)
909912

910913
if prefill_output is None and decode_output is not None:
911914
# Decode only batch.

0 commit comments

Comments
 (0)