From 395de6ee88d6b5a26edcf56437bee1eb20b134a6 Mon Sep 17 00:00:00 2001 From: anon189Ty Date: Mon, 29 Sep 2025 23:16:12 +0800 Subject: [PATCH] [Feat]Make full graph mode compalible with MTP Signed-off-by: anon189Ty --- vllm_ascend/attention/attention_v1.py | 1 + vllm_ascend/attention/mla_v1.py | 84 +++++++++++++++++++++------ vllm_ascend/attention/utils.py | 4 ++ vllm_ascend/compilation/acl_graph.py | 54 +++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 56 ++++++++++++++++-- 5 files changed, 176 insertions(+), 23 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index d289bb4578..98c8c57407 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -237,6 +237,7 @@ def build_for_graph_capture( self, common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, + model: Optional[nn.Module] = None, ): if attn_state == AscendAttentionState.DecodeOnly: attn_metadata = self.build( diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 73cbae6207..c344c9883a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -169,7 +169,7 @@ def split_metadata_for_multistream( class AscendMLAMetadataBuilder: # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.NEVER + AttentionCGSupport.UNIFORM_BATCH """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -389,6 +389,8 @@ def build( decode_metadata = None if num_decodes > 0: + cos = common_attn_metadata.cos + sin = common_attn_metadata.sin # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist() max_seq_lens = seq_lens[:num_decodes].max().item() @@ -397,21 +399,45 @@ def build( block_table = block_table[:num_decodes, ...] seq_lens_list = seq_lens.tolist() - cos = self.cos_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - - decode_metadata = AscendMLADecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin, - cos=cos) + # TODO: After the fullgraph supports MTP, the if branch needs to deleted + assert self.cos_cache is not None + assert self.sin_cache is not None + if cos is None and sin is None: + cos = self.cos_cache[ + input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + + decode_metadata = AscendMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin, + cos=cos) + else: + cos[:num_decode_tokens, + ...] = self.cos_cache[input_positions].unsqueeze( + 1).unsqueeze(2) + sin[:num_decode_tokens, + ...] = self.sin_cache[input_positions].unsqueeze( + 1).unsqueeze(2) + + decode_metadata = AscendMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens_list, + max_seq_lens=max_seq_lens, + attn_mask=common_attn_metadata.spec_attn_mask, + actual_seq_lengths_q=actual_seq_lengths_q, + sin=sin[:num_decode_tokens, ...], + cos=cos[:num_decode_tokens, ...]) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, @@ -431,6 +457,29 @@ def build( enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, ) + def build_for_graph_capture( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, + model: Optional[nn.Module] = None, + ): + if attn_state in { + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + }: + attn_metadata = self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + model=model, + ) + else: + raise NotImplementedError( + "Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state" + ) + + attn_metadata.attn_state = attn_state + return attn_metadata + class DecodeMLAPreprocessResult(NamedTuple): ql_nope: Optional[torch.Tensor] = None @@ -814,7 +863,8 @@ def _forward_decode( if attn_metadata.attn_state in [ AscendAttentionState.SpecDecoding, - AscendAttentionState.ChunkedPrefill + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.DecodeOnly, ] and self.speculative_config is not None: # Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill input_layout = "TND" diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 519cde0c5a..cff3768924 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -63,6 +63,10 @@ class AscendCommonAttentionMetadata: graph_pad_size: int = -1 + # NOTE: This is a temporary solution for rotary embedding in MLA + cos: torch.Tensor = None + sin: torch.Tensor = None + def split_decodes_and_prefills( common_attn_metadata: AscendCommonAttentionMetadata, diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 8a41807739..dc8671b203 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -229,6 +229,60 @@ def update_attn_params(update_stream, forward_context, runtime_shape): event.record(update_stream) +def update_mla_attn_params(update_stream, forward_context, runtime_shape, + speculative_config): + graph_params = get_graph_params() + # FIXME: Behold! We are using a temporary hack here to update the args + # for each layer's attention op in the graph. + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, + spec_attn_mask, sparse_mode, scale, block_table, block_size, + seq_lens_list, actual_seq_lengths, workspace, attn_output, + softmax_lse) = param + seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list + if speculative_config and speculative_config.method == "deepseek_mtp": + actual_seq_lengths = forward_context.attn_metadata[ + key].decode.actual_seq_lengths_q + spec_multiple = speculative_config.num_speculative_tokens + 1 + seq_lens_list = seq_lens_list + [0] * ( + runtime_shape // spec_multiple - len(seq_lens_list)) + actual_seq_lengths = [ + spec_multiple * (i + 1) + for i in range(runtime_shape // spec_multiple) + ] + with torch.npu.stream(update_stream): + torch.npu.graph_task_update_begin(update_stream, handle) + + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=num_heads, + num_key_value_heads=num_kv_heads, + input_layout=input_layout, + atten_mask=spec_attn_mask, + sparse_mode=sparse_mode, + scale=scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=seq_lens_list, + actual_seq_lengths=actual_seq_lengths, + workspace=workspace, + out=[attn_output, softmax_lse]) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + + @dataclass class GraphParams: events: dict[int, list[torch.npu.ExternalEvent]] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0984e2bf63..d163dd616e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -102,7 +102,8 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, set_graph_params, - update_attn_params) + update_attn_params, + update_mla_attn_params) from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.core.eplb_device_transfer_loader import \ D2DExpertWeightLoader @@ -351,6 +352,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.speculative_config.method, self.vllm_config, self.device, self) self.rejection_sampler = AscendRejectionSampler() + self.actual_seq_lengths_q = list( + range(self.decode_token_per_req, self.max_num_tokens + 1, + self.decode_token_per_req)) # Persistent batch. self.input_ids = torch.zeros(self.max_num_tokens, @@ -369,6 +373,25 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): dtype=torch.int32, device=self.device) + if self.vllm_config.model_config.use_mla and \ + self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.cos = torch.ones(self.max_num_reqs, + 1, + 1, + rope_dim, + dtype=self.dtype, + device=self.device) + self.sin = torch.zeros(self.max_num_reqs, + 1, + 1, + rope_dim, + dtype=self.dtype, + device=self.device) + else: + self.cos = None + self.sin = None + self.uses_mrope = self.model_config.uses_mrope # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -1518,6 +1541,8 @@ def _prepare_inputs( max_query_len=max_num_scheduled_tokens, graph_pad_size=self.graph_pad_size, decode_token_per_req=self.decode_token_per_req, + cos=self.cos, + sin=self.sin, ) if self.speculative_config and \ @@ -1547,7 +1572,7 @@ def _prepare_inputs( attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - model=self.model, + model=self.get_model(), **extra_attn_metadata_args) if self.vllm_config.model_config.use_mla: @@ -1582,8 +1607,14 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: - update_attn_params(self.update_stream, forward_context, - positions.shape[0]) + if self.vllm_config.model_config.use_mla: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params(self.update_stream, forward_context, + positions.shape[0], + self.speculative_config) + else: + update_attn_params(self.update_stream, forward_context, + positions.shape[0]) if get_forward_context().sp_enabled: hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) @@ -2285,7 +2316,10 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs, block_table_tensor = self.input_batch.block_table[ kv_cache_group_id].get_device_tensor() common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc=torch.tensor( + [0] + self.actual_seq_lengths_q[:num_reqs], + device=self.device, + dtype=torch.int32), query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], seq_lens_cpu=self.seq_lens_cpu, @@ -2296,9 +2330,19 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs, block_table_tensor=block_table_tensor[:num_reqs], slot_mapping=self.slot_mapping, num_computed_tokens_cpu=num_computed_tokens_cpu, + positions=self.positions, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, max_query_len=max_query_len, decode_token_per_req=self.decode_token_per_req, + cos=self.cos, + sin=self.sin, ) + attn_state = AscendAttentionState.DecodeOnly + if self.speculative_config and \ + self.speculative_config.method == "deepseek_mtp": + attn_state = AscendAttentionState.SpecDecoding for attn_group in self.attn_groups[kv_cache_group_id]: if vllm_version_is("0.10.2"): @@ -2306,7 +2350,7 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs, else: builder = attn_group.get_metadata_builder() attn_metadata_i = builder.build_for_graph_capture( - common_attn_metadata) + common_attn_metadata, attn_state, self.get_model()) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i