Skip to content

Commit 7f616e5

Browse files
LunrEclipseyoukaichao
authored andcommitted
[Frontend] API support for beam search (vllm-project#9087)
Co-authored-by: youkaichao <youkaichao@126.com>
1 parent e8097e0 commit 7f616e5

12 files changed

+275
-68
lines changed

benchmarks/benchmark_throughput.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vllm.entrypoints.openai.api_server import (
1616
build_async_engine_client_from_engine_args)
1717
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
18+
from vllm.sampling_params import BeamSearchParams
1819
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
1920

2021

@@ -145,10 +146,13 @@ def run_vllm(
145146
for prompt, input_len, _output_len in requests:
146147
assert _output_len == output_len
147148
start = time.perf_counter()
148-
llm.beam_search(prompts,
149-
beam_width=n,
150-
max_tokens=output_len,
151-
ignore_eos=True)
149+
llm.beam_search(
150+
prompts,
151+
BeamSearchParams(
152+
beam_width=n,
153+
max_tokens=output_len,
154+
ignore_eos=True,
155+
))
152156
end = time.perf_counter()
153157
return end - start
154158

tests/conftest.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
to_enc_dec_tuple_list, zip_enc_dec_prompts)
3636
from vllm.logger import init_logger
3737
from vllm.outputs import RequestOutput
38+
from vllm.sampling_params import BeamSearchParams
3839
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
3940
identity, is_cpu)
4041

@@ -812,7 +813,9 @@ def generate_beam_search_new(
812813
beam_width: int,
813814
max_tokens: int,
814815
) -> List[Tuple[List[List[int]], List[str]]]:
815-
outputs = self.model.beam_search(prompts, beam_width, max_tokens)
816+
outputs = self.model.beam_search(
817+
prompts,
818+
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
816819
returned_outputs = []
817820
for output in outputs:
818821
token_ids = [x.tokens for x in output.sequences]

tests/entrypoints/openai/test_completion.py

+24-19
Original file line numberDiff line numberDiff line change
@@ -495,25 +495,30 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
495495
assert len(batch.choices) == 2
496496
assert batch.choices[0].text == batch.choices[1].text
497497

498-
# test n = 2
499-
batch = await client.completions.create(
500-
model=model_name,
501-
prompt=prompts,
502-
n=2,
503-
max_tokens=5,
504-
temperature=0.0,
505-
extra_body=dict(
506-
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
507-
# for official client.
508-
use_beam_search=True),
509-
)
510-
assert len(batch.choices) == 4
511-
assert batch.choices[0].text != batch.choices[
512-
1].text, "beam search should be different"
513-
assert batch.choices[0].text == batch.choices[
514-
2].text, "two copies of the same prompt should be the same"
515-
assert batch.choices[1].text == batch.choices[
516-
3].text, "two copies of the same prompt should be the same"
498+
try:
499+
# test n = 2
500+
batch = await client.completions.create(
501+
model=model_name,
502+
prompt=prompts,
503+
n=2,
504+
max_tokens=5,
505+
temperature=0.0,
506+
extra_body=dict(
507+
# NOTE: this has to be true for n > 1 in vLLM, but
508+
# not necessary for official client.
509+
use_beam_search=True),
510+
)
511+
assert len(batch.choices) == 4
512+
assert batch.choices[0].text != batch.choices[
513+
1].text, "beam search should be different"
514+
assert batch.choices[0].text == batch.choices[
515+
2].text, "two copies of the same prompt should be the same"
516+
assert batch.choices[1].text == batch.choices[
517+
3].text, "two copies of the same prompt should be the same"
518+
except BadRequestError as e:
519+
# the only allowed exception is when beam search is not supported
520+
# in the default mqllmengine
521+
assert "--disable-frontend-multiprocessing" in str(e)
517522

518523
# test streaming
519524
batch = await client.completions.create(

vllm/engine/async_llm_engine.py

+103-4
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,26 @@
1414
from vllm.engine.async_timeout import asyncio_timeout
1515
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
1616
from vllm.engine.metrics_types import StatLoggerBase
17+
from vllm.entrypoints.llm import BeamSearchSequence
1718
from vllm.executor.executor_base import ExecutorAsyncBase
1819
from vllm.executor.gpu_executor import GPUExecutorAsync
1920
from vllm.executor.ray_utils import initialize_ray_cluster
20-
from vllm.inputs import PromptType
21+
from vllm.inputs import PromptType, TokensPrompt
2122
from vllm.logger import init_logger
2223
from vllm.lora.request import LoRARequest
2324
from vllm.model_executor.guided_decoding import (
2425
get_guided_decoding_logits_processor)
2526
from vllm.model_executor.layers.sampler import SamplerOutput
26-
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
27+
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
28+
RequestOutput)
2729
from vllm.pooling_params import PoolingParams
2830
from vllm.prompt_adapter.request import PromptAdapterRequest
29-
from vllm.sampling_params import SamplingParams
31+
from vllm.sampling_params import BeamSearchParams, SamplingParams
3032
from vllm.sequence import ExecuteModelRequest
3133
from vllm.transformers_utils.tokenizer import AnyTokenizer
3234
from vllm.usage.usage_lib import UsageContext
33-
from vllm.utils import deprecate_kwargs, weak_bind
35+
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
36+
random_uuid, weak_bind)
3437

3538
logger = init_logger(__name__)
3639
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@@ -1036,6 +1039,102 @@ async def generate(
10361039
):
10371040
yield LLMEngine.validate_output(output, RequestOutput)
10381041

1042+
async def beam_search(
1043+
self,
1044+
prompt: Union[PromptType, List[int]],
1045+
request_id: str,
1046+
params: BeamSearchParams,
1047+
) -> AsyncGenerator[RequestOutput, None]:
1048+
1049+
beam_width = params.beam_width
1050+
max_tokens = params.max_tokens
1051+
ignore_eos = params.ignore_eos
1052+
temperature = params.temperature
1053+
1054+
tokenizer = await self.get_tokenizer()
1055+
tokenizedPrompt = prompt if isinstance(
1056+
prompt, list) else tokenizer.encode(prompt)
1057+
tokenizedLength = len(tokenizedPrompt)
1058+
1059+
beam_search_params = SamplingParams(logprobs=2 * beam_width,
1060+
max_tokens=1,
1061+
temperature=temperature)
1062+
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
1063+
completed = []
1064+
1065+
for _ in range(max_tokens):
1066+
prompts_batch = [
1067+
TokensPrompt(prompt_token_ids=beam.tokens)
1068+
for beam in all_beams
1069+
]
1070+
1071+
tasks = []
1072+
1073+
request_id = f"beam_search-{random_uuid()}"
1074+
for i, individual_prompt in enumerate(prompts_batch):
1075+
request_id_item = f"{request_id}-{i}"
1076+
task = asyncio.create_task(
1077+
collect_from_async_generator(
1078+
self.generate(individual_prompt, beam_search_params,
1079+
request_id_item)))
1080+
tasks.append(task)
1081+
1082+
output = await asyncio.gather(*tasks)
1083+
1084+
output = [x[0] for x in output]
1085+
1086+
logger.info(output)
1087+
1088+
new_beams = []
1089+
for i, current_beam in enumerate(all_beams):
1090+
result = output[i]
1091+
1092+
if result.outputs[0].logprobs is not None:
1093+
logprobs = result.outputs[0].logprobs[0]
1094+
for token_id, logprob_obj in logprobs.items():
1095+
new_beam = BeamSearchSequence(
1096+
tokens=current_beam.tokens + [token_id],
1097+
cum_logprob=current_beam.cum_logprob +
1098+
logprob_obj.logprob)
1099+
1100+
if token_id == tokenizer.eos_token_id and \
1101+
not ignore_eos:
1102+
completed.append(new_beam)
1103+
else:
1104+
new_beams.append(new_beam)
1105+
1106+
sorted_beams = sorted(new_beams,
1107+
key=lambda x: x.cum_logprob,
1108+
reverse=True)
1109+
all_beams = sorted_beams[:beam_width]
1110+
1111+
completed.extend(all_beams)
1112+
sorted_completed = sorted(completed,
1113+
key=lambda x: x.cum_logprob,
1114+
reverse=True)
1115+
best_beams = sorted_completed[:beam_width]
1116+
1117+
for beam in best_beams:
1118+
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
1119+
1120+
beam_search_output = RequestOutput(
1121+
request_id=request_id,
1122+
prompt=prompt,
1123+
outputs=[
1124+
CompletionOutput(
1125+
text=beam.text,
1126+
cumulative_logprob=beam.cum_logprob,
1127+
token_ids=beam.tokens,
1128+
index=i,
1129+
logprobs=beam.cum_logprob,
1130+
) for (i, beam) in enumerate(best_beams)
1131+
],
1132+
finished=True,
1133+
prompt_token_ids=tokenizedPrompt,
1134+
prompt_logprobs=None)
1135+
1136+
yield LLMEngine.validate_output(beam_search_output, RequestOutput)
1137+
10391138
async def encode(
10401139
self,
10411140
prompt: PromptType,

vllm/entrypoints/llm.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
2323
from vllm.pooling_params import PoolingParams
2424
from vllm.prompt_adapter.request import PromptAdapterRequest
25-
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
26-
SamplingParams)
25+
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
26+
RequestOutputKind, SamplingParams)
2727
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
2828
get_cached_tokenizer)
2929
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@@ -394,25 +394,25 @@ def generate(
394394
def beam_search(
395395
self,
396396
prompts: List[Union[str, List[int]]],
397-
beam_width: int,
398-
max_tokens: int,
399-
ignore_eos: bool = False,
400-
temperature: float = 0.0,
397+
params: BeamSearchParams,
401398
) -> List[BeamSearchOutput]:
402399
"""
403400
Generate sequences using beam search.
404401
405402
Args:
406403
prompts: A list of prompts. Each prompt can be a string or a list
407404
of token IDs.
408-
beam_width: The number of beams to keep at each step.
409-
max_tokens: The max number of tokens to generate for each prompt.
410-
temperature: The temperature to use for generation.
411-
405+
params: The beam search parameters.
406+
412407
TODO: how does beam search work together with length penalty, frequency
413408
penalty, and stopping criteria, etc.?
414409
"""
415410

411+
beam_width = params.beam_width
412+
max_tokens = params.max_tokens
413+
temperature = params.temperature
414+
ignore_eos = params.ignore_eos
415+
416416
tokenizer = self.get_tokenizer()
417417
# generate 2 * beam_width candidates at each step
418418
# following the huggingface transformers implementation

vllm/entrypoints/logger.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from vllm.lora.request import LoRARequest
55
from vllm.pooling_params import PoolingParams
66
from vllm.prompt_adapter.request import PromptAdapterRequest
7-
from vllm.sampling_params import SamplingParams
7+
from vllm.sampling_params import BeamSearchParams, SamplingParams
88

99
logger = init_logger(__name__)
1010

@@ -21,7 +21,8 @@ def log_inputs(
2121
request_id: str,
2222
prompt: Optional[str],
2323
prompt_token_ids: Optional[List[int]],
24-
params: Optional[Union[SamplingParams, PoolingParams]],
24+
params: Optional[Union[SamplingParams, PoolingParams,
25+
BeamSearchParams]],
2526
lora_request: Optional[LoRARequest],
2627
prompt_adapter_request: Optional[PromptAdapterRequest],
2728
) -> None:

vllm/entrypoints/openai/protocol.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
1313
from vllm.pooling_params import PoolingParams
14-
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
15-
SamplingParams)
14+
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
15+
RequestOutputKind, SamplingParams)
1616
from vllm.sequence import Logprob
1717
from vllm.utils import random_uuid
1818

@@ -288,6 +288,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
288288

289289
# doc: end-chat-completion-extra-params
290290

291+
def to_beam_search_params(self,
292+
default_max_tokens: int) -> BeamSearchParams:
293+
max_tokens = self.max_tokens
294+
if max_tokens is None:
295+
max_tokens = default_max_tokens
296+
297+
n = self.n if self.n is not None else 1
298+
temperature = self.temperature if self.temperature is not None else 0.0
299+
300+
return BeamSearchParams(
301+
beam_width=n,
302+
max_tokens=max_tokens,
303+
ignore_eos=self.ignore_eos,
304+
temperature=temperature,
305+
)
306+
291307
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
292308
max_tokens = self.max_tokens
293309
if max_tokens is None:
@@ -567,6 +583,22 @@ class CompletionRequest(OpenAIBaseModel):
567583

568584
# doc: end-completion-extra-params
569585

586+
def to_beam_search_params(self,
587+
default_max_tokens: int) -> BeamSearchParams:
588+
max_tokens = self.max_tokens
589+
if max_tokens is None:
590+
max_tokens = default_max_tokens
591+
592+
n = self.n if self.n is not None else 1
593+
temperature = self.temperature if self.temperature is not None else 0.0
594+
595+
return BeamSearchParams(
596+
beam_width=n,
597+
max_tokens=max_tokens,
598+
ignore_eos=self.ignore_eos,
599+
temperature=temperature,
600+
)
601+
570602
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
571603
max_tokens = self.max_tokens
572604
if max_tokens is None:

0 commit comments

Comments
 (0)