From 0f6bbc83bd8ccef812194912d70308b294f7957a Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Tue, 26 Aug 2025 08:33:24 +0000 Subject: [PATCH] Refactor AscendAttentionMetadataBuilder for better extensibility and make the builder class of torchair extend from it Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/attention/attention_v1.py | 93 ++++++++++++++++------ vllm_ascend/torchair/torchair_attention.py | 76 ++++++++++-------- 2 files changed, 111 insertions(+), 58 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 5460b9403f..f56fe8e248 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -158,6 +158,18 @@ class AscendMetadata: is_only_prefill: bool = False +@dataclass +class AscendAttentionMetadataBuildInfo: + num_actual_tokens: int + block_table: torch.Tensor + query_start_loc: torch.Tensor + query_lens: torch.Tensor + seq_lens: torch.Tensor + slot_mapping: torch.Tensor + attn_mask: torch.Tensor + attn_state: AscendAttentionState + + class AscendAttentionMetadataBuilder: def __init__( @@ -175,9 +187,60 @@ def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return False + def _assemble_build_info( + self, + num_actual_tokens, + block_table, + query_start_loc, + query_lens, + seq_lens, + slot_mapping, + attn_mask, + attn_state: "AscendAttentionState", + ) -> "AscendAttentionMetadataBuildInfo": + if is_310p(): + if attn_state == AscendAttentionState.PrefillNoCache: + mask_nz = nd_to_nz_2d(attn_mask) + attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), + ACL_FORMAT_FRACTAL_NZ) + elif attn_state == AscendAttentionState.ChunkedPrefill: + mask_nz = nd_to_nz_spec(attn_mask) + attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), + ACL_FORMAT_FRACTAL_NZ) + + build_info = AscendAttentionMetadataBuildInfo( + num_actual_tokens=num_actual_tokens, + block_table=block_table, + query_start_loc=query_start_loc, + query_lens=query_lens, + seq_lens=seq_lens, + slot_mapping=slot_mapping, + attn_mask=attn_mask, + attn_state=attn_state) + return build_info + + def _assemble_attn_metadata( + self, + build_info: "AscendAttentionMetadataBuildInfo", + common_attn_metadata: "AscendCommonAttentionMetadata", + ) -> "AscendMetadata": + attn_metadata = AscendMetadata( + num_actual_tokens=build_info.num_actual_tokens, + block_tables=build_info.block_table, + query_start_loc=build_info.query_start_loc, + query_lens=build_info.query_lens, + seq_lens=build_info.seq_lens, + max_query_len=common_attn_metadata.max_query_len, + slot_mapping=build_info.slot_mapping, + attn_mask=build_info.attn_mask, + attn_state=build_info.attn_state, + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + is_only_prefill=common_attn_metadata.is_only_prefill) + return attn_metadata + def build( self, - common_attn_metadata: AscendCommonAttentionMetadata, + common_attn_metadata: "AscendCommonAttentionMetadata", model: nn.Module, ): num_reqs = common_attn_metadata.num_reqs @@ -205,28 +268,12 @@ def build( query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) - if is_310p(): - if attn_state == AscendAttentionState.PrefillNoCache: - mask_nz = nd_to_nz_2d(attn_mask) - attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), - ACL_FORMAT_FRACTAL_NZ) - elif attn_state == AscendAttentionState.ChunkedPrefill: - mask_nz = nd_to_nz_spec(attn_mask) - attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), - ACL_FORMAT_FRACTAL_NZ) - - attn_metadata = AscendMetadata( - num_actual_tokens=num_actual_tokens, - block_tables=block_table, - query_start_loc=query_start_loc, - query_lens=query_lens, - seq_lens=seq_lens, - max_query_len=common_attn_metadata.max_query_len, - slot_mapping=slot_mapping, - attn_mask=attn_mask, - attn_state=attn_state, - enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, - is_only_prefill=common_attn_metadata.is_only_prefill) + build_info = self._assemble_build_info(num_actual_tokens, block_table, + query_start_loc, query_lens, + seq_lens, slot_mapping, + attn_mask, attn_state) + attn_metadata = self._assemble_attn_metadata(build_info, + common_attn_metadata) return attn_metadata diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index 81f2968a8e..085e160eaf 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -20,7 +20,6 @@ import numpy as np import torch -import torch.nn as nn import torch_npu from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, AttentionType) @@ -28,10 +27,9 @@ from vllm.config import VllmConfig from vllm.utils import cdiv -from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, - AscendAttentionMetadataBuilder, - AscendAttentionState, - AscendMetadata) +from vllm_ascend.attention.attention_v1 import ( + AscendAttentionBackend, AscendAttentionMetadataBuilder, + AscendAttentionMetadataBuildInfo, AscendAttentionState, AscendMetadata) from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, @@ -169,44 +167,52 @@ def build_torchair_graph_dummy( decode=decode_metadata) return attn_metadata - def build( + def _assemble_build_info( self, - common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, - ): - num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens - - block_table = common_attn_metadata.block_table_tensor - block_table[:num_reqs, :self.max_num_blocks_per_req] = ( - block_table[:num_reqs]) - - seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - slot_mapping = common_attn_metadata.slot_mapping_cpu[: - num_actual_tokens].to( - self.device, - non_blocking= - True) - attn_mask = common_attn_metadata.attn_mask - - attn_state = common_attn_metadata.attn_state + num_actual_tokens, + block_table, + query_start_loc, + query_lens, + seq_lens, + slot_mapping, + attn_mask, + attn_state: "AscendAttentionState", + ) -> "AscendAttentionMetadataBuildInfo": if is_310p() and attn_state == AscendAttentionState.PrefillNoCache: mask_nz = nd_to_nz_2d(attn_mask) attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29) - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: - num_reqs - + 1] - query_start_loc = query_start_loc_cpu.to(self.device, - non_blocking=True) - query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + build_info = AscendAttentionMetadataBuildInfo( + num_actual_tokens=num_actual_tokens, + block_table=block_table, + query_start_loc=query_start_loc, + query_lens=query_lens, + seq_lens=seq_lens, + slot_mapping=slot_mapping, + attn_mask=attn_mask, + attn_state=attn_state) + return build_info + + def _assemble_attn_metadata( + self, + build_info: "AscendAttentionMetadataBuildInfo", + common_attn_metadata: "AscendCommonAttentionMetadata", + ) -> "AscendMetadata": + num_actual_tokens = build_info.num_actual_tokens + block_table = build_info.block_table + seq_lens = build_info.seq_lens + slot_mapping = build_info.slot_mapping + attn_state = build_info.attn_state + + num_reqs = common_attn_metadata.num_reqs input_positions = common_attn_metadata.positions[: num_actual_tokens].long( ) + graph_pad_size = common_attn_metadata.graph_pad_size decode_metadata = None - graph_pad_size = common_attn_metadata.graph_pad_size use_torchair_graph = graph_pad_size > -1 + if common_attn_metadata.attn_state in [ AscendAttentionState.DecodeOnly, ]: @@ -259,12 +265,12 @@ def build( decode=decode_metadata, num_actual_tokens=num_actual_tokens, block_tables=block_table, - query_start_loc=query_start_loc, - query_lens=query_lens, + query_start_loc=build_info.query_start_loc, + query_lens=build_info.query_lens, seq_lens=seq_lens, max_query_len=common_attn_metadata.max_query_len, slot_mapping=slot_mapping, - attn_mask=attn_mask, + attn_mask=build_info.attn_mask, attn_state=attn_state, enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp) return attn_metadata