Skip to content

Conversation

anon189Ty
Copy link
Contributor

@anon189Ty anon189Ty commented Sep 29, 2025

What this PR does / why we need it?

Make the Full Graph mode can run with MTP.

Does this PR introduce any user-facing change?

How was this patch tested?

Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to make the Full Graph mode compatible with MTP (Multi-Token Prediction). The changes involve enabling graph support for the MLA attention backend, adding logic to handle rotary embeddings in graph mode, and updating the model runner to prepare inputs and execute the model correctly for MTP in graph mode.

I've found two critical bugs. One is in vllm_ascend/attention/mla_v1.py where local variables for rotary embeddings are not handled correctly, leading to a potential UnboundLocalError and incorrect behavior in graph mode. The other is a typo in vllm_ascend/worker/model_runner_v1.py that prevents the correct attention state from being set for MTP graph capture. Both issues need to be addressed to ensure the feature works as intended.

Comment on lines 400 to 440
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
assert self.cos_cache is not None
assert self.sin_cache is not None
if cos is None and sin is not None:
cos = self.cos_cache[
input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)

decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos)
else:
cos[:num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(
1).unsqueeze(2)
sin[:num_decode_tokens,
...] = self.sin_cache[input_positions].unsqueeze(
1).unsqueeze(2)

decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:num_decode_tokens, ...],
cos=cos[:num_decode_tokens, ...])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block has a critical bug. The local variables cos and sin are not guaranteed to be initialized, which will lead to an UnboundLocalError if num_prefills is 0.

Additionally, the condition if cos is None and sin is not None: seems incorrect.

It appears the intention is to use the pre-allocated cos and sin tensors from common_attn_metadata for graph mode, but they are not being used. This will cause a failure during graph capture for decode-only batches.

Comment on lines 2341 to 2354
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and \
self.speculative_config.method == "deepseek_mtp":
attn_states = AscendAttentionState.SpecDecoding

for attn_group in self.attn_groups[kv_cache_group_id]:
if vllm_version_is("0.10.2"):
builder = attn_group.metadata_builder
else:
builder = attn_group.get_metadata_builder()
attn_metadata_i = builder.build_for_graph_capture(
common_attn_metadata)
common_attn_metadata,
self.get_model())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a typo here. attn_states is assigned but never used. It should be attn_state. Because of this, attn_state is not updated to AscendAttentionState.SpecDecoding for MTP, and the default DecodeOnly is used when calling build_for_graph_capture. This will cause incorrect behavior in MTP graph mode.

                attn_state = AscendAttentionState.DecodeOnly
                if self.speculative_config and \
                        self.speculative_config.method == "deepseek_mtp":
                    attn_state = AscendAttentionState.SpecDecoding

                for attn_group in self.attn_groups[kv_cache_group_id]:
                    if vllm_version_is("0.10.2"):
                        builder = attn_group.metadata_builder
                    else:
                        builder = attn_group.get_metadata_builder()
                    attn_metadata_i = builder.build_for_graph_capture(
                        common_attn_metadata,
                        attn_state,
                        self.get_model())

Copy link
Collaborator

@yiz-liu yiz-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can rebase this PR onto #3125 so you'll have a clean merge once #3125 is merged.

):
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
spec_attn_mask, sparse_mode, scale, block_table, block_size,
seq_lens_list, actual_seq_lengths, worlspace, attn_output,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo here, worlspace.

block_size=block_size,
actual_seq_lengths_kv=seq_lens_list,
actual_seq_lengths=actual_seq_lengths,
workspace=workspace,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should retrieve workspace from graph params.

@anon189Ty anon189Ty force-pushed the mtp_full_graph_compalible branch 2 times, most recently from e3d7f83 to a0a497f Compare September 30, 2025 09:03
@anon189Ty anon189Ty force-pushed the mtp_full_graph_compalible branch from a0a497f to 395de6e Compare September 30, 2025 17:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants