Skip to content

Conversation

@iKunHvv
Copy link

@iKunHvv iKunHvv commented Nov 25, 2025

What this PR does / why we need it?

Does this PR introduce any user-facing change?

How was this patch tested?

@iKunHvv
Copy link
Author

iKunHvv commented Nov 25, 2025

ok

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several Triton kernels to optimize the rejection sampling process in vllm-ascend, replacing existing PyTorch implementations. The changes aim to improve performance by leveraging GPU-specific optimizations. The PR also includes refactoring of helper functions for better vectorization and clarity. My review identified a critical bug in the new rejection_greedy_sample_kernel where a condition to filter non-greedy requests is incorrect, leading to wrong behavior. I also found a high-severity issue in the refactored sample_recovered_tokens_pytorch function, where an incorrect tensor slicing could cause runtime errors. I've provided suggestions to fix both issues.

Comment on lines +462 to +464
if is_greedy is None:
# Early exit for non-greedy sampling requests.
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The condition is_greedy is None is incorrect for a Triton kernel. The is_greedy variable, being a Triton boolean scalar (tl.int1), can never be Python's None. This bug prevents the intended early exit for non-greedy sampling requests, causing them to be processed incorrectly by this kernel designed for greedy sampling. The check should be if not is_greedy: to correctly handle non-greedy cases.

Suggested change
if is_greedy is None:
# Early exit for non-greedy sampling requests.
return
if not is_greedy:
# Early exit for non-greedy sampling requests.
return

Comment on lines +727 to +730
q_values = torch.full((num_tokens, vocab_size),
float('-inf'),
device=q.device)
q_values[:vocab_size] = q[req_idx, :vocab_size]

recovered_id = torch.argmax(prob / q_values).item()
output_token_ids[token_idx] = recovered_id
q_values[:vocab_size] = q_value_new[token_positions, :vocab_size]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current initialization and population of q_values is incorrect and overly complex. The row-wise slice q_values[:vocab_size] can cause a shape mismatch error if num_tokens > vocab_size. This block can be simplified by directly assigning q_value_new to q_values, which also removes the redundant torch.full initialization.

    q_values = q_value_new

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant