Skip to content

Commit 9595933

Browse files
WoosukKwonprashantgupta24
authored andcommitted
[Hardware][TPU] Raise errors for unsupported sampling params (vllm-project#5850)
1 parent 8509fef commit 9595933

File tree

1 file changed

+44
-19
lines changed

1 file changed

+44
-19
lines changed

vllm/worker/tpu_model_runner.py

+44-19
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
logger = init_logger(__name__)
2121

2222
_PAD_SLOT_ID = 0 # FIXME(woosuk)
23+
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
24+
_ENABLE_TOP_P = False
2325

2426

2527
class TPUModelRunner:
@@ -339,9 +341,34 @@ def _prepare_sample(
339341
assert seq_group_metadata.sampling_params is not None
340342
sampling_params = seq_group_metadata.sampling_params
341343

344+
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
345+
# low temperature. This is not accurate.
342346
t.append(sampling_params.temperature
343347
if sampling_params.temperature >= 1e-5 else 1e-5)
348+
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
349+
raise NotImplementedError(
350+
"Top-p sampling is currently disabled for the TPU backend "
351+
"due to performance issues.")
344352
p.append(sampling_params.top_p)
353+
if sampling_params.top_k != -1:
354+
raise NotImplementedError(
355+
"Top-k sampling is currently disabled for the TPU backend "
356+
"due to performance issues.")
357+
if sampling_params.best_of > 1:
358+
raise NotImplementedError(
359+
"best_of > 1 is not currently supported by the TPU "
360+
"backend.")
361+
if sampling_params.use_beam_search:
362+
raise NotImplementedError(
363+
"Beam search is not supported by the TPU backend.")
364+
if sampling_params.logprobs is not None:
365+
raise NotImplementedError(
366+
"logprobs is not currently supported by the TPU backend.")
367+
if sampling_params.prompt_logprobs is not None:
368+
raise NotImplementedError(
369+
"prompt_logprobs is not currently supported by the TPU "
370+
"backend.")
371+
345372
num_paddings = padded_batch_size - len(seq_group_metadata_list)
346373
t += [1.0] * num_paddings
347374
p += [1.0] * num_paddings
@@ -350,35 +377,32 @@ def _prepare_sample(
350377
p = torch.tensor(p, dtype=torch.float32, device=self.device)
351378
return t, p
352379

353-
def prepare_inputs(
380+
def _execute_model(
354381
self,
355-
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
356-
):
357-
assert seq_group_metadata_list is not None
382+
seq_group_metadata_list: List[SequenceGroupMetadata],
383+
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
384+
) -> List[CompletionSequenceGroupOutput]:
385+
# Prepare inputs.
358386
assert len(seq_group_metadata_list) > 0
359387
# NOTE: We assume that all sequences in the group are all prompts or
360388
# all decodes.
361-
if seq_group_metadata_list[0].is_prompt:
389+
is_prompt = seq_group_metadata_list[0].is_prompt
390+
if is_prompt:
362391
inputs = self._prepare_prompt(seq_group_metadata_list)
363392
else:
364393
inputs = self._prepare_decode(seq_group_metadata_list)
365394
padded_batch_size = inputs[0].shape[0]
366-
sample_inputs = self._prepare_sample(seq_group_metadata_list,
367-
padded_batch_size)
368-
return inputs + sample_inputs
395+
t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size)
369396

370-
def _execute_model(
371-
self,
372-
seq_group_metadata_list: List[SequenceGroupMetadata],
373-
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
374-
) -> List[CompletionSequenceGroupOutput]:
375-
inputs = self.prepare_inputs(seq_group_metadata_list)
397+
# Execute the model.
376398
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
377-
*inputs[2:])
378-
if not self.is_driver_worker:
379-
return []
399+
*inputs[2:], t, p)
400+
# Retrieve the outputs to CPU.
380401
next_token_ids = next_token_ids.cpu().tolist()
381402

403+
# NOTE(woosuk): Minimal code to construct the sampler outputs.
404+
# The TPU backend does not reuse the sampler, since the TPU backend
405+
# does not support the advanced sampling parameters such as logprobs.
382406
i = 0
383407
sampler_outputs = []
384408
for seq_group_metadata in seq_group_metadata_list:
@@ -400,6 +424,7 @@ def execute_model(
400424
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
401425
) -> SamplerOutput:
402426
assert seq_group_metadata_list is not None
427+
assert len(seq_group_metadata_list) > 0
403428
if seq_group_metadata_list[0].is_prompt:
404429
# NOTE(woosuk): To reduce the compilation time, we only compile the
405430
# prefill inputs with batch size 1. Because the scheduler is not
@@ -492,8 +517,8 @@ def forward(
492517
logits = self.model.compute_logits(hidden_states, sampling_metadata)
493518

494519
logits = logits / t.unsqueeze(dim=1)
495-
# FIXME(woosuk): Disabled top-p sampling since it's too slow.
496-
# logits = _apply_top_p(logits, p.unsqueeze(dim=1))
520+
if _ENABLE_TOP_P:
521+
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
497522
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
498523
# FIXME(woosuk): best_of > 1 is not supported.
499524
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)

0 commit comments

Comments
 (0)