@@ -68,7 +68,6 @@ class SampleResultArgsType:
68
68
sample_results_dict : SampleResultsDictType
69
69
sampling_metadata : SamplingMetadata
70
70
greedy_samples : Optional [torch .Tensor ]
71
- beam_search_logprobs : Optional [torch .Tensor ]
72
71
73
72
74
73
# Union of non-deferred (single-step scheduling)
@@ -510,74 +509,6 @@ def _random_sample(
510
509
return results
511
510
512
511
513
- def _beam_search_sample (
514
- selected_seq_groups : List [SequenceGroupToSample ],
515
- logprobs : torch .Tensor ,
516
- ) -> SampleResultType :
517
- """Run beam sampling on a given samples.
518
-
519
- Args:
520
- selected_seq_groups: A list of sequence groups batched.
521
- logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
522
- on selected sample indices.
523
- Returns:
524
- Tuple of (next_token_ids, parent_ids). The length of returned list is
525
- same as the length of selected_seq_groups. If the corresponding
526
- seq_group has do_sample=False, tuple contains ([], [])
527
- """
528
- # We sample 2 * beam_width candidates to make sure that with high
529
- # probability we can get `beam_width` candidates in addition to
530
- # the finished sequences for the next iteration. See
531
- # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
532
- # for details. See also HF reference:
533
- # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
534
- #
535
- # NOTE: Beam search is not vectorized, so its speed can be slower than
536
- # other sampling methods.
537
- sample_idx = 0
538
- results : SampleResultType = []
539
- for seq_group in selected_seq_groups :
540
- if not seq_group .do_sample :
541
- results .append (([], []))
542
- continue
543
-
544
- is_prompt = seq_group .is_prompt
545
- seq_ids , sampling_params = seq_group .seq_ids , seq_group .sampling_params
546
- num_parent_seqs = len (seq_ids )
547
- beam_width = sampling_params .n
548
- seq_group_logprobs = logprobs [sample_idx :sample_idx + num_parent_seqs ]
549
- if is_prompt :
550
- # Prompt phase.
551
- assert num_parent_seqs == 1 , (
552
- "Prompt input should have only one seq." )
553
- parent_ids = [0 ] * (2 * beam_width )
554
- _ , next_token_ids = torch .topk (seq_group_logprobs [0 ],
555
- 2 * beam_width )
556
- next_token_ids = next_token_ids .tolist ()
557
- else :
558
- # Generation phase.
559
- cumulative_logprobs : List [float ] = [
560
- seq_group .seq_data [seq_id ].cumulative_logprob
561
- for seq_id in seq_ids
562
- ]
563
- cumulative_logprobs_tensor = torch .tensor (
564
- cumulative_logprobs ,
565
- dtype = torch .float ,
566
- device = seq_group_logprobs .device )
567
- seq_group_logprobs = (seq_group_logprobs +
568
- cumulative_logprobs_tensor .unsqueeze (dim = 1 ))
569
- _ , topk_ids = torch .topk (seq_group_logprobs .flatten (),
570
- 2 * beam_width )
571
- topk_ids = topk_ids .tolist ()
572
- vocab_size = seq_group_logprobs .size (- 1 )
573
- parent_ids = [i // vocab_size for i in topk_ids ]
574
- next_token_ids = [i % vocab_size for i in topk_ids ]
575
- results .append ((next_token_ids , parent_ids ))
576
- sample_idx += num_parent_seqs
577
- assert sample_idx == logprobs .size (0 )
578
- return results
579
-
580
-
581
512
# torch.multinomial forces a GPU<->CPU sync.
582
513
# Therefore, we use an optimized implementation instead.
583
514
# Note that we always sample with replacement.
@@ -666,14 +597,12 @@ def get_pythonized_sample_results(
666
597
sampling_metadata ,
667
598
greedy_samples ,
668
599
multinomial_samples ,
669
- beam_search_logprobs ,
670
600
sample_results_dict ,
671
601
) = (
672
602
sample_result_args .sample_metadata ,
673
603
sample_result_args .sampling_metadata ,
674
604
sample_result_args .greedy_samples ,
675
605
sample_result_args .multinomial_samples ,
676
- sample_result_args .beam_search_logprobs ,
677
606
sample_result_args .sample_results_dict ,
678
607
)
679
608
@@ -686,9 +615,6 @@ def get_pythonized_sample_results(
686
615
elif sampling_type in (SamplingType .RANDOM , SamplingType .RANDOM_SEED ):
687
616
sample_results = _random_sample (seq_groups ,
688
617
multinomial_samples [sampling_type ])
689
- elif sampling_type == SamplingType .BEAM :
690
- sample_results = _beam_search_sample (seq_groups ,
691
- beam_search_logprobs )
692
618
sample_results_dict .update (zip (seq_group_id , sample_results ))
693
619
694
620
return [
@@ -731,7 +657,6 @@ def _sample_with_torch(
731
657
sample_metadata : SampleMetadataType = {}
732
658
multinomial_samples : MultinomialSamplesType = {}
733
659
greedy_samples : Optional [torch .Tensor ] = None
734
- beam_search_logprobs : Optional [torch .Tensor ] = None
735
660
736
661
# Create output tensor for sampled token ids.
737
662
if include_gpu_probs_tensor :
@@ -800,8 +725,6 @@ def _sample_with_torch(
800
725
sampled_token_ids_tensor [long_sample_indices ] = \
801
726
multinomial_samples [sampling_type ].to (torch .long )
802
727
803
- elif sampling_type == SamplingType .BEAM :
804
- beam_search_logprobs = logprobs [sample_indices ]
805
728
else :
806
729
raise ValueError (f"Unsupported sampling type: { sampling_type } " )
807
730
@@ -812,7 +735,6 @@ def _sample_with_torch(
812
735
sample_metadata = sample_metadata ,
813
736
multinomial_samples = multinomial_samples ,
814
737
greedy_samples = greedy_samples ,
815
- beam_search_logprobs = beam_search_logprobs ,
816
738
sample_results_dict = sample_results_dict )
817
739
818
740
if not sampling_metadata .skip_sampler_cpu_output :
0 commit comments