7
7
import pytest
8
8
from transformers import AutoModelForSeq2SeqLM
9
9
10
+ from vllm .attention .selector import (_Backend ,
11
+ global_force_attn_backend_context_manager )
10
12
from vllm .platforms import current_platform
11
13
from vllm .sequence import SampleLogprobs
12
14
13
15
from ..conftest import DecoderPromptType
14
16
from ..models .utils import check_logprobs_close
15
17
18
+ LIST_ENC_DEC_SUPPORTED_BACKENDS = [
19
+ _Backend .XFORMERS , _Backend .FLASH_ATTN , None
20
+ ]
21
+
16
22
17
23
def vllm_to_hf_output (
18
24
vllm_output : Tuple [List [int ], str , Optional [SampleLogprobs ]],
@@ -29,7 +35,8 @@ def vllm_to_hf_output(
29
35
30
36
31
37
@pytest .mark .parametrize ("model" , ["facebook/bart-large-cnn" ])
32
- @pytest .mark .parametrize ("dtype" , ["bfloat16" ])
38
+ @pytest .mark .parametrize ("dtype" , ["float" ])
39
+ @pytest .mark .parametrize ("attn_backend" , LIST_ENC_DEC_SUPPORTED_BACKENDS )
33
40
@pytest .mark .parametrize ("max_tokens" , [128 ])
34
41
@pytest .mark .parametrize ("num_logprobs" , [5 ])
35
42
@pytest .mark .parametrize ("decoder_prompt_type" , list (DecoderPromptType ))
@@ -48,6 +55,7 @@ def test_encoder_decoder_e2e(
48
55
num_logprobs : int ,
49
56
decoder_prompt_type : DecoderPromptType ,
50
57
enforce_eager : bool ,
58
+ attn_backend : _Backend ,
51
59
) -> None :
52
60
'''
53
61
End-to-End (E2E) test for the encoder-decoder framework.
@@ -56,43 +64,49 @@ def test_encoder_decoder_e2e(
56
64
implementations to ensure that both implementations produce consistent
57
65
and correct results.
58
66
'''
59
- test_case_prompts = example_encoder_decoder_prompts [decoder_prompt_type ]
67
+ with global_force_attn_backend_context_manager (attn_backend ):
68
+ if attn_backend == _Backend .FLASH_ATTN :
69
+ # Flash Attention works only with bfloat16 data-type
70
+ dtype = 'bfloat16'
71
+ test_case_prompts = example_encoder_decoder_prompts [
72
+ decoder_prompt_type ]
60
73
61
- # Configuration settings for HF baseline
62
- hf_kwargs = {
63
- "top_k" : None ,
64
- "num_beams" : 1 ,
65
- "repetition_penalty" : 1.0 ,
66
- "top_p" : 1.0 ,
67
- "length_penalty" : 1.0 ,
68
- "early_stopping" : False ,
69
- "no_repeat_ngram_size" : None ,
70
- "min_length" : 0
71
- }
74
+ # Configuration settings for HF baseline
75
+ hf_kwargs = {
76
+ "top_k" : None ,
77
+ "num_beams" : 1 ,
78
+ "repetition_penalty" : 1.0 ,
79
+ "top_p" : 1.0 ,
80
+ "length_penalty" : 1.0 ,
81
+ "early_stopping" : False ,
82
+ "no_repeat_ngram_size" : None ,
83
+ "min_length" : 0
84
+ }
72
85
73
- with hf_runner (model , dtype = dtype ,
74
- auto_cls = AutoModelForSeq2SeqLM ) as hf_model :
75
- hf_outputs = (hf_model .generate_encoder_decoder_greedy_logprobs_limit (
76
- test_case_prompts ,
77
- max_tokens ,
78
- num_logprobs ,
79
- ** hf_kwargs ,
80
- ))
81
- with vllm_runner (model , dtype = dtype ,
82
- enforce_eager = enforce_eager ) as vllm_model :
83
- vllm_outputs = vllm_model .generate_encoder_decoder_greedy_logprobs (
84
- test_case_prompts , max_tokens , num_logprobs )
86
+ with hf_runner (model , dtype = dtype ,
87
+ auto_cls = AutoModelForSeq2SeqLM ) as hf_model :
88
+ hf_outputs = (
89
+ hf_model .generate_encoder_decoder_greedy_logprobs_limit (
90
+ test_case_prompts ,
91
+ max_tokens ,
92
+ num_logprobs ,
93
+ ** hf_kwargs ,
94
+ ))
95
+ with vllm_runner (model , dtype = dtype ,
96
+ enforce_eager = enforce_eager ) as vllm_model :
97
+ vllm_outputs = vllm_model .generate_encoder_decoder_greedy_logprobs (
98
+ test_case_prompts , max_tokens , num_logprobs )
85
99
86
- hf_skip_tokens = (1
87
- if decoder_prompt_type == DecoderPromptType . NONE else 0 )
100
+ hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType . NONE
101
+ else 0 )
88
102
89
- check_logprobs_close (
90
- outputs_0_lst = hf_outputs ,
91
- outputs_1_lst = [
92
- vllm_to_hf_output (vllm_output , decoder_prompt_type )
93
- for vllm_output in vllm_outputs
94
- ],
95
- name_0 = "hf" ,
96
- name_1 = "vllm" ,
97
- num_outputs_0_skip_tokens = hf_skip_tokens ,
98
- )
103
+ check_logprobs_close (
104
+ outputs_0_lst = hf_outputs ,
105
+ outputs_1_lst = [
106
+ vllm_to_hf_output (vllm_output , decoder_prompt_type )
107
+ for vllm_output in vllm_outputs
108
+ ],
109
+ name_0 = "hf" ,
110
+ name_1 = "vllm" ,
111
+ num_outputs_0_skip_tokens = hf_skip_tokens ,
112
+ )
0 commit comments