Skip to content

Commit 27be13e

Browse files
sroy745sumitd2
authored andcommitted
[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models (vllm-project#9559)
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent 1c39c76 commit 27be13e

File tree

11 files changed

+716
-317
lines changed

11 files changed

+716
-317
lines changed

tests/encoder_decoder/test_e2e_correctness.py

+51-37
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,18 @@
77
import pytest
88
from transformers import AutoModelForSeq2SeqLM
99

10+
from vllm.attention.selector import (_Backend,
11+
global_force_attn_backend_context_manager)
1012
from vllm.platforms import current_platform
1113
from vllm.sequence import SampleLogprobs
1214

1315
from ..conftest import DecoderPromptType
1416
from ..models.utils import check_logprobs_close
1517

18+
LIST_ENC_DEC_SUPPORTED_BACKENDS = [
19+
_Backend.XFORMERS, _Backend.FLASH_ATTN, None
20+
]
21+
1622

1723
def vllm_to_hf_output(
1824
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
@@ -29,7 +35,8 @@ def vllm_to_hf_output(
2935

3036

3137
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
32-
@pytest.mark.parametrize("dtype", ["bfloat16"])
38+
@pytest.mark.parametrize("dtype", ["float"])
39+
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
3340
@pytest.mark.parametrize("max_tokens", [128])
3441
@pytest.mark.parametrize("num_logprobs", [5])
3542
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
@@ -48,6 +55,7 @@ def test_encoder_decoder_e2e(
4855
num_logprobs: int,
4956
decoder_prompt_type: DecoderPromptType,
5057
enforce_eager: bool,
58+
attn_backend: _Backend,
5159
) -> None:
5260
'''
5361
End-to-End (E2E) test for the encoder-decoder framework.
@@ -56,43 +64,49 @@ def test_encoder_decoder_e2e(
5664
implementations to ensure that both implementations produce consistent
5765
and correct results.
5866
'''
59-
test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type]
67+
with global_force_attn_backend_context_manager(attn_backend):
68+
if attn_backend == _Backend.FLASH_ATTN:
69+
# Flash Attention works only with bfloat16 data-type
70+
dtype = 'bfloat16'
71+
test_case_prompts = example_encoder_decoder_prompts[
72+
decoder_prompt_type]
6073

61-
# Configuration settings for HF baseline
62-
hf_kwargs = {
63-
"top_k": None,
64-
"num_beams": 1,
65-
"repetition_penalty": 1.0,
66-
"top_p": 1.0,
67-
"length_penalty": 1.0,
68-
"early_stopping": False,
69-
"no_repeat_ngram_size": None,
70-
"min_length": 0
71-
}
74+
# Configuration settings for HF baseline
75+
hf_kwargs = {
76+
"top_k": None,
77+
"num_beams": 1,
78+
"repetition_penalty": 1.0,
79+
"top_p": 1.0,
80+
"length_penalty": 1.0,
81+
"early_stopping": False,
82+
"no_repeat_ngram_size": None,
83+
"min_length": 0
84+
}
7285

73-
with hf_runner(model, dtype=dtype,
74-
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
75-
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
76-
test_case_prompts,
77-
max_tokens,
78-
num_logprobs,
79-
**hf_kwargs,
80-
))
81-
with vllm_runner(model, dtype=dtype,
82-
enforce_eager=enforce_eager) as vllm_model:
83-
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
84-
test_case_prompts, max_tokens, num_logprobs)
86+
with hf_runner(model, dtype=dtype,
87+
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
88+
hf_outputs = (
89+
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
90+
test_case_prompts,
91+
max_tokens,
92+
num_logprobs,
93+
**hf_kwargs,
94+
))
95+
with vllm_runner(model, dtype=dtype,
96+
enforce_eager=enforce_eager) as vllm_model:
97+
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
98+
test_case_prompts, max_tokens, num_logprobs)
8599

86-
hf_skip_tokens = (1
87-
if decoder_prompt_type == DecoderPromptType.NONE else 0)
100+
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
101+
else 0)
88102

89-
check_logprobs_close(
90-
outputs_0_lst=hf_outputs,
91-
outputs_1_lst=[
92-
vllm_to_hf_output(vllm_output, decoder_prompt_type)
93-
for vllm_output in vllm_outputs
94-
],
95-
name_0="hf",
96-
name_1="vllm",
97-
num_outputs_0_skip_tokens=hf_skip_tokens,
98-
)
103+
check_logprobs_close(
104+
outputs_0_lst=hf_outputs,
105+
outputs_1_lst=[
106+
vllm_to_hf_output(vllm_output, decoder_prompt_type)
107+
for vllm_output in vllm_outputs
108+
],
109+
name_0="hf",
110+
name_1="vllm",
111+
num_outputs_0_skip_tokens=hf_skip_tokens,
112+
)

0 commit comments

Comments
 (0)