From 454b842336ac4cbc52feaf7a7408f6a102dfa32e Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 21 Feb 2024 11:47:00 -0800 Subject: [PATCH] Support per-request seed (#2514) --- tests/samplers/test_sampler.py | 222 +++++++++++++++-------- tests/samplers/test_seeded_generate.py | 82 +++++++++ vllm/core/scheduler.py | 1 + vllm/engine/arg_utils.py | 1 - vllm/entrypoints/openai/protocol.py | 4 + vllm/model_executor/layers/sampler.py | 29 ++- vllm/model_executor/sampling_metadata.py | 3 + vllm/sampling_params.py | 9 +- vllm/sequence.py | 12 ++ vllm/worker/model_runner.py | 10 + 10 files changed, 289 insertions(+), 84 deletions(-) create mode 100644 tests/samplers/test_seeded_generate.py diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index d34f32d03fee0..31e865f42ff3b 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,10 +1,11 @@ import random -from typing import Tuple +from typing import Tuple, List from unittest.mock import patch import pytest import torch from transformers import GenerationConfig, GenerationMixin +from typing import Optional from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.utils import set_random_seed @@ -46,15 +47,13 @@ def _prepare_test( ] -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_all_greedy(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) - +def _do_sample( + batch_size: int, + input_tensor: torch.Tensor, + sampler: MockLogitsSampler, + model_runner: ModelRunner, + sampling_params: SamplingParams, +): seq_group_metadata_list = [] prompt_lens = [] for i in range(batch_size): @@ -63,7 +62,7 @@ def test_sampler_all_greedy(seed: int, device: str): request_id=f"test_{i}", is_prompt=True, seq_data={0: SequenceData([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0, ), + sampling_params=sampling_params, block_tables={0: [1]}, )) prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) @@ -71,9 +70,23 @@ def test_sampler_all_greedy(seed: int, device: str): sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) + return sampler(embedding=None, + hidden_states=input_tensor, + sampling_metadata=sampling_metadata) + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_all_greedy(seed: int, device: str): + set_random_seed(seed) + torch.set_default_device(device) + batch_size = random.randint(1, 256) + input_tensor, fake_logits, sampler, model_runner = _prepare_test( + batch_size) + + sampling_params = SamplingParams(temperature=0) + sampler_output = _do_sample(batch_size, input_tensor, sampler, + model_runner, sampling_params) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -94,28 +107,40 @@ def test_sampler_all_random(seed: int, device: str): for i in range(batch_size): fake_logits[i, i] = 1e2 - seq_group_metadata_list = [] - prompt_lens = [] + sampling_params = SamplingParams( + temperature=1.0, + n=random.randint(1, 10), + ) + sampler_output = _do_sample(batch_size, input_tensor, sampler, + model_runner, sampling_params) + + for i, sequence_output in enumerate(sampler_output): + for nth_output in sequence_output.samples: + assert nth_output.output_token == i + + del model_runner + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_all_random_seed(seed: int, device: str): + set_random_seed(seed) + torch.set_default_device(device) + batch_size = random.randint(1, 256) + input_tensor, fake_logits, sampler, model_runner = _prepare_test( + batch_size) + for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, - sampling_params=SamplingParams( - temperature=1.0, - n=random.randint(1, 10), - ), - block_tables={0: [1]}, - )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + fake_logits[i, i] = 1e2 + + sampling_params = SamplingParams( + temperature=1.0, + n=random.randint(1, 10), + seed=random.randint(0, 10000), + ) + sampler_output = _do_sample(batch_size, input_tensor, sampler, + model_runner, sampling_params) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i @@ -123,6 +148,31 @@ def test_sampler_all_random(seed: int, device: str): del model_runner +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_all_random_seed_deterministic(seed: int, device: str): + set_random_seed(seed) + torch.set_default_device(device) + batch_size = random.randint(1, 256) + input_tensor, fake_logits, sampler, model_runner = _prepare_test( + batch_size) + + sampling_params = SamplingParams( + temperature=1.0, + n=random.randint(1, 10), + seed=random.randint(0, 10000), + ) + first_sampler_output = _do_sample(batch_size, input_tensor, sampler, + model_runner, sampling_params) + + second_sampler_output = _do_sample(batch_size, input_tensor, sampler, + model_runner, sampling_params) + + assert first_sampler_output == second_sampler_output + + del model_runner + + @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_all_beam(seed: int, device: str): @@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str): batch_size = random.randint(1, 256) input_tensor, _, sampler, model_runner = _prepare_test(batch_size) - seq_group_metadata_list = [] - prompt_lens = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, - sampling_params=SamplingParams( - temperature=0, - best_of=2, - use_beam_search=True, - ), - block_tables={0: [1]}, - )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) - sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) + sampling_params = SamplingParams( + temperature=0, + best_of=2, + use_beam_search=True, + ) + _do_sample(batch_size, input_tensor, sampler, model_runner, + sampling_params) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler @@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str): batch_size) seq_group_metadata_list = [] - expected_tokens = [] + expected_tokens: List[Optional[List[int]]] = [] prompt_lens = [] for i in range(batch_size): - n = 1 - sampling_type = random.randint(0, 2) + expected: Optional[List[int]] = None + sampling_type = random.randint(0, 3) if sampling_type == 0: sampling_params = SamplingParams(temperature=0) - elif sampling_type == 1: + expected = [torch.argmax(fake_logits[i], dim=-1).item()] + elif sampling_type in (1, 2): n = random.randint(1, 10) sampling_params = SamplingParams( temperature=random.random() + 0.1, @@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str): n=n, presence_penalty=random.randint(0, 1), ) + if sampling_type == 2: + sampling_params.seed = random.randint(0, 10000) + else: + for idx in range(n): + fake_logits[i, i + idx] = 1e2 + expected = list(range(i, i + n)) else: sampling_params = SamplingParams(temperature=0, use_beam_search=True, best_of=2) - for idx in range(n): - fake_logits[i, i + idx] = 1e2 - expected_tokens.append(i + idx) + expected_tokens.append(expected) seq_group_metadata_list.append( SequenceGroupMetadata( request_id=f"test_{i}", @@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str): )) prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) - for i, sequence_output in enumerate(sampler_output): - if seq_group_metadata_list[i].sampling_params.use_beam_search: - continue - for nth_output in sequence_output.samples: - assert nth_output.output_token in expected_tokens + def test_sampling(model_runner: ModelRunner): + sampling_metadata = model_runner._prepare_sample( + seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) + sampler_output = sampler(embedding=None, + hidden_states=input_tensor, + sampling_metadata=sampling_metadata) + + for i, (sequence_output, metadata) in enumerate( + zip(sampler_output, seq_group_metadata_list)): + if metadata.sampling_params.use_beam_search: + continue + + if metadata.sampling_params.seed is not None \ + and expected_tokens[i] is None: + # Record seeded random result to compare with results of second invocation + expected_tokens[i] = [ + nth_output.output_token + for nth_output in sequence_output.samples + ] + continue + + for n, nth_output in enumerate(sequence_output.samples): + if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None: + # Ensure exact matches for greedy or random with seed + assert nth_output.output_token == expected_tokens[i][n] + else: + # For non-seeded random check that one of the high-logit tokens were chosen + assert nth_output.output_token in expected_tokens[i] + + # Test batch + test_sampling(model_runner) + + # Shuffle the batch and resample + target_index = list(range(batch_size)) + for list_to_shuffle in (target_index, seq_group_metadata_list, + expected_tokens, prompt_lens): + random.Random(seed).shuffle(list_to_shuffle) + target_index = torch.tensor(target_index) + input_tensor.data = input_tensor.index_select(0, target_index) + fake_logits.data = fake_logits.index_select(0, target_index) + + # This time, results of seeded random samples will be compared with the corresponding + # sample in the pre-shuffled batch + test_sampling(model_runner) del model_runner diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py new file mode 100644 index 0000000000000..fcb0e09d46143 --- /dev/null +++ b/tests/samplers/test_seeded_generate.py @@ -0,0 +1,82 @@ +"""Verify that seeded random sampling is deterministic. + +Run `pytest tests/samplers/test_seeded_generate.py --forked`. +""" +import copy +import random +from itertools import combinations + +import pytest + +from vllm.model_executor.utils import set_random_seed +from vllm import SamplingParams + +MODEL = "facebook/opt-125m" +RANDOM_SEEDS = list(range(5)) + + +@pytest.fixture +def vllm_model(vllm_runner): + vllm_model = vllm_runner(MODEL, dtype="half") + yield vllm_model + del vllm_model + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_random_sample_with_seed( + vllm_model, + example_prompts, + seed: int, +) -> None: + set_random_seed(seed) + + sampling_params = SamplingParams( + # Parameters to ensure sufficient randomness + temperature=2.0, + top_p=min(random.random() + 0.3, 1), + top_k=random.randint(5, 20), + n=random.randint(1, 10), + presence_penalty=random.randint(0, 1), + max_tokens=8, + ignore_eos=True, + ) + + sampling_params_seed_1 = copy.deepcopy(sampling_params) + sampling_params_seed_1.seed = 100 + sampling_params_seed_2 = copy.deepcopy(sampling_params) + sampling_params_seed_2.seed = 200 + + llm = vllm_model.model + + for prompt in example_prompts: + for params in ( + sampling_params, + sampling_params_seed_1, + sampling_params_seed_2, + sampling_params, + sampling_params_seed_1, + sampling_params_seed_2, + ): + llm._add_request( + prompt=prompt, + prompt_token_ids=None, + sampling_params=params, + ) + + results = llm._run_engine(use_tqdm=False) + all_outputs = [[out.token_ids for out in output.outputs] + for output in results] + + for i in range(0, len(example_prompts), 6): + outputs = all_outputs[i:i + 6] + + # verify all non-seeded requests differ + for output_a, output_b in combinations( + (outputs[0], outputs[1], outputs[2], outputs[3]), + 2, + ): + assert output_a != output_b + + # verify requests with the same seed match + assert outputs[1] == outputs[4] + assert outputs[2] == outputs[5] diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5dde9097a3d57..f4ac2d6dc59fe 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -387,6 +387,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: block_tables=block_tables, lora_request=seq_group.lora_request, prefix=seq_group.prefix, + state=seq_group.state, ) seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8ac0157151d8e..a4efd171b871d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -173,7 +173,6 @@ def add_cli_args( default=EngineArgs.block_size, choices=[8, 16, 32], help='token block size') - # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). parser.add_argument('--seed', type=int, default=EngineArgs.seed, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index fc15b7833ecf2..727fec870293c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel): top_p: Optional[float] = 1.0 n: Optional[int] = 1 max_tokens: Optional[int] = None + seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False presence_penalty: Optional[float] = 0.0 @@ -90,6 +91,7 @@ def to_sampling_params(self) -> SamplingParams: temperature=self.temperature, top_p=self.top_p, min_p=self.min_p, + seed=self.seed, stop=self.stop, stop_token_ids=self.stop_token_ids, max_tokens=self.max_tokens, @@ -117,6 +119,7 @@ class CompletionRequest(BaseModel): logprobs: Optional[int] = None echo: Optional[bool] = False stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + seed: Optional[int] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 best_of: Optional[int] = None @@ -147,6 +150,7 @@ def to_sampling_params(self): top_p=self.top_p, top_k=self.top_k, min_p=self.min_p, + seed=self.seed, stop=self.stop, stop_token_ids=self.stop_token_ids, ignore_eos=self.ignore_eos, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index bc86a916b5bbf..884d84387e505 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -342,7 +342,9 @@ def _beam_search_sample( def _multinomial( probs: torch.Tensor, num_samples: int, -): + seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None, + generators: Optional[List[torch.Generator]] = None, +) -> torch.Tensor: if num_samples > 1: # This is equivalent to torch.repeat_interleaved (which also # forces a GPU<->CPU sync). @@ -352,7 +354,15 @@ def _multinomial( probs = probs[:, None, :].expand(probs.shape[0], num_samples, probs.shape[1]).contiguous().view( -1, probs.shape[1]) - q = torch.empty_like(probs).exponential_(1) + q = torch.empty_like(probs) + if seq_groups is None: + q.exponential_() + else: + sample_idx = 0 + for (seq_ids, _), generator in zip(seq_groups, generators): + next_sample_idx = sample_idx + len(seq_ids) * num_samples + q[sample_idx:next_sample_idx].exponential_(generator=generator) + sample_idx = next_sample_idx return probs.div_(q).argmax(dim=1).view(-1, num_samples) @@ -370,6 +380,7 @@ def _sample( sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_metadata = {} + multinomial_samples = {} # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. @@ -385,14 +396,18 @@ def _sample( is_prompts, sample_indices) if sampling_type == SamplingType.GREEDY: greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1) - elif sampling_type == SamplingType.RANDOM: + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): max_best_of = 1 for seq_group, is_prompt in zip(seq_groups, is_prompts): if is_prompt: _, sampling_params = seq_group max_best_of = max(max_best_of, sampling_params.best_of) - multinomial_samples = _multinomial(probs[sample_indices], - max_best_of) + seeded_args = {} if sampling_type == SamplingType.RANDOM else { + "seq_groups": seq_groups, + "generators": sampling_metadata.generators, + } + multinomial_samples[sampling_type] = _multinomial( + probs[sample_indices], max_best_of, **seeded_args) elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] else: @@ -407,9 +422,9 @@ def _sample( sampling_type] if sampling_type == SamplingType.GREEDY: sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type == SamplingType.RANDOM: + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): sample_results = _random_sample(seq_groups, is_prompts, - multinomial_samples) + multinomial_samples[sampling_type]) elif sampling_type == SamplingType.BEAM: sample_results = _beam_search_sample(seq_groups, is_prompts, sampling_metadata.seq_data, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 2d41d40e04678..d0ffeecd2d74d 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -19,6 +19,7 @@ class SamplingMetadata: prompt_lens: Lengths of prompts. selected_token_indices: Token indices selected for sampling. categorized_sample_indices: SamplingType -> token indices to sample. + generators: List of torch.Generators to use for seeded sampling perform_sampling: Whether to perform sampling. This option is used to make the sampling only happens in the driver worker, and disable sampling in other worker processes. @@ -31,6 +32,7 @@ def __init__( prompt_lens: Optional[List[int]], selected_token_indices: torch.Tensor, categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]], + generators: Optional[List[torch.Generator]] = None, perform_sampling: bool = True, ) -> None: self.seq_groups = seq_groups @@ -38,6 +40,7 @@ def __init__( self.prompt_lens = prompt_lens self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices + self.generators = generators self.perform_sampling = perform_sampling self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0 diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index bb7d0002c910c..51d39220ca9ca 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -11,7 +11,8 @@ class SamplingType(IntEnum): GREEDY = 0 RANDOM = 1 - BEAM = 2 + RANDOM_SEED = 2 + BEAM = 3 LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] @@ -56,6 +57,7 @@ class SamplingParams: min_p: Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this. + seed: Random seed to use for the generation. use_beam_search: Whether to use beam search instead of sampling. length_penalty: Float that penalizes sequences based on their length. Used in beam search. @@ -101,6 +103,7 @@ def __init__( top_p: float = 1.0, top_k: int = -1, min_p: float = 0.0, + seed: Optional[int] = None, use_beam_search: bool = False, length_penalty: float = 1.0, early_stopping: Union[bool, str] = False, @@ -124,6 +127,7 @@ def __init__( self.top_p = top_p self.top_k = top_k self.min_p = min_p + self.seed = seed self.use_beam_search = use_beam_search self.length_penalty = length_penalty self.early_stopping = early_stopping @@ -229,6 +233,8 @@ def sampling_type(self) -> SamplingType: return SamplingType.BEAM if self.temperature < _SAMPLING_EPS: return SamplingType.GREEDY + if self.seed is not None: + return SamplingType.RANDOM_SEED return SamplingType.RANDOM def __repr__(self) -> str: @@ -242,6 +248,7 @@ def __repr__(self) -> str: f"top_p={self.top_p}, " f"top_k={self.top_k}, " f"min_p={self.min_p}, " + f"seed={self.seed}, " f"use_beam_search={self.use_beam_search}, " f"length_penalty={self.length_penalty}, " f"early_stopping={self.early_stopping}, " diff --git a/vllm/sequence.py b/vllm/sequence.py index 44adb058a5ba5..040e9756e15c6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -248,6 +248,14 @@ def __repr__(self) -> str: f"num_blocks={len(self.logical_token_blocks)})") +@dataclass +class SequenceGroupState: + """Mutable state tied to a specific sequence group""" + + # torch.Generator used in seeded sampling + generator: Optional = None + + class SequenceGroup: """A group of sequences that are generated from the same prompt. @@ -280,6 +288,7 @@ def __init__( self.lora_request = lora_request self.prefix: Optional[Prefix] = prefix self.prompt_logprobs: Optional[PromptLogprobs] = None + self.state = SequenceGroupState() @property def prompt(self) -> str: @@ -397,6 +406,7 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + state: Internal state tied to this sequence group. lora_request: LoRA request. prefix: The prefix of the prompt of the sequence group. """ @@ -410,6 +420,7 @@ def __init__( block_tables: Dict[int, List[int]], lora_request: Optional[LoRARequest] = None, prefix: Optional[Prefix] = None, + state: Optional[SequenceGroupState] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt @@ -418,6 +429,7 @@ def __init__( self.block_tables = block_tables self.lora_request = lora_request self.prefix = prefix + self.state = SequenceGroupState() if state is None else state @property def lora_int_id(self) -> int: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a27b7d9cb4dcf..b99a409e02d1e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -389,6 +389,7 @@ def _prepare_sample( ) -> SamplingMetadata: seq_groups: List[Tuple[List[int], SamplingParams]] = [] selected_token_indices: List[int] = [] + generators: List[torch.Generator] = [] selected_token_start_idx = 0 categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices_start_idx = 0 @@ -419,6 +420,10 @@ def _prepare_sample( selected_token_indices.append(selected_token_start_idx + subquery_len - 1) selected_token_start_idx += max_subquery_len + + if sampling_params.seed is not None: + seq_group_metadata.state.generator = torch.Generator( + device="cuda").manual_seed(sampling_params.seed) else: num_seqs = len(seq_ids) selected_token_indices.extend( @@ -432,6 +437,9 @@ def _prepare_sample( categorized_sample_indices_start_idx + num_seqs)) categorized_sample_indices_start_idx += num_seqs + if sampling_params.seed is not None: + generators.append(seq_group_metadata.state.generator) + selected_token_indices = _async_h2d(selected_token_indices, dtype=torch.long, target_device=self.device, @@ -454,6 +462,7 @@ def _prepare_sample( prompt_lens=prompt_lens, selected_token_indices=selected_token_indices, categorized_sample_indices=categorized_sample_indices, + generators=generators, ) return sampling_metadata @@ -536,6 +545,7 @@ def prepare_input_tensors( prompt_lens=None, selected_token_indices=metadata_dict["selected_token_indices"], categorized_sample_indices=None, + generators=None, perform_sampling=False, )