-
Notifications
You must be signed in to change notification settings - Fork 468
bugfix for mtp #3300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
bugfix for mtp #3300
Conversation
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix a bug in multi-token prediction (MTP) for speculative decoding, where the cos
and sin
caches for RoPE were not being refreshed in each step. The core logic to update these caches is correct and necessary. However, I've identified a critical bug in the conditional logic that guards this new code. It can lead to a NameError
under specific conditions (when TorchAir is enabled during a prefill phase), causing the program to crash. My review provides a code suggestion to fix this issue.
if not is_running_torchair: | ||
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[1:batch_size + 1].tolist() | ||
attn_metadata_i.decode.cos = builder.cos_cache[positions].unsqueeze(1).unsqueeze(2) | ||
attn_metadata_i.decode.sin = builder.sin_cache[positions].unsqueeze(1).unsqueeze(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The builder
variable is only defined when self.torchair_graph_enabled
is False
(lines 402-406). However, this block is guarded by if not is_running_torchair:
.
is_running_torchair
is defined as self.torchair_graph_enabled and not self.runner.with_prefill
. This means not is_running_torchair
is equivalent to not self.torchair_graph_enabled or self.runner.with_prefill
.
When self.torchair_graph_enabled
is True
and self.runner.with_prefill
is True
, the condition not is_running_torchair
is met, and the code enters this block. This will raise a NameError
because builder
is not defined in the TorchAir-enabled path.
To fix this, the condition should be changed to if not self.torchair_graph_enabled:
to ensure this block is only executed when builder
is guaranteed to be defined.
if not is_running_torchair: | |
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[1:batch_size + 1].tolist() | |
attn_metadata_i.decode.cos = builder.cos_cache[positions].unsqueeze(1).unsqueeze(2) | |
attn_metadata_i.decode.sin = builder.sin_cache[positions].unsqueeze(1).unsqueeze(2) | |
if not self.torchair_graph_enabled: | |
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[1:batch_size + 1].tolist() | |
attn_metadata_i.decode.cos = builder.cos_cache[positions].unsqueeze(1).unsqueeze(2) | |
attn_metadata_i.decode.sin = builder.sin_cache[positions].unsqueeze(1).unsqueeze(2) |
Signed-off-by: zouyida2052 <[email protected]>
What this PR does / why we need it?
when mtp>1, we need refresh cos ans sin in each step.
Does this PR introduce any user-facing change?
no
How was this patch tested?