Skip to content

Commit fddbc04

Browse files
committed
Extract _prepare_build_info() and _assemble_build_info() from build() in AscendAttentionMetadataBuilder
Signed-off-by: shen-shanshan <[email protected]>
1 parent e14f2ef commit fddbc04

File tree

2 files changed

+566
-22
lines changed

2 files changed

+566
-22
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 90 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ class AscendAttentionState(Enum):
119119

120120
@dataclass
121121
class AscendMetadata:
122-
123122
# **************************** Basic Properties ************************** #
124123
attn_mask: Optional[torch.Tensor] = None
125124
# Current state of this attention run.
@@ -155,37 +154,106 @@ class AscendMetadata:
155154
is_only_prefill: bool = False
156155

157156

157+
@dataclass
158+
class AscendAttentionMetadataBuildInfo:
159+
block_tables: torch.Tensor = None
160+
query_start_loc: torch.Tensor = None
161+
query_lens: torch.Tensor = None
162+
seq_lens: torch.Tensor = None
163+
slot_mapping: torch.Tensor = None
164+
attn_mask: torch.Tensor = None
165+
attn_state: AscendAttentionState = None
166+
167+
158168
class AscendAttentionMetadataBuilder:
159169

160170
def __init__(self, runner):
161171
self.runner = runner
162172

163-
def reorder_batch(self, input_batch: "InputBatch",
164-
scheduler_output: "SchedulerOutput") -> bool:
173+
def reorder_batch(
174+
self,
175+
input_batch: "InputBatch",
176+
scheduler_output: "SchedulerOutput",
177+
) -> bool:
165178
return False
166179

167-
def build(self,
168-
num_reqs,
169-
num_actual_tokens,
170-
max_query_len,
171-
enable_dbo_across_dp: bool = False,
172-
is_only_prefill: bool = False):
180+
def _assemble_build_info(
181+
self,
182+
num_reqs,
183+
num_actual_tokens,
184+
max_query_len,
185+
block_tables,
186+
query_start_loc,
187+
query_lens,
188+
seq_lens,
189+
slot_mapping,
190+
attn_mask,
191+
attn_state: "AscendAttentionState",
192+
*args,
193+
**kwargs,
194+
) -> "AscendAttentionMetadataBuildInfo":
195+
build_info = AscendAttentionMetadataBuildInfo(
196+
block_tables=block_tables,
197+
query_start_loc=query_start_loc,
198+
query_lens=query_lens,
199+
seq_lens=seq_lens,
200+
slot_mapping=slot_mapping,
201+
attn_mask=attn_mask,
202+
attn_state=attn_state)
203+
return build_info
204+
205+
def _prepare_build_info(
206+
self,
207+
num_reqs,
208+
num_actual_tokens,
209+
max_query_len,
210+
enable_dbo_across_dp,
211+
is_only_prefill,
212+
*args,
213+
**kwargs,
214+
) -> "AscendAttentionMetadataBuildInfo":
215+
device = self.runner.device
216+
217+
block_tables = self.runner.input_batch.block_table[
218+
0].get_device_tensor()
219+
block_tables[:num_reqs, :self.runner.max_num_blocks_per_req] = (
220+
block_tables[:num_reqs])
173221

174-
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
175-
)
176-
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
177-
block_table[:num_reqs])
222+
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
223+
query_start_loc = query_start_loc_cpu.to(device, non_blocking=True)
178224

179225
query_lens = self.runner.query_lens
180226
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
181227
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
182-
self.runner.device, non_blocking=True)
228+
device, non_blocking=True)
183229
attn_mask = self.runner.attn_mask
184230
attn_state = self.runner.attn_state
185-
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
186-
query_start_loc = query_start_loc_cpu.to(self.runner.device,
187-
non_blocking=True)
188231

232+
build_info = self._assemble_build_info(num_reqs, num_actual_tokens,
233+
max_query_len, block_tables,
234+
query_start_loc, query_lens,
235+
seq_lens, slot_mapping,
236+
attn_mask, attn_state, args,
237+
kwargs)
238+
return build_info
239+
240+
def build(
241+
self,
242+
num_reqs,
243+
num_actual_tokens,
244+
max_query_len,
245+
enable_dbo_across_dp: bool = False,
246+
is_only_prefill: bool = False,
247+
*args,
248+
**kwargs,
249+
):
250+
build_info = self._prepare_build_info(num_reqs, num_actual_tokens,
251+
max_query_len,
252+
enable_dbo_across_dp,
253+
is_only_prefill, args, kwargs)
254+
255+
attn_mask = build_info.attn_mask
256+
attn_state = build_info.attn_state
189257
if is_310p():
190258
if attn_state == AscendAttentionState.PrefillNoCache:
191259
mask_nz = nd_to_nz_2d(attn_mask)
@@ -198,12 +266,12 @@ def build(self,
198266

199267
attn_metadata = AscendMetadata(
200268
num_actual_tokens=num_actual_tokens,
201-
block_tables=block_table,
202-
query_start_loc=query_start_loc,
203-
query_lens=query_lens,
204-
seq_lens=seq_lens,
269+
block_tables=build_info.block_tables,
270+
query_start_loc=build_info.query_start_loc,
271+
query_lens=build_info.query_lens,
272+
seq_lens=build_info.seq_lens,
205273
max_query_len=max_query_len,
206-
slot_mapping=slot_mapping,
274+
slot_mapping=build_info.slot_mapping,
207275
attn_mask=attn_mask,
208276
attn_state=attn_state,
209277
enable_dbo_across_dp=enable_dbo_across_dp,

0 commit comments

Comments
 (0)