@@ -200,6 +200,69 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
200
200
assert torch .equal (results [j ][i ], results [0 ][i ])
201
201
202
202
203
+ @pytest .mark .parametrize ("k" , [1 , 3 , 6 ])
204
+ @pytest .mark .parametrize ("vocab_size" , [30_000 , 50_000 ])
205
+ @pytest .mark .parametrize ("batch_size" , [3 , 8 , 32 , 128 ])
206
+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
207
+ @pytest .mark .parametrize ("use_flashinfer" , [True , False ])
208
+ @torch .inference_mode ()
209
+ def test_mixed_seeded_batch (k : int , vocab_size : int , batch_size : int ,
210
+ device : str , use_flashinfer : bool ):
211
+ torch .set_default_device (device )
212
+ set_random_seed (0 )
213
+ draft_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
214
+ target_probs = torch .rand (batch_size ,
215
+ k + 1 ,
216
+ vocab_size ,
217
+ dtype = torch .float32 )
218
+ bonus_token_ids = torch .randint (low = 0 ,
219
+ high = vocab_size ,
220
+ size = (batch_size , 1 ),
221
+ dtype = torch .int64 )
222
+ draft_token_ids = torch .randint (low = 0 ,
223
+ high = vocab_size ,
224
+ size = (batch_size , k ),
225
+ dtype = torch .int64 )
226
+
227
+ single_batches = []
228
+ for i in range (batch_size ):
229
+ single_batches .append ((draft_probs [i ].clone ().unsqueeze (0 ),
230
+ draft_token_ids [i ].clone ().unsqueeze (0 ),
231
+ target_probs [i ].clone ().unsqueeze (0 ),
232
+ bonus_token_ids [i ].clone ().unsqueeze (0 ),
233
+ draft_token_ids [i ].clone ().unsqueeze (0 )))
234
+
235
+ set_random_seed (0 )
236
+ rejection_sampler = RejectionSampler (use_flashinfer = use_flashinfer )
237
+ rejection_sampler .init_gpu_tensors (device = device )
238
+
239
+ results = []
240
+ seeded_seqs = {
241
+ i : torch .Generator (device = device ).manual_seed (i )
242
+ for i in range (1 , batch_size ) # 0 is seed None
243
+ }
244
+ batch_result = rejection_sampler (target_probs .clone (),
245
+ bonus_token_ids .clone (),
246
+ draft_probs .clone (),
247
+ draft_token_ids .clone (), seeded_seqs )
248
+
249
+ set_random_seed (0 )
250
+
251
+ rejection_sampler = RejectionSampler (use_flashinfer = use_flashinfer )
252
+ rejection_sampler .init_gpu_tensors (device = device )
253
+ for i in range (batch_size ):
254
+ request_seeded_seqs = {
255
+ 0 : torch .Generator (device = device ).manual_seed (i )
256
+ } if seeded_seqs .get (i ) is not None else None
257
+ (draft_probs , draft_token_ids , target_probs , bonus_token_ids ,
258
+ draft_token_ids ) = single_batches [i ]
259
+ results .append (
260
+ rejection_sampler (target_probs , bonus_token_ids , draft_probs ,
261
+ draft_token_ids , request_seeded_seqs ))
262
+ for i in range (batch_size ):
263
+ assert torch .equal (batch_result [i ], results [i ].squeeze (0 ))
264
+
265
+
203
266
@pytest .mark .parametrize ("k" , [1 , 3 , 6 ])
204
267
@pytest .mark .parametrize ("vocab_size" , [30_000 , 50_000 ])
205
268
@pytest .mark .parametrize ("batch_size" , [1 , 8 , 32 , 128 ])
0 commit comments