Skip to content

Commit f5f5e2d

Browse files
committed
Reformat code
Signed-off-by: quic-sanising <quic_sanising@quicinc.com>
1 parent abbaf53 commit f5f5e2d

File tree

1 file changed

+29
-29
lines changed

1 file changed

+29
-29
lines changed

QEfficient/transformers/sampler/sampler.py

+29-29
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from transformers.modeling_outputs import ModelOutput, CausalLMOutputWithPast
99
from typing import List, Optional, Tuple, Union
1010

11+
1112
@dataclass
1213
class QEffCausalLMOutputWithPast(ModelOutput):
1314
loss: Optional[torch.FloatTensor] = None
@@ -105,45 +106,46 @@ def sampler_forward(
105106
106107
last_accepted_output_tokens (`torch.Tensor`, *optional*):
107108
Output tokens accepted by the Speculative Decoding Draft Language Model.
108-
109+
109110
repetition_penalty_retain_state (`torch.Tensor`, *optional*):
110-
RetainedState buffer used as a mask to apply repetition penalty to the input
111+
RetainedState buffer used as a mask to apply repetition penalty to the input
111112
prompt and the output generated so far.
112-
113+
113114
repetition_penalties (`torch.Tensor`, *optional*):
114-
Sampling parameter that penalizes new tokens based on whether they appear in the
115-
prompt and the generated text so far. Values > 1 encourage the model to use
115+
Sampling parameter that penalizes new tokens based on whether they appear in the
116+
prompt and the generated text so far. Values > 1 encourage the model to use
116117
new tokens, while values < 1 encourage the model to repeat tokens.
117-
118+
118119
presence_penalty_retain_state (`torch.Tensor`, *optional*):
119-
RetainedState buffer used as a mask to apply presence penalty to the output
120+
RetainedState buffer used as a mask to apply presence penalty to the output
120121
generated so far.
121-
122+
122123
presence_penalties (`torch.Tensor`, *optional*):
123-
Sampling parameter that penalizes new tokens based on whether they appear in the
124-
generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.
125-
124+
Sampling parameter that penalizes new tokens based on whether they appear in the
125+
generated text so far. Values > 0 encourage the model to use new tokens, while
126+
values < 0 encourage the model to repeat tokens.
127+
126128
temperatures (`torch.Tensor`, *optional*):
127-
Sampling parameter that controls the randomness of the sampling. Lower values
128-
make the model more deterministic, while higher values make the model more
129+
Sampling parameter that controls the randomness of the sampling. Lower values
130+
make the model more deterministic, while higher values make the model more
129131
random. Zero means greedy sampling.
130-
132+
131133
top_ks (`torch.Tensor`, *optional*):
132134
Sampling parameter that controls the number of top tokens to consider.
133-
135+
134136
top_ps (`torch.Tensor`, *optional*):
135-
Sampling parameter that controls the cumulative probability of the top tokens to
137+
Sampling parameter that controls the cumulative probability of the top tokens to
136138
consider. Must be in (0, 1]. Set to 1.0 to consider all tokens.
137-
139+
138140
min_ps (`torch.Tensor`, *optional*):
139-
Sampling parameter that represents the minimum probability for a token to be
140-
considered, relative to the probability of the most likely token. Must be in
141+
Sampling parameter that represents the minimum probability for a token to be
142+
considered, relative to the probability of the most likely token. Must be in
141143
[0, 1]. Set to 0.0 to disable this.
142-
144+
143145
random_numbers (`torch.Tensor`, *optional*):
144-
Sampling parameter that represents the random seeds to use for random sampling.
146+
Sampling parameter that represents the random seeds to use for random sampling.
145147
Must be in [-1, 1].
146-
148+
147149
Returns:
148150
149151
Example:
@@ -210,12 +212,12 @@ def sampler_forward(
210212

211213
# Perform Sampling
212214
batch_size, spec_length, vocab_size = logits.shape
213-
215+
214216
# Select relevant rows
215217
batch_index_reshaped = batch_index.view(-1)
216218
repetition_penalty_retain_state_selected = torch.index_select(repetition_penalty_retain_state, 0, batch_index_reshaped)
217219
presence_penalty_retain_state_selected = torch.index_select(presence_penalty_retain_state, 0, batch_index_reshaped)
218-
220+
219221
logits = logits.reshape(-1, vocab_size) # Reshape tensor to 2D
220222

221223
if input_ids.shape[1] > spec_length: # Prefill phase, initialize retained states
@@ -238,16 +240,14 @@ def sampler_forward(
238240
repetition_penalties = repetition_penalties.repeat(spec_length, vocab_size) # (batch_size, 1) -> (batch_size * spec_length, vocab_size)
239241
repetition_penalty_retain_state_selected = repetition_penalty_retain_state_selected.repeat(spec_length, 1) # (batch_size, vocab_size) -> (batch_size * spec_length, vocab_size)
240242
repetition_penalties[~repetition_penalty_retain_state_selected.bool()] = 1.0
241-
logits = torch.where(
242-
logits > 0, logits / repetition_penalties, logits * repetition_penalties
243-
)
243+
logits = torch.where(logits > 0, logits / repetition_penalties, logits * repetition_penalties)
244244

245245
# Presence Penalty
246246
if (presence_penalties != 0.).any():
247247
presence_penalties = presence_penalties.repeat(spec_length, 1) # (batch_size, 1) -> (batch_size * spec_length, 1)
248248
presence_penalty_retain_state_selected = presence_penalty_retain_state_selected.repeat(spec_length, 1) # (batch_size, vocab_size) -> (batch_size * spec_length, vocab_size)
249249
logits -= presence_penalties * presence_penalty_retain_state_selected
250-
250+
251251
# TODO: Frequency Penalty
252252

253253
# Temperature Scaling
@@ -284,7 +284,7 @@ def sampler_forward(
284284
probs = torch.softmax(logits, dim=1) # (batch_size * spec_length, vocab_size)
285285

286286
# Sample the next tokens
287-
# TODO (Optimization): if self.return_pds: skip
287+
# TODO (Optimization): if self.return_pds: skip
288288
greedy_samples = torch.argmax(probs, dim=-1, keepdim=True) # Greedy Sampling
289289
gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick
290290
y = probs + gumbel_noise

0 commit comments

Comments
 (0)