Skip to content

Commit c4c07a0

Browse files
committed
refact attn metadata build
Signed-off-by: weiguihua2 <[email protected]>
1 parent 992271b commit c4c07a0

File tree

8 files changed

+367
-217
lines changed

8 files changed

+367
-217
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import torch
2323
import torch_npu
24+
import torch.nn as nn
25+
from vllm.config import VllmConfig
2426
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2527
AttentionLayer, AttentionType)
2628
from vllm.attention.backends.utils import CommonAttentionState
@@ -32,6 +34,7 @@
3234
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
3335
nd_to_nz_2d, nd_to_nz_spec)
3436
from vllm_ascend.worker.npu_input_batch import InputBatch
37+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3538

3639

3740
class AscendAttentionBackend(AttentionBackend):
@@ -156,33 +159,36 @@ class AscendMetadata:
156159

157160
class AscendAttentionMetadataBuilder:
158161

159-
def __init__(self, runner):
160-
self.runner = runner
162+
def __init__(self,
163+
vllm_config: VllmConfig,
164+
device: torch.device,):
165+
self.vllm_config = vllm_config
166+
self.model_config = vllm_config.model_config
167+
self.device = device
161168

162169
def reorder_batch(self, input_batch: "InputBatch",
163170
scheduler_output: "SchedulerOutput") -> bool:
164171
return False
165172

166173
def build(self,
167-
num_reqs,
168-
num_actual_tokens,
169-
max_query_len,
170-
enable_dbo_across_dp: bool = False,
171-
is_only_prefill: bool = False):
172-
173-
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
174-
)
175-
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
174+
common_attn_metadata: AscendCommonAttentionMetadata,
175+
model: nn.Module,):
176+
num_reqs = common_attn_metadata.num_reqs
177+
num_actual_tokens = common_attn_metadata.num_actual_tokens
178+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
179+
180+
block_table = common_attn_metadata.block_table_tensor
181+
block_table[:num_reqs, :common_attn_metadata.max_num_blocks_per_req] = (
176182
block_table[:num_reqs])
177183

178-
query_lens = self.runner.query_lens
179-
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
180-
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
181-
self.runner.device, non_blocking=True)
182-
attn_mask = self.runner.attn_mask
183-
attn_state = self.runner.attn_state
184-
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
185-
query_start_loc = query_start_loc_cpu.to(self.runner.device,
184+
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
185+
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
186+
slot_mapping = common_attn_metadata.slot_mapping_cpu[:num_actual_tokens].to(
187+
self.device, non_blocking=True)
188+
attn_mask = common_attn_metadata.attn_mask
189+
attn_state = common_attn_metadata.attn_state
190+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
191+
query_start_loc = query_start_loc_cpu.to(self.device,
186192
non_blocking=True)
187193

188194
if is_310p():
@@ -201,12 +207,12 @@ def build(self,
201207
query_start_loc=query_start_loc,
202208
query_lens=query_lens,
203209
seq_lens=seq_lens,
204-
max_query_len=max_query_len,
210+
max_query_len=common_attn_metadata.max_query_len,
205211
slot_mapping=slot_mapping,
206212
attn_mask=attn_mask,
207213
attn_state=attn_state,
208-
enable_dbo_across_dp=enable_dbo_across_dp,
209-
is_only_prefill=is_only_prefill)
214+
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
215+
is_only_prefill=common_attn_metadata.is_only_prefill)
210216
return attn_metadata
211217

212218

vllm_ascend/attention/attention_v1_torchair.py

Lines changed: 41 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import numpy as np
2222
import torch
2323
import torch_npu
24+
import torch.nn as nn
25+
from vllm.config import VllmConfig
2426
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2527
AttentionLayer, AttentionType)
2628
from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState
@@ -30,6 +32,7 @@
3032
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
3133
nd_to_nz_2d)
3234
from vllm_ascend.worker.npu_input_batch import InputBatch
35+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3336

3437

3538
class AscendAttentionTorchairBackend(AttentionBackend):
@@ -145,43 +148,29 @@ class AscendTorchairMetadata:
145148

146149
class AscendAttentionTorchairMetadataBuilder:
147150

148-
def __init__(self, runner):
149-
self.runner = runner
151+
def __init__(self,
152+
vllm_config: VllmConfig,
153+
device: torch.device,):
154+
self.vllm_config = vllm_config
155+
self.model_config = vllm_config.model_config
156+
self.device = device
150157

151158
def reorder_batch(self, input_batch: "InputBatch",
152159
scheduler_output: "SchedulerOutput") -> bool:
153160
return False
154161

155162
def _get_graph_runner_block_tables(
156163
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
157-
158-
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
159-
assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}"
160-
161-
if isinstance(self.runner.graph_block_tables, np.ndarray):
162-
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
163-
dtype=block_tables.dtype,
164-
device=block_tables.device)
165-
else:
166-
graph_block_tables = self.runner.graph_block_tables.to(
167-
device=block_tables.device, dtype=block_tables.dtype)
168-
169164
num_blocks = block_tables.size(1)
170-
if num_blocks <= max_blocks:
171-
graph_block_tables[:num_seqs, :
172-
num_blocks] = block_tables[:num_seqs, :
173-
num_blocks]
165+
if num_blocks <= self.max_blocks:
166+
return block_tables[:num_seqs, :num_blocks]
174167
else:
175-
graph_block_tables[:num_seqs, :
176-
max_blocks] = block_tables[:num_seqs, :
177-
max_blocks]
178-
179-
return graph_block_tables[:num_seqs, :max_blocks]
168+
return block_tables[:num_seqs, :self.max_blocks]
180169

181170
def build_torchair_graph_dummy(
182-
self, num_reqs: int,
183-
num_actual_tokens: int) -> AscendTorchairMetadata:
184-
device = self.runner.device
171+
self, common_attn_metadata: AscendCommonAttentionMetadata) -> AscendTorchairMetadata:
172+
device = self.device
173+
num_reqs = common_attn_metadata.num_reqs
185174
_, max_blocks = self.runner.graph_block_tables.shape
186175
block_table = torch.zeros((num_reqs, max_blocks),
187176
dtype=torch.int32,
@@ -208,7 +197,7 @@ def build_torchair_graph_dummy(
208197
max_seq_lens=1)
209198

210199
attn_metadata = AscendTorchairMetadata(
211-
num_actual_tokens=num_actual_tokens,
200+
num_actual_tokens=common_attn_metadata.num_actual_tokens,
212201
block_tables=block_table,
213202
query_lens=0,
214203
query_start_loc=query_start_loc,
@@ -219,46 +208,43 @@ def build_torchair_graph_dummy(
219208
return attn_metadata
220209

221210
def build(self,
222-
num_reqs,
223-
num_actual_tokens,
224-
max_query_len,
225-
graph_pad_size: int = -1,
226-
enable_dbo_across_dp: bool = False,
227-
*args,
228-
**kwargs):
229-
230-
device = self.runner.device
231-
232-
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
233-
)
234-
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
211+
common_attn_metadata: AscendCommonAttentionMetadata,
212+
model: nn.Module,):
213+
num_reqs = common_attn_metadata.num_reqs
214+
num_actual_tokens = common_attn_metadata.num_actual_tokens
215+
216+
block_table = common_attn_metadata.block_table_tensor
217+
block_table[:num_reqs, :common_attn_metadata.max_num_blocks_per_req] = (
235218
block_table[:num_reqs])
236219

237-
query_lens = self.runner.query_lens
238-
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
239-
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
240-
self.runner.device, non_blocking=True)
241-
attn_mask = self.runner.attn_mask
220+
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
221+
slot_mapping = common_attn_metadata.slot_mapping_cpu[:num_actual_tokens].to(
222+
self.device, non_blocking=True)
223+
attn_mask = common_attn_metadata.attn_mask
242224

243-
attn_state = self.runner.attn_state
225+
attn_state = common_attn_metadata.attn_state
244226
if is_310p() and attn_state == AscendAttentionState.PrefillNoCache:
245227
mask_nz = nd_to_nz_2d(attn_mask)
246228
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
247229

248-
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
249-
query_start_loc = query_start_loc_cpu.to(self.runner.device,
230+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
231+
query_start_loc = query_start_loc_cpu.to(self.device,
250232
non_blocking=True)
251-
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
252-
device, non_blocking=True).long()
233+
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
234+
# input_positions = common_attn_metadata.positions_cpu[:num_actual_tokens].to(
235+
# device, non_blocking=True).long()
236+
237+
input_positions = common_attn_metadata.positions[:num_actual_tokens].long()
253238

254239
decode_metadata = None
240+
graph_pad_size = common_attn_metadata.graph_pad_size
255241
use_torchair_graph = graph_pad_size > -1
256-
if self.runner.attn_state in [
242+
if common_attn_metadata.attn_state in [
257243
AscendAttentionState.DecodeOnly,
258244
]:
259245
max_seq_lens = seq_lens.max().item()
260246
num_seqs = len(seq_lens)
261-
if use_torchair_graph and self.runner.attn_state in [
247+
if use_torchair_graph and common_attn_metadata.attn_state in [
262248
AscendAttentionState.DecodeOnly,
263249
]:
264250
num_reqs_pad_size = 0
@@ -267,7 +253,7 @@ def build(self,
267253
pad_value = 0
268254
num_token_pad_size = graph_pad_size - num_actual_tokens
269255
num_reqs_pad_size = (
270-
graph_pad_size // self.runner.decode_token_per_req -
256+
graph_pad_size // common_attn_metadata.decode_token_per_req -
271257
num_reqs)
272258
pad_value = 1
273259
padded_seq_lens = seq_lens.tolist() + [pad_value
@@ -308,11 +294,11 @@ def build(self,
308294
query_start_loc=query_start_loc,
309295
query_lens=query_lens,
310296
seq_lens=seq_lens,
311-
max_query_len=max_query_len,
297+
max_query_len=common_attn_metadata.max_query_len,
312298
slot_mapping=slot_mapping,
313299
attn_mask=attn_mask,
314300
attn_state=attn_state,
315-
enable_dbo_across_dp=enable_dbo_across_dp)
301+
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
316302
return attn_metadata
317303

318304

0 commit comments

Comments
 (0)