20
20
logger = init_logger (__name__ )
21
21
22
22
_PAD_SLOT_ID = 0 # FIXME(woosuk)
23
+ # FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
24
+ _ENABLE_TOP_P = False
23
25
24
26
25
27
class TPUModelRunner :
@@ -339,9 +341,34 @@ def _prepare_sample(
339
341
assert seq_group_metadata .sampling_params is not None
340
342
sampling_params = seq_group_metadata .sampling_params
341
343
344
+ # NOTE(woosuk): Here we mimic argmax sampling by applying a very
345
+ # low temperature. This is not accurate.
342
346
t .append (sampling_params .temperature
343
347
if sampling_params .temperature >= 1e-5 else 1e-5 )
348
+ if sampling_params .top_p != 1 and not _ENABLE_TOP_P :
349
+ raise NotImplementedError (
350
+ "Top-p sampling is currently disabled for the TPU backend "
351
+ "due to performance issues." )
344
352
p .append (sampling_params .top_p )
353
+ if sampling_params .top_k != - 1 :
354
+ raise NotImplementedError (
355
+ "Top-k sampling is currently disabled for the TPU backend "
356
+ "due to performance issues." )
357
+ if sampling_params .best_of > 1 :
358
+ raise NotImplementedError (
359
+ "best_of > 1 is not currently supported by the TPU "
360
+ "backend." )
361
+ if sampling_params .use_beam_search :
362
+ raise NotImplementedError (
363
+ "Beam search is not supported by the TPU backend." )
364
+ if sampling_params .logprobs is not None :
365
+ raise NotImplementedError (
366
+ "logprobs is not currently supported by the TPU backend." )
367
+ if sampling_params .prompt_logprobs is not None :
368
+ raise NotImplementedError (
369
+ "prompt_logprobs is not currently supported by the TPU "
370
+ "backend." )
371
+
345
372
num_paddings = padded_batch_size - len (seq_group_metadata_list )
346
373
t += [1.0 ] * num_paddings
347
374
p += [1.0 ] * num_paddings
@@ -350,35 +377,32 @@ def _prepare_sample(
350
377
p = torch .tensor (p , dtype = torch .float32 , device = self .device )
351
378
return t , p
352
379
353
- def prepare_inputs (
380
+ def _execute_model (
354
381
self ,
355
- seq_group_metadata_list : Optional [List [SequenceGroupMetadata ]],
356
- ):
357
- assert seq_group_metadata_list is not None
382
+ seq_group_metadata_list : List [SequenceGroupMetadata ],
383
+ kv_caches : List [Tuple [torch .Tensor , torch .Tensor ]],
384
+ ) -> List [CompletionSequenceGroupOutput ]:
385
+ # Prepare inputs.
358
386
assert len (seq_group_metadata_list ) > 0
359
387
# NOTE: We assume that all sequences in the group are all prompts or
360
388
# all decodes.
361
- if seq_group_metadata_list [0 ].is_prompt :
389
+ is_prompt = seq_group_metadata_list [0 ].is_prompt
390
+ if is_prompt :
362
391
inputs = self ._prepare_prompt (seq_group_metadata_list )
363
392
else :
364
393
inputs = self ._prepare_decode (seq_group_metadata_list )
365
394
padded_batch_size = inputs [0 ].shape [0 ]
366
- sample_inputs = self ._prepare_sample (seq_group_metadata_list ,
367
- padded_batch_size )
368
- return inputs + sample_inputs
395
+ t , p = self ._prepare_sample (seq_group_metadata_list , padded_batch_size )
369
396
370
- def _execute_model (
371
- self ,
372
- seq_group_metadata_list : List [SequenceGroupMetadata ],
373
- kv_caches : List [Tuple [torch .Tensor , torch .Tensor ]],
374
- ) -> List [CompletionSequenceGroupOutput ]:
375
- inputs = self .prepare_inputs (seq_group_metadata_list )
397
+ # Execute the model.
376
398
next_token_ids = self .model (inputs [0 ], inputs [1 ], kv_caches ,
377
- * inputs [2 :])
378
- if not self .is_driver_worker :
379
- return []
399
+ * inputs [2 :], t , p )
400
+ # Retrieve the outputs to CPU.
380
401
next_token_ids = next_token_ids .cpu ().tolist ()
381
402
403
+ # NOTE(woosuk): Minimal code to construct the sampler outputs.
404
+ # The TPU backend does not reuse the sampler, since the TPU backend
405
+ # does not support the advanced sampling parameters such as logprobs.
382
406
i = 0
383
407
sampler_outputs = []
384
408
for seq_group_metadata in seq_group_metadata_list :
@@ -400,6 +424,7 @@ def execute_model(
400
424
kv_caches : List [Tuple [torch .Tensor , torch .Tensor ]],
401
425
) -> SamplerOutput :
402
426
assert seq_group_metadata_list is not None
427
+ assert len (seq_group_metadata_list ) > 0
403
428
if seq_group_metadata_list [0 ].is_prompt :
404
429
# NOTE(woosuk): To reduce the compilation time, we only compile the
405
430
# prefill inputs with batch size 1. Because the scheduler is not
@@ -492,8 +517,8 @@ def forward(
492
517
logits = self .model .compute_logits (hidden_states , sampling_metadata )
493
518
494
519
logits = logits / t .unsqueeze (dim = 1 )
495
- # FIXME(woosuk): Disabled top-p sampling since it's too slow.
496
- # logits = _apply_top_p(logits, p.unsqueeze(dim=1))
520
+ if _ENABLE_TOP_P :
521
+ logits = _apply_top_p (logits , p .unsqueeze (dim = 1 ))
497
522
probs = torch .softmax (logits , dim = - 1 , dtype = torch .float32 )
498
523
# FIXME(woosuk): best_of > 1 is not supported.
499
524
next_token_ids = torch .multinomial (probs , num_samples = 1 ).squeeze (dim = 1 )
0 commit comments