|
14 | 14 | from vllm.engine.async_timeout import asyncio_timeout
|
15 | 15 | from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
|
16 | 16 | from vllm.engine.metrics_types import StatLoggerBase
|
| 17 | +from vllm.entrypoints.llm import BeamSearchSequence |
17 | 18 | from vllm.executor.executor_base import ExecutorAsyncBase
|
18 | 19 | from vllm.executor.gpu_executor import GPUExecutorAsync
|
19 | 20 | from vllm.executor.ray_utils import initialize_ray_cluster
|
20 |
| -from vllm.inputs import PromptType |
| 21 | +from vllm.inputs import PromptType, TokensPrompt |
21 | 22 | from vllm.logger import init_logger
|
22 | 23 | from vllm.lora.request import LoRARequest
|
23 | 24 | from vllm.model_executor.guided_decoding import (
|
24 | 25 | get_guided_decoding_logits_processor)
|
25 | 26 | from vllm.model_executor.layers.sampler import SamplerOutput
|
26 |
| -from vllm.outputs import EmbeddingRequestOutput, RequestOutput |
| 27 | +from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, |
| 28 | + RequestOutput) |
27 | 29 | from vllm.pooling_params import PoolingParams
|
28 | 30 | from vllm.prompt_adapter.request import PromptAdapterRequest
|
29 |
| -from vllm.sampling_params import SamplingParams |
| 31 | +from vllm.sampling_params import BeamSearchParams, SamplingParams |
30 | 32 | from vllm.sequence import ExecuteModelRequest
|
31 | 33 | from vllm.transformers_utils.tokenizer import AnyTokenizer
|
32 | 34 | from vllm.usage.usage_lib import UsageContext
|
33 |
| -from vllm.utils import deprecate_kwargs, weak_bind |
| 35 | +from vllm.utils import (collect_from_async_generator, deprecate_kwargs, |
| 36 | + random_uuid, weak_bind) |
34 | 37 |
|
35 | 38 | logger = init_logger(__name__)
|
36 | 39 | ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
@@ -1036,6 +1039,102 @@ async def generate(
|
1036 | 1039 | ):
|
1037 | 1040 | yield LLMEngine.validate_output(output, RequestOutput)
|
1038 | 1041 |
|
| 1042 | + async def beam_search( |
| 1043 | + self, |
| 1044 | + prompt: Union[PromptType, List[int]], |
| 1045 | + request_id: str, |
| 1046 | + params: BeamSearchParams, |
| 1047 | + ) -> AsyncGenerator[RequestOutput, None]: |
| 1048 | + |
| 1049 | + beam_width = params.beam_width |
| 1050 | + max_tokens = params.max_tokens |
| 1051 | + ignore_eos = params.ignore_eos |
| 1052 | + temperature = params.temperature |
| 1053 | + |
| 1054 | + tokenizer = await self.get_tokenizer() |
| 1055 | + tokenizedPrompt = prompt if isinstance( |
| 1056 | + prompt, list) else tokenizer.encode(prompt) |
| 1057 | + tokenizedLength = len(tokenizedPrompt) |
| 1058 | + |
| 1059 | + beam_search_params = SamplingParams(logprobs=2 * beam_width, |
| 1060 | + max_tokens=1, |
| 1061 | + temperature=temperature) |
| 1062 | + all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] |
| 1063 | + completed = [] |
| 1064 | + |
| 1065 | + for _ in range(max_tokens): |
| 1066 | + prompts_batch = [ |
| 1067 | + TokensPrompt(prompt_token_ids=beam.tokens) |
| 1068 | + for beam in all_beams |
| 1069 | + ] |
| 1070 | + |
| 1071 | + tasks = [] |
| 1072 | + |
| 1073 | + request_id = f"beam_search-{random_uuid()}" |
| 1074 | + for i, individual_prompt in enumerate(prompts_batch): |
| 1075 | + request_id_item = f"{request_id}-{i}" |
| 1076 | + task = asyncio.create_task( |
| 1077 | + collect_from_async_generator( |
| 1078 | + self.generate(individual_prompt, beam_search_params, |
| 1079 | + request_id_item))) |
| 1080 | + tasks.append(task) |
| 1081 | + |
| 1082 | + output = await asyncio.gather(*tasks) |
| 1083 | + |
| 1084 | + output = [x[0] for x in output] |
| 1085 | + |
| 1086 | + logger.info(output) |
| 1087 | + |
| 1088 | + new_beams = [] |
| 1089 | + for i, current_beam in enumerate(all_beams): |
| 1090 | + result = output[i] |
| 1091 | + |
| 1092 | + if result.outputs[0].logprobs is not None: |
| 1093 | + logprobs = result.outputs[0].logprobs[0] |
| 1094 | + for token_id, logprob_obj in logprobs.items(): |
| 1095 | + new_beam = BeamSearchSequence( |
| 1096 | + tokens=current_beam.tokens + [token_id], |
| 1097 | + cum_logprob=current_beam.cum_logprob + |
| 1098 | + logprob_obj.logprob) |
| 1099 | + |
| 1100 | + if token_id == tokenizer.eos_token_id and \ |
| 1101 | + not ignore_eos: |
| 1102 | + completed.append(new_beam) |
| 1103 | + else: |
| 1104 | + new_beams.append(new_beam) |
| 1105 | + |
| 1106 | + sorted_beams = sorted(new_beams, |
| 1107 | + key=lambda x: x.cum_logprob, |
| 1108 | + reverse=True) |
| 1109 | + all_beams = sorted_beams[:beam_width] |
| 1110 | + |
| 1111 | + completed.extend(all_beams) |
| 1112 | + sorted_completed = sorted(completed, |
| 1113 | + key=lambda x: x.cum_logprob, |
| 1114 | + reverse=True) |
| 1115 | + best_beams = sorted_completed[:beam_width] |
| 1116 | + |
| 1117 | + for beam in best_beams: |
| 1118 | + beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) |
| 1119 | + |
| 1120 | + beam_search_output = RequestOutput( |
| 1121 | + request_id=request_id, |
| 1122 | + prompt=prompt, |
| 1123 | + outputs=[ |
| 1124 | + CompletionOutput( |
| 1125 | + text=beam.text, |
| 1126 | + cumulative_logprob=beam.cum_logprob, |
| 1127 | + token_ids=beam.tokens, |
| 1128 | + index=i, |
| 1129 | + logprobs=beam.cum_logprob, |
| 1130 | + ) for (i, beam) in enumerate(best_beams) |
| 1131 | + ], |
| 1132 | + finished=True, |
| 1133 | + prompt_token_ids=tokenizedPrompt, |
| 1134 | + prompt_logprobs=None) |
| 1135 | + |
| 1136 | + yield LLMEngine.validate_output(beam_search_output, RequestOutput) |
| 1137 | + |
1039 | 1138 | async def encode(
|
1040 | 1139 | self,
|
1041 | 1140 | prompt: PromptType,
|
|
0 commit comments