Skip to content

Commit ebcebee

Browse files
authored
[V1][Spec Decode] Enable spec decode for top-p & top-k sampling (#15063)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent f533b58 commit ebcebee

File tree

3 files changed

+219
-19
lines changed

3 files changed

+219
-19
lines changed

tests/v1/sample/test_rejection_sampler.py

Lines changed: 148 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def create_logits_tensor(output_token_ids: list[list[int]],
3636
def create_sampling_metadata(
3737
all_greedy: bool,
3838
temperature: Optional[torch.Tensor] = None,
39+
top_k: Optional[torch.Tensor] = None,
40+
top_p: Optional[torch.Tensor] = None,
3941
generators: Optional[dict[int, Any]] = None,
4042
) -> SamplingMetadata:
4143
"""Create a v1 sampling metadata object with all_greedy set
@@ -52,8 +54,8 @@ def create_sampling_metadata(
5254
temperature=temperature,
5355
all_greedy=all_greedy,
5456
all_random=not all_greedy,
55-
top_p=None,
56-
top_k=None,
57+
top_p=top_p,
58+
top_k=top_k,
5759
min_p=torch.empty(1, ),
5860
generators=generators,
5961
max_num_logprobs=0,
@@ -462,3 +464,147 @@ def estimate_rejection_sampling_pdf(
462464
density=True)
463465

464466
return hist.hist
467+
468+
469+
def _test_masked_logits(
470+
rejection_sampler,
471+
batch_size: int,
472+
num_draft_tokens: int,
473+
vocab_size: int,
474+
target_logits: torch.Tensor,
475+
unmasked_indices: torch.Tensor,
476+
sampling_metadata: SamplingMetadata,
477+
):
478+
# Set up test parameters
479+
num_tokens = batch_size * num_draft_tokens
480+
481+
# Create random draft probabilities.
482+
draft_probs = torch.rand((num_tokens, vocab_size),
483+
dtype=torch.float32,
484+
device=DEVICE)
485+
draft_probs = F.softmax(draft_probs, dim=-1)
486+
487+
# Randomly sample draft token ids from draft probs
488+
draft_token_ids = torch.multinomial(draft_probs, num_samples=1)
489+
draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens)
490+
draft_token_ids = draft_token_ids.tolist()
491+
492+
# Bonus tokens not used but required
493+
bonus_token_ids = torch.zeros((batch_size, 1),
494+
dtype=torch.int64,
495+
device=DEVICE)
496+
497+
# Create spec decode metadata
498+
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
499+
draft_token_ids,
500+
device=DEVICE,
501+
)
502+
503+
# Run rejection sampling
504+
output_token_ids = rejection_sampler(
505+
spec_decode_metadata,
506+
draft_probs=draft_probs,
507+
target_logits=target_logits,
508+
bonus_token_ids=bonus_token_ids,
509+
sampling_metadata=sampling_metadata,
510+
)
511+
512+
# Remove bonus tokens and reshape
513+
output_token_ids = output_token_ids[:, :-1].flatten().tolist()
514+
515+
# Check that all sampled tokens are within the unmasked indices.
516+
for i in range(num_tokens):
517+
token_id = output_token_ids[i]
518+
if token_id == PLACEHOLDER_TOKEN_ID:
519+
continue
520+
assert token_id in unmasked_indices[i]
521+
522+
523+
@pytest.mark.parametrize("top_k", [1, 5, 99])
524+
def test_top_k(rejection_sampler, top_k):
525+
"""Test rejection sampling with top-k sampling"""
526+
vocab_size = 100
527+
batch_size = 100
528+
num_draft_tokens = 3
529+
num_tokens = batch_size * num_draft_tokens
530+
531+
# Randomly create top-k indices.
532+
top_k_indices = [
533+
torch.randperm(vocab_size, device=DEVICE)[:top_k]
534+
for _ in range(num_tokens)
535+
]
536+
top_k_indices = torch.stack(top_k_indices)
537+
538+
# Create logits with the uniform distribution.
539+
target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)
540+
541+
# Increment the logits for top-k indices, a little bit more than the other
542+
# ones. If the masking is effective, the non-topk indices will never be
543+
# sampled despite the small difference in logits.
544+
for i in range(num_tokens):
545+
target_logits[i, top_k_indices[i]] += 0.1
546+
547+
# Create sampling metadata
548+
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
549+
sampling_metadata = create_sampling_metadata(
550+
all_greedy=False,
551+
temperature=temperature,
552+
top_k=torch.tensor([top_k] * batch_size,
553+
device=DEVICE,
554+
dtype=torch.int64),
555+
)
556+
557+
_test_masked_logits(
558+
rejection_sampler,
559+
batch_size=batch_size,
560+
num_draft_tokens=num_draft_tokens,
561+
vocab_size=vocab_size,
562+
target_logits=target_logits,
563+
unmasked_indices=top_k_indices,
564+
sampling_metadata=sampling_metadata,
565+
)
566+
567+
568+
@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
569+
def test_top_p(rejection_sampler, top_p):
570+
"""Test rejection sampling with top-p sampling"""
571+
vocab_size = 100
572+
batch_size = 100
573+
num_draft_tokens = 3
574+
num_tokens = batch_size * num_draft_tokens
575+
576+
# Create logits with the uniform distribution.
577+
target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
578+
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
579+
rescaled_logits = target_logits / temperature
580+
581+
logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
582+
probs_sort = logits_sort.softmax(dim=-1)
583+
probs_sum = probs_sort.cumsum(dim=-1)
584+
top_p_mask = probs_sum <= 1 - top_p
585+
# at least one
586+
top_p_mask[:, -1] = False
587+
588+
# Get the top-p indices.
589+
top_p_indices = []
590+
for i in range(num_tokens):
591+
top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist())
592+
593+
# Create sampling metadata
594+
sampling_metadata = create_sampling_metadata(
595+
all_greedy=False,
596+
temperature=temperature,
597+
top_p=torch.tensor([top_p] * batch_size,
598+
device=DEVICE,
599+
dtype=torch.float32),
600+
)
601+
602+
_test_masked_logits(
603+
rejection_sampler,
604+
batch_size=batch_size,
605+
num_draft_tokens=num_draft_tokens,
606+
vocab_size=vocab_size,
607+
target_logits=target_logits,
608+
unmasked_indices=top_p_indices,
609+
sampling_metadata=sampling_metadata,
610+
)

vllm/v1/sample/rejection_sampler.py

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from vllm.logger import init_logger
1010
from vllm.v1.sample.metadata import SamplingMetadata
11+
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
1112
from vllm.v1.sample.ops.utils import compiled_softmax
1213
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
1314

@@ -245,25 +246,81 @@ def compute_probs(
245246
return logits
246247

247248
num_tokens = logits.shape[0]
248-
batch_size = cu_num_draft_tokens.shape[0]
249-
expanded_temperature = torch.empty(
250-
(num_tokens, 1),
251-
dtype=torch.float32,
252-
device=logits.device,
253-
)
254-
expand_kernel[(batch_size, )](
255-
expanded_temperature,
249+
temperature = expand_batch_to_tokens(
256250
sampling_metadata.temperature,
257251
cu_num_draft_tokens,
258-
GREEDY_TEMPERATURE, # replace_from
259-
1, # replace_to
260-
MAX_NUM_TOKENS=MAX_SPEC_LEN,
261-
num_warps=1,
252+
num_tokens,
253+
replace_from=GREEDY_TEMPERATURE,
254+
replace_to=1,
262255
)
263-
output_prob = compiled_softmax(logits, expanded_temperature)
256+
# TODO(woosuk): Consider using in-place op to reduce memory usage.
257+
logits = logits / temperature.unsqueeze(-1)
258+
259+
# Get expanded top_k and top_p tensors.
260+
top_k = None
261+
if sampling_metadata.top_k is not None:
262+
top_k = expand_batch_to_tokens(
263+
sampling_metadata.top_k,
264+
cu_num_draft_tokens,
265+
num_tokens,
266+
)
267+
top_p = None
268+
if sampling_metadata.top_p is not None:
269+
top_p = expand_batch_to_tokens(
270+
sampling_metadata.top_p,
271+
cu_num_draft_tokens,
272+
num_tokens,
273+
)
274+
275+
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
276+
# which is slow for large vocab sizes. This may cause performance issues.
277+
logits = apply_top_k_top_p(logits, top_k, top_p)
278+
279+
output_prob = compiled_softmax(logits)
264280
return output_prob
265281

266282

283+
def expand_batch_to_tokens(
284+
x: torch.Tensor, # [batch_size]
285+
cu_num_tokens: torch.Tensor, # [batch_size]
286+
num_tokens: int,
287+
replace_from: int = 0,
288+
replace_to: int = 0,
289+
) -> torch.Tensor:
290+
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
291+
tokens per batch in cu_num_tokens.
292+
293+
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
294+
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
295+
296+
Args:
297+
x: [batch_size] tensor to expand.
298+
cu_num_tokens: [batch_size] tensor containing the cumulative number of
299+
tokens per batch. Each element represents the total number of
300+
tokens up to and including that batch.
301+
num_tokens: Total number of tokens.
302+
replace_from: int = 0
303+
Value to be replaced if it is found in x.
304+
replace_to: int = 0
305+
Value to replace with when replace_from is found.
306+
Returns:
307+
expanded_x: [num_tokens] tensor.
308+
"""
309+
batch_size = x.shape[0]
310+
assert cu_num_tokens.shape[0] == batch_size
311+
expanded_x = x.new_empty(num_tokens)
312+
expand_kernel[(batch_size, )](
313+
expanded_x,
314+
x,
315+
cu_num_tokens,
316+
replace_from,
317+
replace_to,
318+
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
319+
num_warps=1,
320+
)
321+
return expanded_x
322+
323+
267324
def generate_uniform_probs(
268325
num_tokens: int,
269326
num_draft_tokens: list[int],

vllm/v1/spec_decode/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33

44

55
def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
6-
if req_id in input_batch.top_k_reqs or req_id in input_batch.top_p_reqs:
7-
# Spec decode doesn't support top_p/top_k sampling.
8-
return False
9-
elif req_id in input_batch.min_p_reqs:
6+
if req_id in input_batch.min_p_reqs:
107
# Spec decode doesn't support min_p sampling.
118
return False
129
elif (req_id in input_batch.frequency_penalties_reqs

0 commit comments

Comments
 (0)