-
Notifications
You must be signed in to change notification settings - Fork 386
[v0.9.1] MTP supports V1 scheduler #2371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the attention metadata handling by introducing a centralized AscendCommonAttentionMetadata
dataclass. This is a good architectural improvement that centralizes logic and reduces code duplication. However, the review identified several critical issues related to this refactoring, including incorrect tensor slicing and initialization that could lead to runtime errors or incorrect behavior. Specifically, there are bugs in build_dummy_metadata
in attention_v1.py
and in the build
method of mla_v1.py
's metadata builder. There are also some redundant code assignments that should be cleaned up.
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( | ||
block_table[:num_reqs]) | ||
|
||
query_start_loc = common_attn_metadata.query_start_loc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The common_attn_metadata
used here is initialized without a query_start_loc
, causing common_attn_metadata.query_start_loc
to be None
. This None
value is then passed to the AscendMetadata
constructor, which will lead to a runtime error as query_start_loc
is a required tensor. To fix this, a dummy query_start_loc
tensor should be created for the decode-only state.
query_start_loc = common_attn_metadata.query_start_loc | |
query_start_loc = torch.arange(0, num_reqs + 1, dtype=torch.int32, device=block_table.device) |
vllm_ascend/attention/mla_v1.py
Outdated
actual_seq_lengths_q = query_start_loc[1:num_decodes+1].tolist() | ||
max_seq_lens = seq_lens[:num_decodes].max().item() | ||
seq_lens = seq_lens[:num_decodes] | ||
input_positions = input_positions[:num_decodes] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input_positions
tensor is incorrectly sliced using num_decodes
(the number of decode requests). It should be sliced with num_decode_tokens
(the number of tokens in decode requests) to select the correct positions for the decode phase. This is a critical bug that will lead to incorrect behavior.
input_positions = input_positions[:num_decodes] | |
input_positions = input_positions[:num_decode_tokens] |
self.device, non_blocking=True) | ||
attn_mask = common_attn_metadata.attn_mask | ||
attn_state = common_attn_metadata.attn_state | ||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_positions = self.runner.positions_cpu[:num_tokens].to( | ||
device, non_blocking=True).long() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
): | ||
self.vllm_config = vllm_config | ||
self.model_config = vllm_config.model_config | ||
self.device = device | ||
self.runner = runner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove runner
619fff4
to
96a8bfb
Compare
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Signed-off-by: xuyexiong <[email protected]>
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?