Skip to content

Commit 8da0465

Browse files
wallashsslucast2021
authored and
lucast2021
committed
[Bugfix] Fix spec decoding when seed is none in a batch (vllm-project#10863)
Signed-off-by: Wallas Santos <wallashss@ibm.com> Signed-off-by: lucast2021 <lucast2021@headroyce.org>
1 parent 547fafb commit 8da0465

File tree

2 files changed

+66
-7
lines changed

2 files changed

+66
-7
lines changed

tests/samplers/test_rejection_sampler.py

+63
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,69 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
200200
assert torch.equal(results[j][i], results[0][i])
201201

202202

203+
@pytest.mark.parametrize("k", [1, 3, 6])
204+
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
205+
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
206+
@pytest.mark.parametrize("device", CUDA_DEVICES)
207+
@pytest.mark.parametrize("use_flashinfer", [True, False])
208+
@torch.inference_mode()
209+
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
210+
device: str, use_flashinfer: bool):
211+
torch.set_default_device(device)
212+
set_random_seed(0)
213+
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
214+
target_probs = torch.rand(batch_size,
215+
k + 1,
216+
vocab_size,
217+
dtype=torch.float32)
218+
bonus_token_ids = torch.randint(low=0,
219+
high=vocab_size,
220+
size=(batch_size, 1),
221+
dtype=torch.int64)
222+
draft_token_ids = torch.randint(low=0,
223+
high=vocab_size,
224+
size=(batch_size, k),
225+
dtype=torch.int64)
226+
227+
single_batches = []
228+
for i in range(batch_size):
229+
single_batches.append((draft_probs[i].clone().unsqueeze(0),
230+
draft_token_ids[i].clone().unsqueeze(0),
231+
target_probs[i].clone().unsqueeze(0),
232+
bonus_token_ids[i].clone().unsqueeze(0),
233+
draft_token_ids[i].clone().unsqueeze(0)))
234+
235+
set_random_seed(0)
236+
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
237+
rejection_sampler.init_gpu_tensors(device=device)
238+
239+
results = []
240+
seeded_seqs = {
241+
i: torch.Generator(device=device).manual_seed(i)
242+
for i in range(1, batch_size) # 0 is seed None
243+
}
244+
batch_result = rejection_sampler(target_probs.clone(),
245+
bonus_token_ids.clone(),
246+
draft_probs.clone(),
247+
draft_token_ids.clone(), seeded_seqs)
248+
249+
set_random_seed(0)
250+
251+
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
252+
rejection_sampler.init_gpu_tensors(device=device)
253+
for i in range(batch_size):
254+
request_seeded_seqs = {
255+
0: torch.Generator(device=device).manual_seed(i)
256+
} if seeded_seqs.get(i) is not None else None
257+
(draft_probs, draft_token_ids, target_probs, bonus_token_ids,
258+
draft_token_ids) = single_batches[i]
259+
results.append(
260+
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
261+
draft_token_ids, request_seeded_seqs))
262+
for i in range(batch_size):
263+
assert torch.equal(batch_result[i], results[i].squeeze(0))
264+
265+
203266
@pytest.mark.parametrize("k", [1, 3, 6])
204267
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
205268
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])

vllm/model_executor/layers/rejection_sampler.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from functools import cached_property
22
from importlib.util import find_spec
3-
from typing import Dict, List, Optional, Tuple
3+
from typing import Dict, Optional, Tuple
44

55
import torch
66
import torch.jit
@@ -386,16 +386,12 @@ def _multinomial(
386386
if not seeded_seqs:
387387
q.exponential_(1.0)
388388
else:
389-
non_seeded_indices: List[int] = []
390389
start = 0
391390
for idx in range(len(q) // k):
392391
end = start + k
393392
generator = seeded_seqs.get(idx)
394-
if generator is None:
395-
non_seeded_indices.extend(list(range(start, end)))
396-
else:
397-
q[start:end].exponential_(1.0, generator=generator)
393+
# Note: generator might be None for non seeded
394+
q[start:end].exponential_(1.0, generator=generator)
398395
start = end
399-
q[non_seeded_indices].exponential_(1.0)
400396

401397
return probs.div_(q).argmax(dim=1).view(-1, num_samples)

0 commit comments

Comments
 (0)