diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index e0d770df26..124e8c831f 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -3,10 +3,9 @@ import torch import torch.nn as nn -import vllm.v1.sample.rejection_sampler as rs +from vllm.triton_utils import tl, triton from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import (RejectionSampler, compute_probs, - generate_uniform_probs) +from vllm.v1.sample.rejection_sampler import RejectionSampler, compute_probs from vllm.v1.spec_decode.metadata import SpecDecodeMetadata PLACEHOLDER_TOKEN_ID = -1 @@ -192,7 +191,7 @@ def rejection_sample( ) # Rejection sampling for random sampling requests. - rejection_random_sample_pytorch( + rejection_random_sample_kernel[(batch_size, )]( output_token_ids, cu_num_draft_tokens, draft_token_ids, @@ -204,7 +203,7 @@ def rejection_sample( is_greedy, max_spec_len, vocab_size, - IS_NGRAM=draft_probs is None, + NO_DRAFT_PROBS=draft_probs is None, # num_warps=1, ) return output_token_ids @@ -378,59 +377,66 @@ def rejection_greedy_sample_pytorch( output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows] -def rejection_random_sample_pytorch( - output_token_ids, # [batch_size, max_spec_len + 1] - cu_num_draft_tokens, # [batch_size] - draft_token_ids, # [num_tokens] - draft_probs, # [num_tokens, vocab_size] or None - target_probs, # [num_tokens, vocab_size] - bonus_token_ids, # [batch_size] - recovered_token_ids, # [num_tokens] - uniform_probs, # [num_tokens] - is_greedy, # [batch_size] +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_random_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + bonus_token_ids_ptr, # [batch_size] + recovered_token_ids_ptr, # [num_tokens] + uniform_probs_ptr, # [num_tokens] + is_greedy_ptr, # [batch_size] max_spec_len, vocab_size, - IS_NGRAM=False, + NO_DRAFT_PROBS: tl.constexpr, ): - batch_size = output_token_ids.shape[0] - - for req_idx in range(batch_size): - if is_greedy[req_idx]: - continue - - if req_idx == 0: - start_idx = 0 - else: - start_idx = cu_num_draft_tokens[req_idx - 1].item() - end_idx = cu_num_draft_tokens[req_idx].item() - num_draft_tokens = end_idx - start_idx - - rejected = False - for pos in range(num_draft_tokens): - if not rejected: - draft_token_id = draft_token_ids[start_idx + pos].item() - - if IS_NGRAM: - draft_prob = 1.0 - else: - draft_prob = draft_probs[start_idx + pos, - draft_token_id].item() - - target_prob = target_probs[start_idx + pos, - draft_token_id].item() - uniform_prob = uniform_probs[start_idx + pos].item() - - if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: - token_id = draft_token_id - else: - rejected = True - token_id = recovered_token_ids[start_idx + pos].item() - - output_token_ids[req_idx, pos] = token_id + req_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + req_idx) + if is_greedy: + # Early exit for greedy sampling requests. + return + + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + rejected = False + for pos in range(num_draft_tokens): if not rejected: - bonus_token_id = bonus_token_ids[req_idx].item() - output_token_ids[req_idx, num_draft_tokens] = bonus_token_id + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + if NO_DRAFT_PROBS: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + # NOTE(woosuk): While the draft probability should never be 0, + # we check it to avoid NaNs. If it happens to be 0, we reject. + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: + # Accept. + token_id = draft_token_id + else: + # Reject. Use recovered token. + rejected = True + token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + token_id) + + if not rejected: + # If all tokens are accepted, append the bonus token. + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, bonus_token_id) def expand_pytorch( @@ -458,6 +464,56 @@ def expand_pytorch( output_ptr[output_slice] = src_val +def generate_uniform_probs( + num_tokens: int, + num_draft_tokens: list[int], + generators: dict[int, torch.Generator], + device: torch.device, +) -> torch.Tensor: + """ + Generates a batch of uniform random samples, with optional seeding + if available. + + This method creates a tensor of shape `(num_tokens, )` filled + with uniform random values in the range [0, 1). If `generators` is provided, + the requests with their own seeds will use the provided `torch.Generator` + for reproducibility. The samples for the other requests will be generated + without a seed. + + Args: + num_tokens : int + Total number of tokens. + num_draft_tokens : List[List[int]] + Number of draft tokens per request. + generators : Optional[Dict[int, torch.Generator]] + A dictionary mapping indices in the batch to + `torch.Generator` objects. + device : torch.device + The device on which to allocate the tensor. + Returns: + uniform_rand : torch.Tensor + A tensor of shape `(num_tokens, )` containing uniform + random values in the range [0, 1). + """ + uniform_probs = torch.rand( + (num_tokens, ), + dtype=torch.float32, + device=device, + ) + start_idx = 0 + for req_idx, n in enumerate(num_draft_tokens): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. + if n == 0: + continue + end_idx = start_idx + n + generator = generators.get(req_idx) + if generator is not None: + uniform_probs[start_idx:end_idx].uniform_(generator=generator) + start_idx = end_idx + return uniform_probs + + def sample_recovered_tokens_pytorch( output_token_ids, # [num_tokens] cu_num_draft_tokens, # [batch_size] @@ -468,37 +524,33 @@ def sample_recovered_tokens_pytorch( vocab_size, IS_NGRAM=False, ): - batch_size = len(cu_num_draft_tokens) - - for req_idx in range(batch_size): - start_idx = 0 if req_idx == 0 else cu_num_draft_tokens[req_idx - 1] - end_idx = cu_num_draft_tokens[req_idx] - num_draft_tokens = end_idx - start_idx - - for pos in range(num_draft_tokens): - token_idx = start_idx + pos - - if IS_NGRAM: - draft_token_id = draft_token_ids[token_idx] - orig_prob = target_probs[token_idx, draft_token_id].item() - target_probs[token_idx, draft_token_id] = 0 - prob = target_probs[token_idx].clone() - else: - draft_p = draft_probs[token_idx].clone() - target_p = target_probs[token_idx].clone() - prob = torch.maximum(target_p - draft_p, - torch.tensor(0.0, device=target_p.device)) - - q_values = torch.full((vocab_size, ), - float('-inf'), - device=q.device) - q_values[:vocab_size] = q[req_idx, :vocab_size] - - recovered_id = torch.argmax(prob / q_values).item() - output_token_ids[token_idx] = recovered_id + num_tokens = len(draft_token_ids) + device = output_token_ids.device - if IS_NGRAM: - target_probs[token_idx, draft_token_id] = orig_prob + diff = torch.diff(cu_num_draft_tokens, + prepend=torch.tensor([0], device=device)) + q_value_new = torch.repeat_interleave(q, diff, dim=0) + token_positions = torch.arange(num_tokens, device=device) -rs.expand_batch_to_tokens = expand_batch_to_tokens + if IS_NGRAM: + orig_prob = target_probs[token_positions, draft_token_ids] + target_probs[token_positions, draft_token_ids] = 0 + prob = target_probs[token_positions].clone() + else: + draft_p = draft_probs[token_positions].clone() + target_p = target_probs[token_positions].clone() + prob = torch.maximum(target_p - draft_p, + torch.tensor(0.0, device=target_p.device)) + + q_values = torch.full((num_tokens, vocab_size), + float('-inf'), + device=q.device) + q_values[:vocab_size] = q_value_new[token_positions, :vocab_size] + + recovered_id = torch.argmax(prob / q_values, dim=-1) + output_token_ids[token_positions] = recovered_id.to( + dtype=output_token_ids.dtype) + + if IS_NGRAM: + target_probs[token_positions, draft_token_ids] = orig_prob