8
8
from transformers .modeling_outputs import ModelOutput , CausalLMOutputWithPast
9
9
from typing import List , Optional , Tuple , Union
10
10
11
+
11
12
@dataclass
12
13
class QEffCausalLMOutputWithPast (ModelOutput ):
13
14
loss : Optional [torch .FloatTensor ] = None
@@ -105,45 +106,46 @@ def sampler_forward(
105
106
106
107
last_accepted_output_tokens (`torch.Tensor`, *optional*):
107
108
Output tokens accepted by the Speculative Decoding Draft Language Model.
108
-
109
+
109
110
repetition_penalty_retain_state (`torch.Tensor`, *optional*):
110
- RetainedState buffer used as a mask to apply repetition penalty to the input
111
+ RetainedState buffer used as a mask to apply repetition penalty to the input
111
112
prompt and the output generated so far.
112
-
113
+
113
114
repetition_penalties (`torch.Tensor`, *optional*):
114
- Sampling parameter that penalizes new tokens based on whether they appear in the
115
- prompt and the generated text so far. Values > 1 encourage the model to use
115
+ Sampling parameter that penalizes new tokens based on whether they appear in the
116
+ prompt and the generated text so far. Values > 1 encourage the model to use
116
117
new tokens, while values < 1 encourage the model to repeat tokens.
117
-
118
+
118
119
presence_penalty_retain_state (`torch.Tensor`, *optional*):
119
- RetainedState buffer used as a mask to apply presence penalty to the output
120
+ RetainedState buffer used as a mask to apply presence penalty to the output
120
121
generated so far.
121
-
122
+
122
123
presence_penalties (`torch.Tensor`, *optional*):
123
- Sampling parameter that penalizes new tokens based on whether they appear in the
124
- generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.
125
-
124
+ Sampling parameter that penalizes new tokens based on whether they appear in the
125
+ generated text so far. Values > 0 encourage the model to use new tokens, while
126
+ values < 0 encourage the model to repeat tokens.
127
+
126
128
temperatures (`torch.Tensor`, *optional*):
127
- Sampling parameter that controls the randomness of the sampling. Lower values
128
- make the model more deterministic, while higher values make the model more
129
+ Sampling parameter that controls the randomness of the sampling. Lower values
130
+ make the model more deterministic, while higher values make the model more
129
131
random. Zero means greedy sampling.
130
-
132
+
131
133
top_ks (`torch.Tensor`, *optional*):
132
134
Sampling parameter that controls the number of top tokens to consider.
133
-
135
+
134
136
top_ps (`torch.Tensor`, *optional*):
135
- Sampling parameter that controls the cumulative probability of the top tokens to
137
+ Sampling parameter that controls the cumulative probability of the top tokens to
136
138
consider. Must be in (0, 1]. Set to 1.0 to consider all tokens.
137
-
139
+
138
140
min_ps (`torch.Tensor`, *optional*):
139
- Sampling parameter that represents the minimum probability for a token to be
140
- considered, relative to the probability of the most likely token. Must be in
141
+ Sampling parameter that represents the minimum probability for a token to be
142
+ considered, relative to the probability of the most likely token. Must be in
141
143
[0, 1]. Set to 0.0 to disable this.
142
-
144
+
143
145
random_numbers (`torch.Tensor`, *optional*):
144
- Sampling parameter that represents the random seeds to use for random sampling.
146
+ Sampling parameter that represents the random seeds to use for random sampling.
145
147
Must be in [-1, 1].
146
-
148
+
147
149
Returns:
148
150
149
151
Example:
@@ -210,12 +212,12 @@ def sampler_forward(
210
212
211
213
# Perform Sampling
212
214
batch_size , spec_length , vocab_size = logits .shape
213
-
215
+
214
216
# Select relevant rows
215
217
batch_index_reshaped = batch_index .view (- 1 )
216
218
repetition_penalty_retain_state_selected = torch .index_select (repetition_penalty_retain_state , 0 , batch_index_reshaped )
217
219
presence_penalty_retain_state_selected = torch .index_select (presence_penalty_retain_state , 0 , batch_index_reshaped )
218
-
220
+
219
221
logits = logits .reshape (- 1 , vocab_size ) # Reshape tensor to 2D
220
222
221
223
if input_ids .shape [1 ] > spec_length : # Prefill phase, initialize retained states
@@ -238,16 +240,14 @@ def sampler_forward(
238
240
repetition_penalties = repetition_penalties .repeat (spec_length , vocab_size ) # (batch_size, 1) -> (batch_size * spec_length, vocab_size)
239
241
repetition_penalty_retain_state_selected = repetition_penalty_retain_state_selected .repeat (spec_length , 1 ) # (batch_size, vocab_size) -> (batch_size * spec_length, vocab_size)
240
242
repetition_penalties [~ repetition_penalty_retain_state_selected .bool ()] = 1.0
241
- logits = torch .where (
242
- logits > 0 , logits / repetition_penalties , logits * repetition_penalties
243
- )
243
+ logits = torch .where (logits > 0 , logits / repetition_penalties , logits * repetition_penalties )
244
244
245
245
# Presence Penalty
246
246
if (presence_penalties != 0. ).any ():
247
247
presence_penalties = presence_penalties .repeat (spec_length , 1 ) # (batch_size, 1) -> (batch_size * spec_length, 1)
248
248
presence_penalty_retain_state_selected = presence_penalty_retain_state_selected .repeat (spec_length , 1 ) # (batch_size, vocab_size) -> (batch_size * spec_length, vocab_size)
249
249
logits -= presence_penalties * presence_penalty_retain_state_selected
250
-
250
+
251
251
# TODO: Frequency Penalty
252
252
253
253
# Temperature Scaling
@@ -284,7 +284,7 @@ def sampler_forward(
284
284
probs = torch .softmax (logits , dim = 1 ) # (batch_size * spec_length, vocab_size)
285
285
286
286
# Sample the next tokens
287
- # TODO (Optimization): if self.return_pds: skip
287
+ # TODO (Optimization): if self.return_pds: skip
288
288
greedy_samples = torch .argmax (probs , dim = - 1 , keepdim = True ) # Greedy Sampling
289
289
gumbel_noise = - torch .log (- torch .log (random_numbers .repeat (spec_length , 1 ))) # Gumbel-Max Trick
290
290
y = probs + gumbel_noise
0 commit comments