File tree 1 file changed +1
-2
lines changed
QEfficient/transformers/sampler
1 file changed +1
-2
lines changed Original file line number Diff line number Diff line change @@ -221,11 +221,10 @@ def sampler_forward(
221
221
logits = logits .reshape (- 1 , vocab_size ) # Reshape tensor to 2D
222
222
223
223
if input_ids .shape [1 ] > spec_length : # Prefill phase, initialize retained states
224
- repetition_penalty_retain_state_selected = torch .mul (repetition_penalty_retain_state_selected , 0 )
225
- presence_penalty_retain_state_selected = torch .mul (presence_penalty_retain_state_selected , 0 )
226
224
# TODO: Replace scatter_ with CtxScatterFunc; Replace -1 with int_max while exporting on onnx
227
225
# repetition_penalty_retain_state_selected = CtxScatterFunc.apply(repetition_penalty_retain_state_selected.unsqueeze(1), input_ids, 1).squeeze(1)
228
226
repetition_penalty_retain_state_selected .scatter_ (1 , input_ids , 1 )
227
+ presence_penalty_retain_state_selected .scatter_ (1 , input_ids , 0 )
229
228
else : # Decode phase, update retained states
230
229
repetition_penalty_retain_state_selected .scatter_ (1 , last_accepted_output_tokens , 1 )
231
230
presence_penalty_retain_state_selected .scatter_ (1 , last_accepted_output_tokens , 1 )
You can’t perform that action at this time.
0 commit comments