-
Notifications
You must be signed in to change notification settings - Fork 468
[Feat]Make full graph mode compalible with MTP #3276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to make the Full Graph mode compatible with MTP (Multi-Token Prediction). The changes involve enabling graph support for the MLA attention backend, adding logic to handle rotary embeddings in graph mode, and updating the model runner to prepare inputs and execute the model correctly for MTP in graph mode.
I've found two critical bugs. One is in vllm_ascend/attention/mla_v1.py
where local variables for rotary embeddings are not handled correctly, leading to a potential UnboundLocalError
and incorrect behavior in graph mode. The other is a typo in vllm_ascend/worker/model_runner_v1.py
that prevents the correct attention state from being set for MTP graph capture. Both issues need to be addressed to ensure the feature works as intended.
# TODO: After the fullgraph supports MTP, the if branch needs to deleted | ||
assert self.cos_cache is not None | ||
assert self.sin_cache is not None | ||
if cos is None and sin is not None: | ||
cos = self.cos_cache[ | ||
input_positions].unsqueeze( # type: ignore | ||
1).unsqueeze(2) | ||
sin = self.sin_cache[ | ||
input_positions].unsqueeze( # type: ignore | ||
1).unsqueeze(2) | ||
|
||
decode_metadata = AscendMLADecodeMetadata( | ||
input_positions=input_positions, | ||
block_table=block_table, | ||
seq_lens=seq_lens, | ||
seq_lens_list=seq_lens_list, | ||
max_seq_lens=max_seq_lens, | ||
attn_mask=common_attn_metadata.spec_attn_mask, | ||
actual_seq_lengths_q=actual_seq_lengths_q, | ||
sin=sin, | ||
cos=cos) | ||
else: | ||
cos[:num_decode_tokens, | ||
...] = self.cos_cache[input_positions].unsqueeze( | ||
1).unsqueeze(2) | ||
sin[:num_decode_tokens, | ||
...] = self.sin_cache[input_positions].unsqueeze( | ||
1).unsqueeze(2) | ||
|
||
decode_metadata = AscendMLADecodeMetadata( | ||
input_positions=input_positions, | ||
block_table=block_table, | ||
seq_lens=seq_lens, | ||
seq_lens_list=seq_lens_list, | ||
max_seq_lens=max_seq_lens, | ||
attn_mask=common_attn_metadata.spec_attn_mask, | ||
actual_seq_lengths_q=actual_seq_lengths_q, | ||
sin=sin[:num_decode_tokens, ...], | ||
cos=cos[:num_decode_tokens, ...]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block has a critical bug. The local variables cos
and sin
are not guaranteed to be initialized, which will lead to an UnboundLocalError
if num_prefills
is 0.
Additionally, the condition if cos is None and sin is not None:
seems incorrect.
It appears the intention is to use the pre-allocated cos
and sin
tensors from common_attn_metadata
for graph mode, but they are not being used. This will cause a failure during graph capture for decode-only batches.
attn_state = AscendAttentionState.DecodeOnly | ||
if self.speculative_config and \ | ||
self.speculative_config.method == "deepseek_mtp": | ||
attn_states = AscendAttentionState.SpecDecoding | ||
|
||
for attn_group in self.attn_groups[kv_cache_group_id]: | ||
if vllm_version_is("0.10.2"): | ||
builder = attn_group.metadata_builder | ||
else: | ||
builder = attn_group.get_metadata_builder() | ||
attn_metadata_i = builder.build_for_graph_capture( | ||
common_attn_metadata) | ||
common_attn_metadata, | ||
self.get_model()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a typo here. attn_states
is assigned but never used. It should be attn_state
. Because of this, attn_state
is not updated to AscendAttentionState.SpecDecoding
for MTP, and the default DecodeOnly
is used when calling build_for_graph_capture
. This will cause incorrect behavior in MTP graph mode.
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and \
self.speculative_config.method == "deepseek_mtp":
attn_state = AscendAttentionState.SpecDecoding
for attn_group in self.attn_groups[kv_cache_group_id]:
if vllm_version_is("0.10.2"):
builder = attn_group.metadata_builder
else:
builder = attn_group.get_metadata_builder()
attn_metadata_i = builder.build_for_graph_capture(
common_attn_metadata,
attn_state,
self.get_model())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vllm_ascend/compilation/acl_graph.py
Outdated
): | ||
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, | ||
spec_attn_mask, sparse_mode, scale, block_table, block_size, | ||
seq_lens_list, actual_seq_lengths, worlspace, attn_output, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo here, worlspace
.
block_size=block_size, | ||
actual_seq_lengths_kv=seq_lens_list, | ||
actual_seq_lengths=actual_seq_lengths, | ||
workspace=workspace, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should retrieve workspace
from graph params.
e3d7f83
to
a0a497f
Compare
Signed-off-by: anon189Ty <[email protected]>
a0a497f
to
395de6e
Compare
What this PR does / why we need it?
Make the Full Graph mode can run with MTP.
Does this PR introduce any user-facing change?
How was this patch tested?