@@ -20,21 +20,21 @@ def test_env(name: str, device: str, monkeypatch):
20
20
21
21
if device == "cpu" :
22
22
with patch ("vllm.attention.selector.is_cpu" , return_value = True ):
23
- backend = which_attn_to_use (16 , None , torch .float16 , torch .float16 ,
24
- 16 , False )
23
+ backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 ,
24
+ False )
25
25
assert backend .name == "TORCH_SDPA"
26
26
elif device == "hip" :
27
27
with patch ("vllm.attention.selector.is_hip" , return_value = True ):
28
- backend = which_attn_to_use (16 , None , torch .float16 , torch .float16 ,
29
- 16 , False )
28
+ backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 ,
29
+ False )
30
30
assert backend .name == "ROCM_FLASH"
31
31
elif device == "openvino" :
32
32
with patch ("vllm.attention.selector.is_openvino" , return_value = True ):
33
- backend = which_attn_to_use (16 , None , torch .float16 , torch .float16 ,
34
- 16 , False )
33
+ backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 ,
34
+ False )
35
35
assert backend .name == "OPENVINO"
36
36
else :
37
- backend = which_attn_to_use (16 , None , torch .float16 , torch .float16 , 16 ,
37
+ backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 ,
38
38
False )
39
39
assert backend .name == name
40
40
@@ -46,42 +46,37 @@ def test_flash_attn(monkeypatch):
46
46
47
47
# Unsupported CUDA arch
48
48
with patch ("torch.cuda.get_device_capability" , return_value = (7 , 5 )):
49
- backend = which_attn_to_use (16 , None , torch .float16 , None , 16 , False )
49
+ backend = which_attn_to_use (16 , torch .float16 , None , 16 , False )
50
50
assert backend .name != STR_FLASH_ATTN_VAL
51
51
52
52
# Unsupported data type
53
- backend = which_attn_to_use (16 , None , torch .float8_e4m3fn , None , 16 , False )
53
+ backend = which_attn_to_use (16 , torch .float8_e4m3fn , None , 16 , False )
54
54
assert backend .name != STR_FLASH_ATTN_VAL
55
55
56
56
# Unsupported kv cache data type
57
- backend = which_attn_to_use (16 , None , torch .float16 , "fp8" , 16 , False )
57
+ backend = which_attn_to_use (16 , torch .float16 , "fp8" , 16 , False )
58
58
assert backend .name != STR_FLASH_ATTN_VAL
59
59
60
60
# Unsupported block size
61
- backend = which_attn_to_use (16 , None , torch .float16 , None , 8 , False )
62
- assert backend .name != STR_FLASH_ATTN_VAL
63
-
64
- # Unsupported sliding window
65
- backend = which_attn_to_use (16 , 1 , torch .float16 , None , 16 , False )
61
+ backend = which_attn_to_use (16 , torch .float16 , None , 8 , False )
66
62
assert backend .name != STR_FLASH_ATTN_VAL
67
63
68
64
# flash-attn is not installed
69
65
with patch .dict ('sys.modules' , {'vllm_flash_attn' : None }):
70
- backend = which_attn_to_use (16 , None , torch .float16 , None , 16 , False )
66
+ backend = which_attn_to_use (16 , torch .float16 , None , 16 , False )
71
67
assert backend .name != STR_FLASH_ATTN_VAL
72
68
73
69
# Unsupported head size
74
- backend = which_attn_to_use (17 , None , torch .float16 , None , 16 , False )
70
+ backend = which_attn_to_use (17 , torch .float16 , None , 16 , False )
75
71
assert backend .name != STR_FLASH_ATTN_VAL
76
72
77
73
# Attention-free models should bypass env and use PlaceholderAttention
78
- backend = which_attn_to_use (16 , None , torch .float16 , torch .float16 , 16 ,
79
- True )
74
+ backend = which_attn_to_use (16 , torch .float16 , torch .float16 , 16 , True )
80
75
assert backend .name != STR_FLASH_ATTN_VAL
81
76
82
77
83
78
def test_invalid_env (monkeypatch ):
84
79
"""Throw an exception if the backend name is invalid."""
85
80
override_backend_env_variable (monkeypatch , STR_INVALID_VAL )
86
81
with pytest .raises (ValueError ):
87
- which_attn_to_use (16 , None , torch .float16 , None , 16 , False )
82
+ which_attn_to_use (16 , torch .float16 , None , 16 , False )
0 commit comments