-
Notifications
You must be signed in to change notification settings - Fork 468
[Refactor] optimize operators in rejection sampler #3301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,10 +3,9 @@ | |
|
||
import torch | ||
import torch.nn as nn | ||
import vllm.v1.sample.rejection_sampler as rs | ||
from vllm.triton_utils import tl, triton | ||
from vllm.v1.sample.metadata import SamplingMetadata | ||
from vllm.v1.sample.rejection_sampler import (RejectionSampler, compute_probs, | ||
generate_uniform_probs) | ||
from vllm.v1.sample.rejection_sampler import RejectionSampler, compute_probs | ||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata | ||
|
||
PLACEHOLDER_TOKEN_ID = -1 | ||
|
@@ -192,7 +191,7 @@ def rejection_sample( | |
) | ||
|
||
# Rejection sampling for random sampling requests. | ||
rejection_random_sample_pytorch( | ||
rejection_random_sample_kernel[(batch_size, )]( | ||
output_token_ids, | ||
cu_num_draft_tokens, | ||
draft_token_ids, | ||
|
@@ -204,7 +203,7 @@ def rejection_sample( | |
is_greedy, | ||
max_spec_len, | ||
vocab_size, | ||
IS_NGRAM=draft_probs is None, | ||
NO_DRAFT_PROBS=draft_probs is None, | ||
# num_warps=1, | ||
) | ||
return output_token_ids | ||
|
@@ -378,59 +377,66 @@ def rejection_greedy_sample_pytorch( | |
output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows] | ||
|
||
|
||
def rejection_random_sample_pytorch( | ||
output_token_ids, # [batch_size, max_spec_len + 1] | ||
cu_num_draft_tokens, # [batch_size] | ||
draft_token_ids, # [num_tokens] | ||
draft_probs, # [num_tokens, vocab_size] or None | ||
target_probs, # [num_tokens, vocab_size] | ||
bonus_token_ids, # [batch_size] | ||
recovered_token_ids, # [num_tokens] | ||
uniform_probs, # [num_tokens] | ||
is_greedy, # [batch_size] | ||
@triton.jit(do_not_specialize=["max_spec_len"]) | ||
def rejection_random_sample_kernel( | ||
output_token_ids_ptr, # [batch_size, max_spec_len + 1] | ||
cu_num_draft_tokens_ptr, # [batch_size] | ||
draft_token_ids_ptr, # [num_tokens] | ||
draft_probs_ptr, # [num_tokens, vocab_size] or None | ||
target_probs_ptr, # [num_tokens, vocab_size] | ||
bonus_token_ids_ptr, # [batch_size] | ||
recovered_token_ids_ptr, # [num_tokens] | ||
uniform_probs_ptr, # [num_tokens] | ||
is_greedy_ptr, # [batch_size] | ||
max_spec_len, | ||
vocab_size, | ||
IS_NGRAM=False, | ||
NO_DRAFT_PROBS: tl.constexpr, | ||
): | ||
batch_size = output_token_ids.shape[0] | ||
|
||
for req_idx in range(batch_size): | ||
if is_greedy[req_idx]: | ||
continue | ||
|
||
if req_idx == 0: | ||
start_idx = 0 | ||
else: | ||
start_idx = cu_num_draft_tokens[req_idx - 1].item() | ||
end_idx = cu_num_draft_tokens[req_idx].item() | ||
num_draft_tokens = end_idx - start_idx | ||
|
||
rejected = False | ||
for pos in range(num_draft_tokens): | ||
if not rejected: | ||
draft_token_id = draft_token_ids[start_idx + pos].item() | ||
|
||
if IS_NGRAM: | ||
draft_prob = 1.0 | ||
else: | ||
draft_prob = draft_probs[start_idx + pos, | ||
draft_token_id].item() | ||
|
||
target_prob = target_probs[start_idx + pos, | ||
draft_token_id].item() | ||
uniform_prob = uniform_probs[start_idx + pos].item() | ||
|
||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: | ||
token_id = draft_token_id | ||
else: | ||
rejected = True | ||
token_id = recovered_token_ids[start_idx + pos].item() | ||
|
||
output_token_ids[req_idx, pos] = token_id | ||
req_idx = tl.program_id(0) | ||
is_greedy = tl.load(is_greedy_ptr + req_idx) | ||
if is_greedy: | ||
# Early exit for greedy sampling requests. | ||
return | ||
|
||
if req_idx == 0: | ||
start_idx = 0 | ||
else: | ||
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) | ||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) | ||
num_draft_tokens = end_idx - start_idx | ||
|
||
rejected = False | ||
for pos in range(num_draft_tokens): | ||
if not rejected: | ||
bonus_token_id = bonus_token_ids[req_idx].item() | ||
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id | ||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) | ||
if NO_DRAFT_PROBS: | ||
draft_prob = 1 | ||
else: | ||
draft_prob = tl.load(draft_probs_ptr + | ||
(start_idx + pos) * vocab_size + | ||
draft_token_id) | ||
target_prob = tl.load(target_probs_ptr + | ||
(start_idx + pos) * vocab_size + | ||
draft_token_id) | ||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) | ||
# NOTE(woosuk): While the draft probability should never be 0, | ||
# we check it to avoid NaNs. If it happens to be 0, we reject. | ||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: | ||
# Accept. | ||
token_id = draft_token_id | ||
else: | ||
# Reject. Use recovered token. | ||
rejected = True | ||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) | ||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, | ||
token_id) | ||
|
||
if not rejected: | ||
# If all tokens are accepted, append the bonus token. | ||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) | ||
tl.store( | ||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + | ||
num_draft_tokens, bonus_token_id) | ||
|
||
|
||
def expand_pytorch( | ||
|
@@ -458,6 +464,56 @@ def expand_pytorch( | |
output_ptr[output_slice] = src_val | ||
|
||
|
||
def generate_uniform_probs( | ||
num_tokens: int, | ||
num_draft_tokens: list[int], | ||
generators: dict[int, torch.Generator], | ||
device: torch.device, | ||
) -> torch.Tensor: | ||
""" | ||
Generates a batch of uniform random samples, with optional seeding | ||
if available. | ||
|
||
This method creates a tensor of shape `(num_tokens, )` filled | ||
with uniform random values in the range [0, 1). If `generators` is provided, | ||
the requests with their own seeds will use the provided `torch.Generator` | ||
for reproducibility. The samples for the other requests will be generated | ||
without a seed. | ||
|
||
Args: | ||
num_tokens : int | ||
Total number of tokens. | ||
num_draft_tokens : List[List[int]] | ||
Number of draft tokens per request. | ||
generators : Optional[Dict[int, torch.Generator]] | ||
A dictionary mapping indices in the batch to | ||
`torch.Generator` objects. | ||
device : torch.device | ||
The device on which to allocate the tensor. | ||
Returns: | ||
uniform_rand : torch.Tensor | ||
A tensor of shape `(num_tokens, )` containing uniform | ||
random values in the range [0, 1). | ||
""" | ||
uniform_probs = torch.rand( | ||
(num_tokens, ), | ||
dtype=torch.float32, | ||
device=device, | ||
) | ||
start_idx = 0 | ||
for req_idx, n in enumerate(num_draft_tokens): | ||
# Do not generate random numbers for requests with no draft tokens. | ||
# This can be important for reproducibility. | ||
if n == 0: | ||
continue | ||
end_idx = start_idx + n | ||
generator = generators.get(req_idx) | ||
if generator is not None: | ||
uniform_probs[start_idx:end_idx].uniform_(generator=generator) | ||
start_idx = end_idx | ||
return uniform_probs | ||
|
||
|
||
def sample_recovered_tokens_pytorch( | ||
output_token_ids, # [num_tokens] | ||
cu_num_draft_tokens, # [batch_size] | ||
|
@@ -468,37 +524,33 @@ def sample_recovered_tokens_pytorch( | |
vocab_size, | ||
IS_NGRAM=False, | ||
): | ||
batch_size = len(cu_num_draft_tokens) | ||
|
||
for req_idx in range(batch_size): | ||
start_idx = 0 if req_idx == 0 else cu_num_draft_tokens[req_idx - 1] | ||
end_idx = cu_num_draft_tokens[req_idx] | ||
num_draft_tokens = end_idx - start_idx | ||
|
||
for pos in range(num_draft_tokens): | ||
token_idx = start_idx + pos | ||
|
||
if IS_NGRAM: | ||
draft_token_id = draft_token_ids[token_idx] | ||
orig_prob = target_probs[token_idx, draft_token_id].item() | ||
target_probs[token_idx, draft_token_id] = 0 | ||
prob = target_probs[token_idx].clone() | ||
else: | ||
draft_p = draft_probs[token_idx].clone() | ||
target_p = target_probs[token_idx].clone() | ||
prob = torch.maximum(target_p - draft_p, | ||
torch.tensor(0.0, device=target_p.device)) | ||
|
||
q_values = torch.full((vocab_size, ), | ||
float('-inf'), | ||
device=q.device) | ||
q_values[:vocab_size] = q[req_idx, :vocab_size] | ||
|
||
recovered_id = torch.argmax(prob / q_values).item() | ||
output_token_ids[token_idx] = recovered_id | ||
num_tokens = len(draft_token_ids) | ||
device = output_token_ids.device | ||
|
||
if IS_NGRAM: | ||
target_probs[token_idx, draft_token_id] = orig_prob | ||
diff = torch.diff(cu_num_draft_tokens, | ||
prepend=torch.tensor([0], device=device)) | ||
q_value_new = torch.repeat_interleave(q, diff, dim=0) | ||
|
||
token_positions = torch.arange(num_tokens, device=device) | ||
|
||
rs.expand_batch_to_tokens = expand_batch_to_tokens | ||
if IS_NGRAM: | ||
orig_prob = target_probs[token_positions, draft_token_ids] | ||
target_probs[token_positions, draft_token_ids] = 0 | ||
prob = target_probs[token_positions].clone() | ||
else: | ||
draft_p = draft_probs[token_positions].clone() | ||
target_p = target_probs[token_positions].clone() | ||
prob = torch.maximum(target_p - draft_p, | ||
torch.tensor(0.0, device=target_p.device)) | ||
|
||
q_values = torch.full((num_tokens, vocab_size), | ||
float('-inf'), | ||
device=q.device) | ||
q_values[:vocab_size] = q_value_new[token_positions, :vocab_size] | ||
Comment on lines
+546
to
+549
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The construction of q_values = q_value_new |
||
|
||
recovered_id = torch.argmax(prob / q_values, dim=-1) | ||
output_token_ids[token_positions] = recovered_id.to( | ||
dtype=output_token_ids.dtype) | ||
|
||
if IS_NGRAM: | ||
target_probs[token_positions, draft_token_ids] = orig_prob |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop
for pos in range(num_draft_tokens):
uses a runtime valuenum_draft_tokens
as its bound. This will cause the Triton kernel to recompile for every different value ofnum_draft_tokens
, which can lead to significant performance degradation. To avoid this, you should loop over a compile-time constant or ado_not_specialize
argument, likemax_spec_len
which is already configured as such, and use a mask to handle the variable number of tokens.