Skip to content

Commit 08512af

Browse files
hmellorjimpang
authored and
jimpang
committed
[Misc] Remove dangling references to SamplingType.BEAM (vllm-project#13402)
1 parent 64a1faa commit 08512af

File tree

1 file changed

+0
-78
lines changed

1 file changed

+0
-78
lines changed

vllm/model_executor/layers/sampler.py

-78
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ class SampleResultArgsType:
6868
sample_results_dict: SampleResultsDictType
6969
sampling_metadata: SamplingMetadata
7070
greedy_samples: Optional[torch.Tensor]
71-
beam_search_logprobs: Optional[torch.Tensor]
7271

7372

7473
# Union of non-deferred (single-step scheduling)
@@ -510,74 +509,6 @@ def _random_sample(
510509
return results
511510

512511

513-
def _beam_search_sample(
514-
selected_seq_groups: List[SequenceGroupToSample],
515-
logprobs: torch.Tensor,
516-
) -> SampleResultType:
517-
"""Run beam sampling on a given samples.
518-
519-
Args:
520-
selected_seq_groups: A list of sequence groups batched.
521-
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
522-
on selected sample indices.
523-
Returns:
524-
Tuple of (next_token_ids, parent_ids). The length of returned list is
525-
same as the length of selected_seq_groups. If the corresponding
526-
seq_group has do_sample=False, tuple contains ([], [])
527-
"""
528-
# We sample 2 * beam_width candidates to make sure that with high
529-
# probability we can get `beam_width` candidates in addition to
530-
# the finished sequences for the next iteration. See
531-
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
532-
# for details. See also HF reference:
533-
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
534-
#
535-
# NOTE: Beam search is not vectorized, so its speed can be slower than
536-
# other sampling methods.
537-
sample_idx = 0
538-
results: SampleResultType = []
539-
for seq_group in selected_seq_groups:
540-
if not seq_group.do_sample:
541-
results.append(([], []))
542-
continue
543-
544-
is_prompt = seq_group.is_prompt
545-
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
546-
num_parent_seqs = len(seq_ids)
547-
beam_width = sampling_params.n
548-
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
549-
if is_prompt:
550-
# Prompt phase.
551-
assert num_parent_seqs == 1, (
552-
"Prompt input should have only one seq.")
553-
parent_ids = [0] * (2 * beam_width)
554-
_, next_token_ids = torch.topk(seq_group_logprobs[0],
555-
2 * beam_width)
556-
next_token_ids = next_token_ids.tolist()
557-
else:
558-
# Generation phase.
559-
cumulative_logprobs: List[float] = [
560-
seq_group.seq_data[seq_id].cumulative_logprob
561-
for seq_id in seq_ids
562-
]
563-
cumulative_logprobs_tensor = torch.tensor(
564-
cumulative_logprobs,
565-
dtype=torch.float,
566-
device=seq_group_logprobs.device)
567-
seq_group_logprobs = (seq_group_logprobs +
568-
cumulative_logprobs_tensor.unsqueeze(dim=1))
569-
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
570-
2 * beam_width)
571-
topk_ids = topk_ids.tolist()
572-
vocab_size = seq_group_logprobs.size(-1)
573-
parent_ids = [i // vocab_size for i in topk_ids]
574-
next_token_ids = [i % vocab_size for i in topk_ids]
575-
results.append((next_token_ids, parent_ids))
576-
sample_idx += num_parent_seqs
577-
assert sample_idx == logprobs.size(0)
578-
return results
579-
580-
581512
# torch.multinomial forces a GPU<->CPU sync.
582513
# Therefore, we use an optimized implementation instead.
583514
# Note that we always sample with replacement.
@@ -666,14 +597,12 @@ def get_pythonized_sample_results(
666597
sampling_metadata,
667598
greedy_samples,
668599
multinomial_samples,
669-
beam_search_logprobs,
670600
sample_results_dict,
671601
) = (
672602
sample_result_args.sample_metadata,
673603
sample_result_args.sampling_metadata,
674604
sample_result_args.greedy_samples,
675605
sample_result_args.multinomial_samples,
676-
sample_result_args.beam_search_logprobs,
677606
sample_result_args.sample_results_dict,
678607
)
679608

@@ -686,9 +615,6 @@ def get_pythonized_sample_results(
686615
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
687616
sample_results = _random_sample(seq_groups,
688617
multinomial_samples[sampling_type])
689-
elif sampling_type == SamplingType.BEAM:
690-
sample_results = _beam_search_sample(seq_groups,
691-
beam_search_logprobs)
692618
sample_results_dict.update(zip(seq_group_id, sample_results))
693619

694620
return [
@@ -731,7 +657,6 @@ def _sample_with_torch(
731657
sample_metadata: SampleMetadataType = {}
732658
multinomial_samples: MultinomialSamplesType = {}
733659
greedy_samples: Optional[torch.Tensor] = None
734-
beam_search_logprobs: Optional[torch.Tensor] = None
735660

736661
# Create output tensor for sampled token ids.
737662
if include_gpu_probs_tensor:
@@ -800,8 +725,6 @@ def _sample_with_torch(
800725
sampled_token_ids_tensor[long_sample_indices] = \
801726
multinomial_samples[sampling_type].to(torch.long)
802727

803-
elif sampling_type == SamplingType.BEAM:
804-
beam_search_logprobs = logprobs[sample_indices]
805728
else:
806729
raise ValueError(f"Unsupported sampling type: {sampling_type}")
807730

@@ -812,7 +735,6 @@ def _sample_with_torch(
812735
sample_metadata=sample_metadata,
813736
multinomial_samples=multinomial_samples,
814737
greedy_samples=greedy_samples,
815-
beam_search_logprobs=beam_search_logprobs,
816738
sample_results_dict=sample_results_dict)
817739

818740
if not sampling_metadata.skip_sampler_cpu_output:

0 commit comments

Comments
 (0)