Skip to content

Conversation

lio1226
Copy link

@lio1226 lio1226 commented Sep 30, 2025

What this PR does / why we need it?

We optimized the operators in rejection sampler by triton, improving the performance of eagle-3.

Does this PR introduce any user-facing change?

No

How was this patch tested?

NA

Co-authored-by: QilaiZhang ([email protected] )

Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the rejection sampler by replacing a Python-based random sampling function with a Triton kernel and vectorizing another PyTorch function. The changes are aimed at improving performance. My review focuses on two points: a performance issue in the new Triton kernel due to potential recompilations, and a critical bug in the vectorized PyTorch function that could lead to runtime errors. I've provided suggestions to fix both issues.

Comment on lines +546 to +549
q_values = torch.full((num_tokens, vocab_size),
float('-inf'),
device=q.device)
q_values[:vocab_size] = q_value_new[token_positions, :vocab_size]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The construction of q_values is incorrect. q_values[:vocab_size] slices the first vocab_size rows of the q_values tensor, but you are attempting to assign q_value_new which has num_tokens rows. This will cause a shape mismatch error at runtime if num_tokens is not equal to vocab_size. Since q_value_new already has the desired shape and values, you can directly assign it to q_values and remove the unnecessary initialization with torch.full.

    q_values = q_value_new

Comment on lines +409 to +432
for pos in range(num_draft_tokens):
if not rejected:
bonus_token_id = bonus_token_ids[req_idx].item()
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept.
token_id = draft_token_id
else:
# Reject. Use recovered token.
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The loop for pos in range(num_draft_tokens): uses a runtime value num_draft_tokens as its bound. This will cause the Triton kernel to recompile for every different value of num_draft_tokens, which can lead to significant performance degradation. To avoid this, you should loop over a compile-time constant or a do_not_specialize argument, like max_spec_len which is already configured as such, and use a mask to handle the variable number of tokens.

    for pos in range(max_spec_len):
        if pos < num_draft_tokens:
            if not rejected:
                draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
                if NO_DRAFT_PROBS:
                    draft_prob = 1
                else:
                    draft_prob = tl.load(draft_probs_ptr +
                                         (start_idx + pos) * vocab_size +
                                         draft_token_id)
                target_prob = tl.load(target_probs_ptr +
                                      (start_idx + pos) * vocab_size +
                                      draft_token_id)
                uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
                # NOTE(woosuk): While the draft probability should never be 0,
                # we check it to avoid NaNs. If it happens to be 0, we reject.
                if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
                    # Accept.
                    token_id = draft_token_id
                else:
                    # Reject. Use recovered token.
                    rejected = True
                    token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
                tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
                         token_id)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant