Skip to content

Commit 643f31e

Browse files
committed
Refactor AscendAttentionTorchairMetadataBuilder for better extensibility
Signed-off-by: shen-shanshan <[email protected]>
1 parent e14f2ef commit 643f31e

File tree

2 files changed

+569
-27
lines changed

2 files changed

+569
-27
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 94 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ class AscendAttentionState(Enum):
119119

120120
@dataclass
121121
class AscendMetadata:
122-
123122
# **************************** Basic Properties ************************** #
124123
attn_mask: Optional[torch.Tensor] = None
125124
# Current state of this attention run.
@@ -155,37 +154,50 @@ class AscendMetadata:
155154
is_only_prefill: bool = False
156155

157156

157+
@dataclass
158+
class AscendAttentionMetadataBuildInfo:
159+
num_actual_tokens: int = 0
160+
block_table: torch.Tensor = None
161+
query_start_loc: torch.Tensor = None
162+
query_lens: torch.Tensor = None
163+
seq_lens: torch.Tensor = None
164+
max_query_len: int = 0
165+
slot_mapping: torch.Tensor = None
166+
attn_mask: torch.Tensor = None
167+
attn_state: AscendAttentionState = None
168+
enable_dbo_across_dp: bool = False
169+
is_only_prefill: bool = False
170+
171+
158172
class AscendAttentionMetadataBuilder:
159173

160174
def __init__(self, runner):
161175
self.runner = runner
162176

163-
def reorder_batch(self, input_batch: "InputBatch",
164-
scheduler_output: "SchedulerOutput") -> bool:
177+
def reorder_batch(
178+
self,
179+
input_batch: "InputBatch",
180+
scheduler_output: "SchedulerOutput",
181+
) -> bool:
165182
return False
166183

167-
def build(self,
168-
num_reqs,
169-
num_actual_tokens,
170-
max_query_len,
171-
enable_dbo_across_dp: bool = False,
172-
is_only_prefill: bool = False):
173-
174-
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
175-
)
176-
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
177-
block_table[:num_reqs])
178-
179-
query_lens = self.runner.query_lens
180-
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
181-
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
182-
self.runner.device, non_blocking=True)
183-
attn_mask = self.runner.attn_mask
184-
attn_state = self.runner.attn_state
185-
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
186-
query_start_loc = query_start_loc_cpu.to(self.runner.device,
187-
non_blocking=True)
188-
184+
def _assemble_build_info(
185+
self,
186+
num_reqs,
187+
num_actual_tokens,
188+
max_query_len,
189+
enable_dbo_across_dp,
190+
is_only_prefill,
191+
block_table,
192+
query_start_loc,
193+
query_lens,
194+
seq_lens,
195+
slot_mapping,
196+
attn_mask,
197+
attn_state: "AscendAttentionState",
198+
*args,
199+
**kwargs,
200+
) -> "AscendAttentionMetadataBuildInfo":
189201
if is_310p():
190202
if attn_state == AscendAttentionState.PrefillNoCache:
191203
mask_nz = nd_to_nz_2d(attn_mask)
@@ -196,9 +208,9 @@ def build(self,
196208
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
197209
ACL_FORMAT_FRACTAL_NZ)
198210

199-
attn_metadata = AscendMetadata(
211+
build_info = AscendAttentionMetadataBuildInfo(
200212
num_actual_tokens=num_actual_tokens,
201-
block_tables=block_table,
213+
block_table=block_table,
202214
query_start_loc=query_start_loc,
203215
query_lens=query_lens,
204216
seq_lens=seq_lens,
@@ -208,6 +220,61 @@ def build(self,
208220
attn_state=attn_state,
209221
enable_dbo_across_dp=enable_dbo_across_dp,
210222
is_only_prefill=is_only_prefill)
223+
return build_info
224+
225+
def _assemble_attn_metadata(
226+
self,
227+
build_info: "AscendAttentionMetadataBuildInfo",
228+
) -> "AscendMetadata":
229+
attn_metadata = AscendMetadata(
230+
num_actual_tokens=build_info.num_actual_tokens,
231+
block_tables=build_info.block_table,
232+
query_start_loc=build_info.query_start_loc,
233+
query_lens=build_info.query_lens,
234+
seq_lens=build_info.seq_lens,
235+
max_query_len=build_info.max_query_len,
236+
slot_mapping=build_info.slot_mapping,
237+
attn_mask=build_info.attn_mask,
238+
attn_state=build_info.attn_state,
239+
enable_dbo_across_dp=build_info.enable_dbo_across_dp,
240+
is_only_prefill=build_info.is_only_prefill)
241+
return attn_metadata
242+
243+
def build(
244+
self,
245+
num_reqs,
246+
num_actual_tokens,
247+
max_query_len,
248+
enable_dbo_across_dp: bool = False,
249+
is_only_prefill: bool = False,
250+
*args,
251+
**kwargs,
252+
) -> "AscendMetadata":
253+
device = self.runner.device
254+
255+
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
256+
)
257+
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
258+
block_table[:num_reqs])
259+
260+
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
261+
query_start_loc = query_start_loc_cpu.to(device, non_blocking=True)
262+
263+
query_lens = self.runner.query_lens
264+
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
265+
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
266+
device, non_blocking=True)
267+
attn_mask = self.runner.attn_mask
268+
attn_state = self.runner.attn_state
269+
270+
build_info = self._assemble_build_info(num_reqs, num_actual_tokens,
271+
max_query_len, block_table,
272+
query_start_loc, query_lens,
273+
seq_lens, slot_mapping,
274+
attn_mask, attn_state, args,
275+
kwargs)
276+
277+
attn_metadata = self._assemble_attn_metadata(build_info)
211278
return attn_metadata
212279

213280

0 commit comments

Comments
 (0)