Skip to content

Commit 37b500a

Browse files
committed
Refactor AscendAttentionMetadataBuilder for better extensibility and make the builder class of torchair extend from it
Signed-off-by: shen-shanshan <[email protected]>
1 parent 60ac4fb commit 37b500a

File tree

2 files changed

+111
-58
lines changed

2 files changed

+111
-58
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,18 @@ class AscendMetadata:
158158
is_only_prefill: bool = False
159159

160160

161+
@dataclass
162+
class AscendAttentionMetadataBuildInfo:
163+
num_actual_tokens: int
164+
block_table: torch.Tensor
165+
query_start_loc: torch.Tensor
166+
query_lens: torch.Tensor
167+
seq_lens: torch.Tensor
168+
slot_mapping: torch.Tensor
169+
attn_mask: torch.Tensor
170+
attn_state: AscendAttentionState
171+
172+
161173
class AscendAttentionMetadataBuilder:
162174

163175
def __init__(
@@ -175,9 +187,60 @@ def reorder_batch(self, input_batch: "InputBatch",
175187
scheduler_output: "SchedulerOutput") -> bool:
176188
return False
177189

190+
def _assemble_build_info(
191+
self,
192+
num_actual_tokens,
193+
block_table,
194+
query_start_loc,
195+
query_lens,
196+
seq_lens,
197+
slot_mapping,
198+
attn_mask,
199+
attn_state: "AscendAttentionState",
200+
) -> "AscendAttentionMetadataBuildInfo":
201+
if is_310p():
202+
if attn_state == AscendAttentionState.PrefillNoCache:
203+
mask_nz = nd_to_nz_2d(attn_mask)
204+
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
205+
ACL_FORMAT_FRACTAL_NZ)
206+
elif attn_state == AscendAttentionState.ChunkedPrefill:
207+
mask_nz = nd_to_nz_spec(attn_mask)
208+
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
209+
ACL_FORMAT_FRACTAL_NZ)
210+
211+
build_info = AscendAttentionMetadataBuildInfo(
212+
num_actual_tokens=num_actual_tokens,
213+
block_table=block_table,
214+
query_start_loc=query_start_loc,
215+
query_lens=query_lens,
216+
seq_lens=seq_lens,
217+
slot_mapping=slot_mapping,
218+
attn_mask=attn_mask,
219+
attn_state=attn_state)
220+
return build_info
221+
222+
def _assemble_attn_metadata(
223+
self,
224+
build_info: "AscendAttentionMetadataBuildInfo",
225+
common_attn_metadata: "AscendCommonAttentionMetadata",
226+
) -> "AscendMetadata":
227+
attn_metadata = AscendMetadata(
228+
num_actual_tokens=build_info.num_actual_tokens,
229+
block_tables=build_info.block_table,
230+
query_start_loc=build_info.query_start_loc,
231+
query_lens=build_info.query_lens,
232+
seq_lens=build_info.seq_lens,
233+
max_query_len=common_attn_metadata.max_query_len,
234+
slot_mapping=build_info.slot_mapping,
235+
attn_mask=build_info.attn_mask,
236+
attn_state=build_info.attn_state,
237+
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
238+
is_only_prefill=common_attn_metadata.is_only_prefill)
239+
return attn_metadata
240+
178241
def build(
179242
self,
180-
common_attn_metadata: AscendCommonAttentionMetadata,
243+
common_attn_metadata: "AscendCommonAttentionMetadata",
181244
model: nn.Module,
182245
):
183246
num_reqs = common_attn_metadata.num_reqs
@@ -205,28 +268,12 @@ def build(
205268
query_start_loc = query_start_loc_cpu.to(self.device,
206269
non_blocking=True)
207270

208-
if is_310p():
209-
if attn_state == AscendAttentionState.PrefillNoCache:
210-
mask_nz = nd_to_nz_2d(attn_mask)
211-
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
212-
ACL_FORMAT_FRACTAL_NZ)
213-
elif attn_state == AscendAttentionState.ChunkedPrefill:
214-
mask_nz = nd_to_nz_spec(attn_mask)
215-
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
216-
ACL_FORMAT_FRACTAL_NZ)
217-
218-
attn_metadata = AscendMetadata(
219-
num_actual_tokens=num_actual_tokens,
220-
block_tables=block_table,
221-
query_start_loc=query_start_loc,
222-
query_lens=query_lens,
223-
seq_lens=seq_lens,
224-
max_query_len=common_attn_metadata.max_query_len,
225-
slot_mapping=slot_mapping,
226-
attn_mask=attn_mask,
227-
attn_state=attn_state,
228-
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
229-
is_only_prefill=common_attn_metadata.is_only_prefill)
271+
build_info = self._assemble_build_info(num_actual_tokens, block_table,
272+
query_start_loc, query_lens,
273+
seq_lens, slot_mapping,
274+
attn_mask, attn_state)
275+
attn_metadata = self._assemble_attn_metadata(build_info,
276+
common_attn_metadata)
230277
return attn_metadata
231278

232279

vllm_ascend/torchair/torchair_attention.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,16 @@
2020

2121
import numpy as np
2222
import torch
23-
import torch.nn as nn
2423
import torch_npu
2524
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
2625
AttentionType)
2726
from vllm.attention.backends.utils import PAD_SLOT_ID
2827
from vllm.config import VllmConfig
2928
from vllm.utils import cdiv
3029

31-
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
32-
AscendAttentionMetadataBuilder,
33-
AscendAttentionState,
34-
AscendMetadata)
30+
from vllm_ascend.attention.attention_v1 import (
31+
AscendAttentionBackend, AscendAttentionMetadataBuilder,
32+
AscendAttentionMetadataBuildInfo, AscendAttentionState, AscendMetadata)
3533
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3634
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
3735
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
@@ -169,44 +167,52 @@ def build_torchair_graph_dummy(
169167
decode=decode_metadata)
170168
return attn_metadata
171169

172-
def build(
170+
def _assemble_build_info(
173171
self,
174-
common_attn_metadata: AscendCommonAttentionMetadata,
175-
model: nn.Module,
176-
):
177-
num_reqs = common_attn_metadata.num_reqs
178-
num_actual_tokens = common_attn_metadata.num_actual_tokens
179-
180-
block_table = common_attn_metadata.block_table_tensor
181-
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
182-
block_table[:num_reqs])
183-
184-
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
185-
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
186-
num_actual_tokens].to(
187-
self.device,
188-
non_blocking=
189-
True)
190-
attn_mask = common_attn_metadata.attn_mask
191-
192-
attn_state = common_attn_metadata.attn_state
172+
num_actual_tokens,
173+
block_table,
174+
query_start_loc,
175+
query_lens,
176+
seq_lens,
177+
slot_mapping,
178+
attn_mask,
179+
attn_state: "AscendAttentionState",
180+
) -> "AscendAttentionMetadataBuildInfo":
193181
if is_310p() and attn_state == AscendAttentionState.PrefillNoCache:
194182
mask_nz = nd_to_nz_2d(attn_mask)
195183
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
196184

197-
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
198-
num_reqs
199-
+ 1]
200-
query_start_loc = query_start_loc_cpu.to(self.device,
201-
non_blocking=True)
202-
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
185+
build_info = AscendAttentionMetadataBuildInfo(
186+
num_actual_tokens=num_actual_tokens,
187+
block_table=block_table,
188+
query_start_loc=query_start_loc,
189+
query_lens=query_lens,
190+
seq_lens=seq_lens,
191+
slot_mapping=slot_mapping,
192+
attn_mask=attn_mask,
193+
attn_state=attn_state)
194+
return build_info
195+
196+
def _assemble_attn_metadata(
197+
self,
198+
build_info: "AscendAttentionMetadataBuildInfo",
199+
common_attn_metadata: "AscendCommonAttentionMetadata",
200+
) -> "AscendMetadata":
201+
num_actual_tokens = build_info.num_actual_tokens
202+
block_table = build_info.block_table
203+
seq_lens = build_info.seq_lens
204+
slot_mapping = build_info.slot_mapping
205+
attn_state = build_info.attn_state
206+
207+
num_reqs = common_attn_metadata.num_reqs
203208
input_positions = common_attn_metadata.positions[:
204209
num_actual_tokens].long(
205210
)
211+
graph_pad_size = common_attn_metadata.graph_pad_size
206212

207213
decode_metadata = None
208-
graph_pad_size = common_attn_metadata.graph_pad_size
209214
use_torchair_graph = graph_pad_size > -1
215+
210216
if common_attn_metadata.attn_state in [
211217
AscendAttentionState.DecodeOnly,
212218
]:
@@ -259,12 +265,12 @@ def build(
259265
decode=decode_metadata,
260266
num_actual_tokens=num_actual_tokens,
261267
block_tables=block_table,
262-
query_start_loc=query_start_loc,
263-
query_lens=query_lens,
268+
query_start_loc=build_info.query_start_loc,
269+
query_lens=build_info.query_lens,
264270
seq_lens=seq_lens,
265271
max_query_len=common_attn_metadata.max_query_len,
266272
slot_mapping=slot_mapping,
267-
attn_mask=attn_mask,
273+
attn_mask=build_info.attn_mask,
268274
attn_state=attn_state,
269275
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
270276
return attn_metadata

0 commit comments

Comments
 (0)