Skip to content

Commit 3b63ecb

Browse files
committed
Allow chunked prompts during prefill
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
1 parent 05c0bf0 commit 3b63ecb

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

QEfficient/transformers/sampler/sampler.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,10 @@ def sampler_forward(
221221
logits = logits.reshape(-1, vocab_size) # Reshape tensor to 2D
222222

223223
if input_ids.shape[1] > spec_length: # Prefill phase, initialize retained states
224-
repetition_penalty_retain_state_selected = torch.mul(repetition_penalty_retain_state_selected, 0)
225-
presence_penalty_retain_state_selected = torch.mul(presence_penalty_retain_state_selected, 0)
226224
# TODO: Replace scatter_ with CtxScatterFunc; Replace -1 with int_max while exporting on onnx
227225
# repetition_penalty_retain_state_selected = CtxScatterFunc.apply(repetition_penalty_retain_state_selected.unsqueeze(1), input_ids, 1).squeeze(1)
228226
repetition_penalty_retain_state_selected.scatter_(1, input_ids, 1)
227+
presence_penalty_retain_state_selected.scatter_(1, input_ids, 0)
229228
else: # Decode phase, update retained states
230229
repetition_penalty_retain_state_selected.scatter_(1, last_accepted_output_tokens, 1)
231230
presence_penalty_retain_state_selected.scatter_(1, last_accepted_output_tokens, 1)

0 commit comments

Comments
 (0)