Skip to content

Commit 9474b30

Browse files
author
sanising
committed
Fix bug
1 parent 85d46cc commit 9474b30

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

QEfficient/transformers/sampler/sampler.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def sampler_forward(
257257

258258
# Top K
259259
# TODO (Optimization): if (top_ks != -1 or top_ks != Constants.MAX_TOP_K_IDS).any(): skip
260-
topk_values_asc, topk_indices_asc = torch.topk(logits, k=Constants.MAX_TOP_K_IDS, dim=1, largest=False) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
260+
topk_values, topk_indices = torch.topk(logits, k=Constants.MAX_TOP_K_IDS, dim=1) # (batch_size * spec_length, vocab_size)
261+
topk_values_asc = topk_values.flip(dims=[1])
262+
topk_indices_asc = topk_indices.flip(dims=[1])
261263
top_ks[top_ks > Constants.MAX_TOP_K_IDS] = Constants.MAX_TOP_K_IDS # Clip k to max value
262264
# True values in this mask indicate the positions of the non-top K values
263265
topk_mask = torch.arange(topk_values_asc.shape[1]).unsqueeze(0) < (topk_values_asc.size(1) - top_ks.to(torch.long)).repeat(spec_length, 1) # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
@@ -277,6 +279,7 @@ def sampler_forward(
277279
min_p_mask = top_probs < scaled_min_p # (batch_size * spec_length, Constants.MAX_TOP_K_IDS)
278280
topk_values_asc[min_p_mask] = torch.finfo(torch.float16).min
279281

282+
logits.fill_(torch.finfo(torch.float16).min)
280283
logits = logits.scatter(1, topk_indices_asc, topk_values_asc) # (batch_size * spec_length, vocab_size)
281284

282285
# Softmax

0 commit comments

Comments
 (0)