@@ -147,16 +147,25 @@ def rejection_sample(
147
147
if not sampling_metadata .all_random :
148
148
# Rejection sampling for greedy sampling requests.
149
149
target_argmax = target_probs .argmax (dim = - 1 )
150
- rejection_greedy_sample_pytorch (
151
- output_token_ids ,
152
- cu_num_draft_tokens ,
153
- draft_token_ids ,
154
- target_argmax ,
155
- bonus_token_ids ,
156
- is_greedy ,
157
- max_spec_len ,
158
- # num_warps=1,
159
- )
150
+ if min (num_draft_tokens ) == 1 and max (
151
+ num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
152
+ rejection_greedy_sample_spec_len_1_pytorch (
153
+ output_token_ids ,
154
+ draft_token_ids ,
155
+ target_argmax ,
156
+ bonus_token_ids ,
157
+ )
158
+ else :
159
+ rejection_greedy_sample_pytorch (
160
+ output_token_ids ,
161
+ cu_num_draft_tokens ,
162
+ draft_token_ids ,
163
+ target_argmax ,
164
+ bonus_token_ids ,
165
+ num_draft_tokens ,
166
+ max_spec_len ,
167
+ is_greedy ,
168
+ )
160
169
if sampling_metadata .all_greedy :
161
170
return output_token_ids
162
171
@@ -284,47 +293,89 @@ def sample_recovered_tokens(
284
293
return recovered_token_ids
285
294
286
295
287
- def rejection_greedy_sample_pytorch (
288
- output_token_ids , # [batch_size, max_spec_len + 1]
289
- cu_num_draft_tokens , # [batch_size]
290
- draft_token_ids , # [num_tokens]
291
- target_argmax , # [num_tokens]
292
- bonus_token_ids , # [batch_size]
293
- is_greedy = None , # [batch_size] or None
294
- max_spec_len = None ,
296
+ def rejection_greedy_sample_spec_len_1_pytorch (
297
+ output_token_ids , # [batch_size, 2]
298
+ draft_token_ids , # [num_tokens]
299
+ target_argmax , # [num_tokens]
300
+ bonus_token_ids , # [batch_size]
295
301
):
296
- batch_size = output_token_ids .shape [0 ]
297
-
298
- if is_greedy is None :
299
- is_greedy = torch .ones (batch_size ,
300
- dtype = torch .bool ,
301
- device = output_token_ids .device )
302
-
303
- for req_idx in range (batch_size ):
304
- if not is_greedy [req_idx ]:
305
- continue
306
-
307
- if req_idx == 0 :
308
- start_idx = 0
309
- else :
310
- start_idx = cu_num_draft_tokens [req_idx - 1 ].item ()
311
- end_idx = cu_num_draft_tokens [req_idx ].item ()
312
- num_draft_tokens = end_idx - start_idx
313
-
314
- rejected = False
315
- for pos in range (num_draft_tokens ):
316
- if not rejected :
317
- draft_token_id = draft_token_ids [start_idx + pos ].item ()
318
- target_argmax_id = target_argmax [start_idx + pos ].item ()
319
-
320
- output_token_ids [req_idx , pos ] = target_argmax_id
302
+ batch_size = output_token_ids .size (0 )
303
+ num_tokens = draft_token_ids .size (0 )
304
+ assert batch_size == num_tokens
305
+ accept_req_mask = draft_token_ids == target_argmax
306
+ output_token_ids [:, 0 ] = target_argmax
307
+ bonus_token_ids = bonus_token_ids .squeeze (1 )
308
+ output_token_ids [accept_req_mask , 1 ] = bonus_token_ids [accept_req_mask ]
321
309
322
- if draft_token_id != target_argmax_id :
323
- rejected = True
324
310
325
- if not rejected :
326
- bonus_token_id = bonus_token_ids [req_idx ].item ()
327
- output_token_ids [req_idx , num_draft_tokens ] = bonus_token_id
311
+ def rejection_greedy_sample_pytorch (
312
+ output_token_ids , # [batch_size, max_spec_len + 1]
313
+ cu_num_draft_tokens , # [batch_size]
314
+ draft_token_ids , # [num_tokens]
315
+ target_argmax , # [num_tokens]
316
+ bonus_token_ids , # [batch_size]
317
+ draft_tokens_per_req , # [batch_size], list
318
+ max_spec_len ,
319
+ is_greedy = None , # [batch_size] or None
320
+ ):
321
+ batch_size = output_token_ids .size (0 )
322
+ num_tokens = draft_token_ids .size (0 )
323
+ device = output_token_ids .device
324
+ draft_tokens_per_req = torch .tensor (draft_tokens_per_req ).to (
325
+ device , non_blocking = True )
326
+ if is_greedy is None :
327
+ is_greedy = torch .ones (batch_size , dtype = torch .bool , device = device )
328
+
329
+ start_indices = cu_num_draft_tokens - draft_tokens_per_req
330
+ req_ids = torch .arange (batch_size , device = device )
331
+ token_req_ids = torch .repeat_interleave (req_ids , draft_tokens_per_req )
332
+ token_positions = torch .arange (
333
+ num_tokens , device = device ) - start_indices [token_req_ids ]
334
+
335
+ # Find the first mismatch position of each request.
336
+ mismatch_global = (draft_token_ids != target_argmax )
337
+ if max_spec_len == 0 :
338
+ first_mismatch_pos_per_req = torch .zeros (batch_size ,
339
+ dtype = torch .long ,
340
+ device = device )
341
+ else :
342
+ # [bs, max_spec_len]
343
+ pos_matrix = torch .full ((batch_size , max_spec_len ),
344
+ - 1 ,
345
+ dtype = torch .long ,
346
+ device = device )
347
+ pos_matrix [token_req_ids , token_positions ] = token_positions
348
+ mismatch_matrix = torch .full ((batch_size , max_spec_len ),
349
+ False ,
350
+ dtype = torch .bool ,
351
+ device = device )
352
+ mismatch_matrix [token_req_ids , token_positions ] = mismatch_global
353
+ mismatch_positions = torch .where (mismatch_matrix , pos_matrix ,
354
+ max_spec_len * 2 )
355
+ first_mismatch_pos_per_req , _ = torch .min (mismatch_positions , dim = 1 )
356
+ no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2 )
357
+ first_mismatch_pos_per_req [no_mismatch_mask ] = draft_tokens_per_req [
358
+ no_mismatch_mask ]
359
+
360
+ # Copy matched target tokens into output.
361
+ copy_len = torch .minimum (first_mismatch_pos_per_req + 1 ,
362
+ draft_tokens_per_req )
363
+ copy_indices = torch .arange (max_spec_len + 1 ,
364
+ device = device ).expand (batch_size , - 1 )
365
+ copy_mask = copy_indices < copy_len .unsqueeze (1 )
366
+ greedy_mask = is_greedy .unsqueeze (1 )
367
+ final_copy_mask = copy_mask & greedy_mask
368
+ global_idx = start_indices .unsqueeze (1 ) + copy_indices
369
+ output_token_ids [final_copy_mask ] = target_argmax [
370
+ global_idx [final_copy_mask ]].to (output_token_ids .dtype )
371
+ # Fill bonus token.
372
+ needs_bonus = is_greedy & (first_mismatch_pos_per_req
373
+ >= draft_tokens_per_req )
374
+ if torch .any (needs_bonus ):
375
+ bonus_rows = torch .where (needs_bonus )[0 ]
376
+ bonus_cols = draft_tokens_per_req [bonus_rows ]
377
+ bonus_token_ids = bonus_token_ids .squeeze (1 )
378
+ output_token_ids [bonus_rows , bonus_cols ] = bonus_token_ids [bonus_rows ]
328
379
329
380
330
381
def rejection_random_sample_pytorch (
0 commit comments