@@ -36,6 +36,8 @@ def create_logits_tensor(output_token_ids: list[list[int]],
36
36
def create_sampling_metadata (
37
37
all_greedy : bool ,
38
38
temperature : Optional [torch .Tensor ] = None ,
39
+ top_k : Optional [torch .Tensor ] = None ,
40
+ top_p : Optional [torch .Tensor ] = None ,
39
41
generators : Optional [dict [int , Any ]] = None ,
40
42
) -> SamplingMetadata :
41
43
"""Create a v1 sampling metadata object with all_greedy set
@@ -52,8 +54,8 @@ def create_sampling_metadata(
52
54
temperature = temperature ,
53
55
all_greedy = all_greedy ,
54
56
all_random = not all_greedy ,
55
- top_p = None ,
56
- top_k = None ,
57
+ top_p = top_p ,
58
+ top_k = top_k ,
57
59
min_p = torch .empty (1 , ),
58
60
generators = generators ,
59
61
max_num_logprobs = 0 ,
@@ -462,3 +464,147 @@ def estimate_rejection_sampling_pdf(
462
464
density = True )
463
465
464
466
return hist .hist
467
+
468
+
469
+ def _test_masked_logits (
470
+ rejection_sampler ,
471
+ batch_size : int ,
472
+ num_draft_tokens : int ,
473
+ vocab_size : int ,
474
+ target_logits : torch .Tensor ,
475
+ unmasked_indices : torch .Tensor ,
476
+ sampling_metadata : SamplingMetadata ,
477
+ ):
478
+ # Set up test parameters
479
+ num_tokens = batch_size * num_draft_tokens
480
+
481
+ # Create random draft probabilities.
482
+ draft_probs = torch .rand ((num_tokens , vocab_size ),
483
+ dtype = torch .float32 ,
484
+ device = DEVICE )
485
+ draft_probs = F .softmax (draft_probs , dim = - 1 )
486
+
487
+ # Randomly sample draft token ids from draft probs
488
+ draft_token_ids = torch .multinomial (draft_probs , num_samples = 1 )
489
+ draft_token_ids = draft_token_ids .reshape (batch_size , num_draft_tokens )
490
+ draft_token_ids = draft_token_ids .tolist ()
491
+
492
+ # Bonus tokens not used but required
493
+ bonus_token_ids = torch .zeros ((batch_size , 1 ),
494
+ dtype = torch .int64 ,
495
+ device = DEVICE )
496
+
497
+ # Create spec decode metadata
498
+ spec_decode_metadata = SpecDecodeMetadata .make_dummy (
499
+ draft_token_ids ,
500
+ device = DEVICE ,
501
+ )
502
+
503
+ # Run rejection sampling
504
+ output_token_ids = rejection_sampler (
505
+ spec_decode_metadata ,
506
+ draft_probs = draft_probs ,
507
+ target_logits = target_logits ,
508
+ bonus_token_ids = bonus_token_ids ,
509
+ sampling_metadata = sampling_metadata ,
510
+ )
511
+
512
+ # Remove bonus tokens and reshape
513
+ output_token_ids = output_token_ids [:, :- 1 ].flatten ().tolist ()
514
+
515
+ # Check that all sampled tokens are within the unmasked indices.
516
+ for i in range (num_tokens ):
517
+ token_id = output_token_ids [i ]
518
+ if token_id == PLACEHOLDER_TOKEN_ID :
519
+ continue
520
+ assert token_id in unmasked_indices [i ]
521
+
522
+
523
+ @pytest .mark .parametrize ("top_k" , [1 , 5 , 99 ])
524
+ def test_top_k (rejection_sampler , top_k ):
525
+ """Test rejection sampling with top-k sampling"""
526
+ vocab_size = 100
527
+ batch_size = 100
528
+ num_draft_tokens = 3
529
+ num_tokens = batch_size * num_draft_tokens
530
+
531
+ # Randomly create top-k indices.
532
+ top_k_indices = [
533
+ torch .randperm (vocab_size , device = DEVICE )[:top_k ]
534
+ for _ in range (num_tokens )
535
+ ]
536
+ top_k_indices = torch .stack (top_k_indices )
537
+
538
+ # Create logits with the uniform distribution.
539
+ target_logits = torch .zeros ((num_tokens , vocab_size ), device = DEVICE )
540
+
541
+ # Increment the logits for top-k indices, a little bit more than the other
542
+ # ones. If the masking is effective, the non-topk indices will never be
543
+ # sampled despite the small difference in logits.
544
+ for i in range (num_tokens ):
545
+ target_logits [i , top_k_indices [i ]] += 0.1
546
+
547
+ # Create sampling metadata
548
+ temperature = torch .ones (batch_size , dtype = torch .float32 , device = DEVICE )
549
+ sampling_metadata = create_sampling_metadata (
550
+ all_greedy = False ,
551
+ temperature = temperature ,
552
+ top_k = torch .tensor ([top_k ] * batch_size ,
553
+ device = DEVICE ,
554
+ dtype = torch .int64 ),
555
+ )
556
+
557
+ _test_masked_logits (
558
+ rejection_sampler ,
559
+ batch_size = batch_size ,
560
+ num_draft_tokens = num_draft_tokens ,
561
+ vocab_size = vocab_size ,
562
+ target_logits = target_logits ,
563
+ unmasked_indices = top_k_indices ,
564
+ sampling_metadata = sampling_metadata ,
565
+ )
566
+
567
+
568
+ @pytest .mark .parametrize ("top_p" , [0.5 , 0.9 , 0.99 ])
569
+ def test_top_p (rejection_sampler , top_p ):
570
+ """Test rejection sampling with top-p sampling"""
571
+ vocab_size = 100
572
+ batch_size = 100
573
+ num_draft_tokens = 3
574
+ num_tokens = batch_size * num_draft_tokens
575
+
576
+ # Create logits with the uniform distribution.
577
+ target_logits = torch .randn ((num_tokens , vocab_size ), device = DEVICE )
578
+ temperature = torch .ones (batch_size , dtype = torch .float32 , device = DEVICE )
579
+ rescaled_logits = target_logits / temperature
580
+
581
+ logits_sort , logits_idx = rescaled_logits .sort (dim = - 1 , descending = False )
582
+ probs_sort = logits_sort .softmax (dim = - 1 )
583
+ probs_sum = probs_sort .cumsum (dim = - 1 )
584
+ top_p_mask = probs_sum <= 1 - top_p
585
+ # at least one
586
+ top_p_mask [:, - 1 ] = False
587
+
588
+ # Get the top-p indices.
589
+ top_p_indices = []
590
+ for i in range (num_tokens ):
591
+ top_p_indices .append (logits_idx [i ][~ top_p_mask [i ]].tolist ())
592
+
593
+ # Create sampling metadata
594
+ sampling_metadata = create_sampling_metadata (
595
+ all_greedy = False ,
596
+ temperature = temperature ,
597
+ top_p = torch .tensor ([top_p ] * batch_size ,
598
+ device = DEVICE ,
599
+ dtype = torch .float32 ),
600
+ )
601
+
602
+ _test_masked_logits (
603
+ rejection_sampler ,
604
+ batch_size = batch_size ,
605
+ num_draft_tokens = num_draft_tokens ,
606
+ vocab_size = vocab_size ,
607
+ target_logits = target_logits ,
608
+ unmasked_indices = top_p_indices ,
609
+ sampling_metadata = sampling_metadata ,
610
+ )
0 commit comments