Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 136 additions & 84 deletions vllm_ascend/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import torch
import torch.nn as nn
import vllm.v1.sample.rejection_sampler as rs
from vllm.triton_utils import tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (RejectionSampler, compute_probs,
generate_uniform_probs)
from vllm.v1.sample.rejection_sampler import RejectionSampler, compute_probs
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata

PLACEHOLDER_TOKEN_ID = -1
Expand Down Expand Up @@ -192,7 +191,7 @@ def rejection_sample(
)

# Rejection sampling for random sampling requests.
rejection_random_sample_pytorch(
rejection_random_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
Expand All @@ -204,7 +203,7 @@ def rejection_sample(
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
NO_DRAFT_PROBS=draft_probs is None,
# num_warps=1,
)
return output_token_ids
Expand Down Expand Up @@ -378,59 +377,66 @@ def rejection_greedy_sample_pytorch(
output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows]


def rejection_random_sample_pytorch(
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
draft_probs, # [num_tokens, vocab_size] or None
target_probs, # [num_tokens, vocab_size]
bonus_token_ids, # [batch_size]
recovered_token_ids, # [num_tokens]
uniform_probs, # [num_tokens]
is_greedy, # [batch_size]
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
IS_NGRAM=False,
NO_DRAFT_PROBS: tl.constexpr,
):
batch_size = output_token_ids.shape[0]

for req_idx in range(batch_size):
if is_greedy[req_idx]:
continue

if req_idx == 0:
start_idx = 0
else:
start_idx = cu_num_draft_tokens[req_idx - 1].item()
end_idx = cu_num_draft_tokens[req_idx].item()
num_draft_tokens = end_idx - start_idx

rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = draft_token_ids[start_idx + pos].item()

if IS_NGRAM:
draft_prob = 1.0
else:
draft_prob = draft_probs[start_idx + pos,
draft_token_id].item()

target_prob = target_probs[start_idx + pos,
draft_token_id].item()
uniform_prob = uniform_probs[start_idx + pos].item()

if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
token_id = draft_token_id
else:
rejected = True
token_id = recovered_token_ids[start_idx + pos].item()

output_token_ids[req_idx, pos] = token_id
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy:
# Early exit for greedy sampling requests.
return

if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx

rejected = False
for pos in range(num_draft_tokens):
if not rejected:
bonus_token_id = bonus_token_ids[req_idx].item()
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept.
token_id = draft_token_id
else:
# Reject. Use recovered token.
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
Comment on lines +409 to +432
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The loop for pos in range(num_draft_tokens): uses a runtime value num_draft_tokens as its bound. This will cause the Triton kernel to recompile for every different value of num_draft_tokens, which can lead to significant performance degradation. To avoid this, you should loop over a compile-time constant or a do_not_specialize argument, like max_spec_len which is already configured as such, and use a mask to handle the variable number of tokens.

    for pos in range(max_spec_len):
        if pos < num_draft_tokens:
            if not rejected:
                draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
                if NO_DRAFT_PROBS:
                    draft_prob = 1
                else:
                    draft_prob = tl.load(draft_probs_ptr +
                                         (start_idx + pos) * vocab_size +
                                         draft_token_id)
                target_prob = tl.load(target_probs_ptr +
                                      (start_idx + pos) * vocab_size +
                                      draft_token_id)
                uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
                # NOTE(woosuk): While the draft probability should never be 0,
                # we check it to avoid NaNs. If it happens to be 0, we reject.
                if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
                    # Accept.
                    token_id = draft_token_id
                else:
                    # Reject. Use recovered token.
                    rejected = True
                    token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
                tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
                         token_id)


if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)


def expand_pytorch(
Expand Down Expand Up @@ -458,6 +464,56 @@ def expand_pytorch(
output_ptr[output_slice] = src_val


def generate_uniform_probs(
num_tokens: int,
num_draft_tokens: list[int],
generators: dict[int, torch.Generator],
device: torch.device,
) -> torch.Tensor:
"""
Generates a batch of uniform random samples, with optional seeding
if available.

This method creates a tensor of shape `(num_tokens, )` filled
with uniform random values in the range [0, 1). If `generators` is provided,
the requests with their own seeds will use the provided `torch.Generator`
for reproducibility. The samples for the other requests will be generated
without a seed.

Args:
num_tokens : int
Total number of tokens.
num_draft_tokens : List[List[int]]
Number of draft tokens per request.
generators : Optional[Dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects.
device : torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand : torch.Tensor
A tensor of shape `(num_tokens, )` containing uniform
random values in the range [0, 1).
"""
uniform_probs = torch.rand(
(num_tokens, ),
dtype=torch.float32,
device=device,
)
start_idx = 0
for req_idx, n in enumerate(num_draft_tokens):
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if n == 0:
continue
end_idx = start_idx + n
generator = generators.get(req_idx)
if generator is not None:
uniform_probs[start_idx:end_idx].uniform_(generator=generator)
start_idx = end_idx
return uniform_probs


def sample_recovered_tokens_pytorch(
output_token_ids, # [num_tokens]
cu_num_draft_tokens, # [batch_size]
Expand All @@ -468,37 +524,33 @@ def sample_recovered_tokens_pytorch(
vocab_size,
IS_NGRAM=False,
):
batch_size = len(cu_num_draft_tokens)

for req_idx in range(batch_size):
start_idx = 0 if req_idx == 0 else cu_num_draft_tokens[req_idx - 1]
end_idx = cu_num_draft_tokens[req_idx]
num_draft_tokens = end_idx - start_idx

for pos in range(num_draft_tokens):
token_idx = start_idx + pos

if IS_NGRAM:
draft_token_id = draft_token_ids[token_idx]
orig_prob = target_probs[token_idx, draft_token_id].item()
target_probs[token_idx, draft_token_id] = 0
prob = target_probs[token_idx].clone()
else:
draft_p = draft_probs[token_idx].clone()
target_p = target_probs[token_idx].clone()
prob = torch.maximum(target_p - draft_p,
torch.tensor(0.0, device=target_p.device))

q_values = torch.full((vocab_size, ),
float('-inf'),
device=q.device)
q_values[:vocab_size] = q[req_idx, :vocab_size]

recovered_id = torch.argmax(prob / q_values).item()
output_token_ids[token_idx] = recovered_id
num_tokens = len(draft_token_ids)
device = output_token_ids.device

if IS_NGRAM:
target_probs[token_idx, draft_token_id] = orig_prob
diff = torch.diff(cu_num_draft_tokens,
prepend=torch.tensor([0], device=device))
q_value_new = torch.repeat_interleave(q, diff, dim=0)

token_positions = torch.arange(num_tokens, device=device)

rs.expand_batch_to_tokens = expand_batch_to_tokens
if IS_NGRAM:
orig_prob = target_probs[token_positions, draft_token_ids]
target_probs[token_positions, draft_token_ids] = 0
prob = target_probs[token_positions].clone()
else:
draft_p = draft_probs[token_positions].clone()
target_p = target_probs[token_positions].clone()
prob = torch.maximum(target_p - draft_p,
torch.tensor(0.0, device=target_p.device))

q_values = torch.full((num_tokens, vocab_size),
float('-inf'),
device=q.device)
q_values[:vocab_size] = q_value_new[token_positions, :vocab_size]
Comment on lines +546 to +549
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The construction of q_values is incorrect. q_values[:vocab_size] slices the first vocab_size rows of the q_values tensor, but you are attempting to assign q_value_new which has num_tokens rows. This will cause a shape mismatch error at runtime if num_tokens is not equal to vocab_size. Since q_value_new already has the desired shape and values, you can directly assign it to q_values and remove the unnecessary initialization with torch.full.

    q_values = q_value_new


recovered_id = torch.argmax(prob / q_values, dim=-1)
output_token_ids[token_positions] = recovered_id.to(
dtype=output_token_ids.dtype)

if IS_NGRAM:
target_probs[token_positions, draft_token_ids] = orig_prob
Loading