From 0af3abe3d3225449c907d75eb3d2ae4b83bd21a1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 3 Sep 2024 13:29:24 -0700 Subject: [PATCH] [TPU][Bugfix] Fix next_token_ids shape (#8128) --- vllm/worker/tpu_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index a0498315516b8..684c54b7d8139 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -601,7 +601,7 @@ def _execute_model(*args): batch_idx += 1 else: for seq_id in seq_ids: - next_token_id = next_token_ids[batch_idx][0] + next_token_id = next_token_ids[batch_idx] seq_outputs.append( SequenceOutput(seq_id, next_token_id, {next_token_id: zero_logprob})) @@ -722,6 +722,9 @@ def forward( sampled_token_ids = torch.multinomial(probs, num_samples, replacement=True) + if num_samples == 1: + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + sampled_token_ids = sampled_token_ids.squeeze(dim=-1) next_token_ids = torch.where(t != 0, sampled_token_ids, argmax_token_ids) return next_token_ids