-
Notifications
You must be signed in to change notification settings - Fork 468
[Refactor] optimize operators in rejection sampler #3301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: lio <[email protected]>
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request optimizes the rejection sampler by replacing a Python-based random sampling function with a Triton kernel and vectorizing another PyTorch function. The changes are aimed at improving performance. My review focuses on two points: a performance issue in the new Triton kernel due to potential recompilations, and a critical bug in the vectorized PyTorch function that could lead to runtime errors. I've provided suggestions to fix both issues.
q_values = torch.full((num_tokens, vocab_size), | ||
float('-inf'), | ||
device=q.device) | ||
q_values[:vocab_size] = q_value_new[token_positions, :vocab_size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The construction of q_values
is incorrect. q_values[:vocab_size]
slices the first vocab_size
rows of the q_values
tensor, but you are attempting to assign q_value_new
which has num_tokens
rows. This will cause a shape mismatch error at runtime if num_tokens
is not equal to vocab_size
. Since q_value_new
already has the desired shape and values, you can directly assign it to q_values
and remove the unnecessary initialization with torch.full
.
q_values = q_value_new
for pos in range(num_draft_tokens): | ||
if not rejected: | ||
bonus_token_id = bonus_token_ids[req_idx].item() | ||
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id | ||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) | ||
if NO_DRAFT_PROBS: | ||
draft_prob = 1 | ||
else: | ||
draft_prob = tl.load(draft_probs_ptr + | ||
(start_idx + pos) * vocab_size + | ||
draft_token_id) | ||
target_prob = tl.load(target_probs_ptr + | ||
(start_idx + pos) * vocab_size + | ||
draft_token_id) | ||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) | ||
# NOTE(woosuk): While the draft probability should never be 0, | ||
# we check it to avoid NaNs. If it happens to be 0, we reject. | ||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: | ||
# Accept. | ||
token_id = draft_token_id | ||
else: | ||
# Reject. Use recovered token. | ||
rejected = True | ||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) | ||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, | ||
token_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop for pos in range(num_draft_tokens):
uses a runtime value num_draft_tokens
as its bound. This will cause the Triton kernel to recompile for every different value of num_draft_tokens
, which can lead to significant performance degradation. To avoid this, you should loop over a compile-time constant or a do_not_specialize
argument, like max_spec_len
which is already configured as such, and use a mask to handle the variable number of tokens.
for pos in range(max_spec_len):
if pos < num_draft_tokens:
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept.
token_id = draft_token_id
else:
# Reject. Use recovered token.
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
What this PR does / why we need it?
We optimized the operators in rejection sampler by triton, improving the performance of eagle-3.
Does this PR introduce any user-facing change?
No
How was this patch tested?
NA
Co-authored-by: QilaiZhang ([email protected] )