3
3
4
4
import pytest
5
5
6
+ from tests .kernels .utils import override_backend_env_variable
6
7
from vllm import LLM , SamplingParams
7
8
8
9
from .conftest import get_text_from_llm_generator
28
29
@pytest .mark .parametrize ("test_llm_kwargs" , [{}])
29
30
@pytest .mark .parametrize ("batch_size" , [5 ])
30
31
@pytest .mark .parametrize ("seed" , [1 ])
32
+ @pytest .mark .parametrize ("backend" , ["FLASH_ATTN" , "FLASHINFER" , "XFORMERS" ])
31
33
def test_sliding_window_retrival (baseline_llm_generator , test_llm_generator ,
32
- batch_size , seed ):
34
+ batch_size , seed , backend , monkeypatch ):
33
35
"""
34
36
The test does a bunch of assignments "x1 = 10\n x2 = 33\n ..." and then
35
37
asks for value of one of them (which is outside the sliding window).
@@ -38,6 +40,8 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
38
40
39
41
Additionally, we compare the results of the v1 and v2 managers.
40
42
"""
43
+ override_backend_env_variable (monkeypatch , backend )
44
+
41
45
sampling_params = SamplingParams (
42
46
max_tokens = 1024 ,
43
47
ignore_eos = True ,
@@ -84,7 +88,9 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
84
88
@pytest .mark .parametrize ("test_llm_kwargs" , [{"enable_chunked_prefill" : True }])
85
89
@pytest .mark .parametrize ("batch_size" , [5 ])
86
90
@pytest .mark .parametrize ("seed" , [1 ])
87
- def test_sliding_window_chunked_prefill (test_llm_generator , batch_size , seed ):
91
+ @pytest .mark .parametrize ("backend" , ["FLASH_ATTN" , "FLASHINFER" , "XFORMERS" ])
92
+ def test_sliding_window_chunked_prefill (test_llm_generator , batch_size , seed ,
93
+ backend , monkeypatch ):
88
94
"""
89
95
This is similar to test_sliding_window_retrival, however, it doesn't
90
96
compare against the v1 block manager since v1 doesn't support
@@ -93,6 +99,8 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
93
99
The results with and without chunked prefill are not the same due to
94
100
numerical instabilities.
95
101
"""
102
+ override_backend_env_variable (monkeypatch , backend )
103
+
96
104
sampling_params = SamplingParams (
97
105
max_tokens = 10 ,
98
106
ignore_eos = True ,
0 commit comments