21
21
import numpy as np
22
22
import torch
23
23
import torch_npu
24
+ import torch .nn as nn
25
+ from vllm .config import VllmConfig
24
26
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
25
27
AttentionLayer , AttentionType )
26
28
from vllm .attention .backends .utils import PAD_SLOT_ID , CommonAttentionState
30
32
from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
31
33
nd_to_nz_2d )
32
34
from vllm_ascend .worker .npu_input_batch import InputBatch
35
+ from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
33
36
34
37
35
38
class AscendAttentionTorchairBackend (AttentionBackend ):
@@ -145,43 +148,29 @@ class AscendTorchairMetadata:
145
148
146
149
class AscendAttentionTorchairMetadataBuilder :
147
150
148
- def __init__ (self , runner ):
149
- self .runner = runner
151
+ def __init__ (self ,
152
+ vllm_config : VllmConfig ,
153
+ device : torch .device ,):
154
+ self .vllm_config = vllm_config
155
+ self .model_config = vllm_config .model_config
156
+ self .device = device
150
157
151
158
def reorder_batch (self , input_batch : "InputBatch" ,
152
159
scheduler_output : "SchedulerOutput" ) -> bool :
153
160
return False
154
161
155
162
def _get_graph_runner_block_tables (
156
163
self , num_seqs : int , block_tables : torch .Tensor ) -> torch .Tensor :
157
-
158
- max_batch_size , max_blocks = self .runner .graph_block_tables .shape
159
- assert max_batch_size >= num_seqs , f"max_batch_size: { max_batch_size } should be bigger than cur_num_seqs: { num_seqs } "
160
-
161
- if isinstance (self .runner .graph_block_tables , np .ndarray ):
162
- graph_block_tables = torch .zeros ((max_batch_size , max_blocks ),
163
- dtype = block_tables .dtype ,
164
- device = block_tables .device )
165
- else :
166
- graph_block_tables = self .runner .graph_block_tables .to (
167
- device = block_tables .device , dtype = block_tables .dtype )
168
-
169
164
num_blocks = block_tables .size (1 )
170
- if num_blocks <= max_blocks :
171
- graph_block_tables [:num_seqs , :
172
- num_blocks ] = block_tables [:num_seqs , :
173
- num_blocks ]
165
+ if num_blocks <= self .max_blocks :
166
+ return block_tables [:num_seqs , :num_blocks ]
174
167
else :
175
- graph_block_tables [:num_seqs , :
176
- max_blocks ] = block_tables [:num_seqs , :
177
- max_blocks ]
178
-
179
- return graph_block_tables [:num_seqs , :max_blocks ]
168
+ return block_tables [:num_seqs , :self .max_blocks ]
180
169
181
170
def build_torchair_graph_dummy (
182
- self , num_reqs : int ,
183
- num_actual_tokens : int ) -> AscendTorchairMetadata :
184
- device = self . runner . device
171
+ self , common_attn_metadata : AscendCommonAttentionMetadata ) -> AscendTorchairMetadata :
172
+ device = self . device
173
+ num_reqs = common_attn_metadata . num_reqs
185
174
_ , max_blocks = self .runner .graph_block_tables .shape
186
175
block_table = torch .zeros ((num_reqs , max_blocks ),
187
176
dtype = torch .int32 ,
@@ -208,7 +197,7 @@ def build_torchair_graph_dummy(
208
197
max_seq_lens = 1 )
209
198
210
199
attn_metadata = AscendTorchairMetadata (
211
- num_actual_tokens = num_actual_tokens ,
200
+ num_actual_tokens = common_attn_metadata . num_actual_tokens ,
212
201
block_tables = block_table ,
213
202
query_lens = 0 ,
214
203
query_start_loc = query_start_loc ,
@@ -219,46 +208,43 @@ def build_torchair_graph_dummy(
219
208
return attn_metadata
220
209
221
210
def build (self ,
222
- num_reqs ,
223
- num_actual_tokens ,
224
- max_query_len ,
225
- graph_pad_size : int = - 1 ,
226
- enable_dbo_across_dp : bool = False ,
227
- * args ,
228
- ** kwargs ):
229
-
230
- device = self .runner .device
231
-
232
- block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
233
- )
234
- block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
211
+ common_attn_metadata : AscendCommonAttentionMetadata ,
212
+ model : nn .Module ,):
213
+ num_reqs = common_attn_metadata .num_reqs
214
+ num_actual_tokens = common_attn_metadata .num_actual_tokens
215
+
216
+ block_table = common_attn_metadata .block_table_tensor
217
+ block_table [:num_reqs , :common_attn_metadata .max_num_blocks_per_req ] = (
235
218
block_table [:num_reqs ])
236
219
237
- query_lens = self .runner .query_lens
238
- seq_lens = self .runner .seq_lens_cpu [:num_reqs ]
239
- slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
240
- self .runner .device , non_blocking = True )
241
- attn_mask = self .runner .attn_mask
220
+ seq_lens = common_attn_metadata .seq_lens_cpu [:num_reqs ]
221
+ slot_mapping = common_attn_metadata .slot_mapping_cpu [:num_actual_tokens ].to (
222
+ self .device , non_blocking = True )
223
+ attn_mask = common_attn_metadata .attn_mask
242
224
243
- attn_state = self . runner .attn_state
225
+ attn_state = common_attn_metadata .attn_state
244
226
if is_310p () and attn_state == AscendAttentionState .PrefillNoCache :
245
227
mask_nz = nd_to_nz_2d (attn_mask )
246
228
attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (), 29 )
247
229
248
- query_start_loc_cpu = self . runner .query_start_loc_cpu [:num_reqs + 1 ]
249
- query_start_loc = query_start_loc_cpu .to (self .runner . device ,
230
+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:num_reqs + 1 ]
231
+ query_start_loc = query_start_loc_cpu .to (self .device ,
250
232
non_blocking = True )
251
- input_positions = self .runner .positions_cpu [:num_actual_tokens ].to (
252
- device , non_blocking = True ).long ()
233
+ query_lens = query_start_loc_cpu [1 :] - query_start_loc_cpu [:- 1 ]
234
+ # input_positions = common_attn_metadata.positions_cpu[:num_actual_tokens].to(
235
+ # device, non_blocking=True).long()
236
+
237
+ input_positions = common_attn_metadata .positions [:num_actual_tokens ].long ()
253
238
254
239
decode_metadata = None
240
+ graph_pad_size = common_attn_metadata .graph_pad_size
255
241
use_torchair_graph = graph_pad_size > - 1
256
- if self . runner .attn_state in [
242
+ if common_attn_metadata .attn_state in [
257
243
AscendAttentionState .DecodeOnly ,
258
244
]:
259
245
max_seq_lens = seq_lens .max ().item ()
260
246
num_seqs = len (seq_lens )
261
- if use_torchair_graph and self . runner .attn_state in [
247
+ if use_torchair_graph and common_attn_metadata .attn_state in [
262
248
AscendAttentionState .DecodeOnly ,
263
249
]:
264
250
num_reqs_pad_size = 0
@@ -267,7 +253,7 @@ def build(self,
267
253
pad_value = 0
268
254
num_token_pad_size = graph_pad_size - num_actual_tokens
269
255
num_reqs_pad_size = (
270
- graph_pad_size // self . runner .decode_token_per_req -
256
+ graph_pad_size // common_attn_metadata .decode_token_per_req -
271
257
num_reqs )
272
258
pad_value = 1
273
259
padded_seq_lens = seq_lens .tolist () + [pad_value
@@ -308,11 +294,11 @@ def build(self,
308
294
query_start_loc = query_start_loc ,
309
295
query_lens = query_lens ,
310
296
seq_lens = seq_lens ,
311
- max_query_len = max_query_len ,
297
+ max_query_len = common_attn_metadata . max_query_len ,
312
298
slot_mapping = slot_mapping ,
313
299
attn_mask = attn_mask ,
314
300
attn_state = attn_state ,
315
- enable_dbo_across_dp = enable_dbo_across_dp )
301
+ enable_dbo_across_dp = common_attn_metadata . enable_dbo_across_dp )
316
302
return attn_metadata
317
303
318
304
0 commit comments