21
21
22
22
import torch
23
23
import torch_npu
24
+ import torch .nn as nn
25
+ from vllm .config import VllmConfig
24
26
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
25
27
AttentionLayer , AttentionType )
26
28
from vllm .attention .backends .utils import CommonAttentionState
35
37
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
36
38
from vllm_ascend .ops .attention import vanilla_chunked_prefill
37
39
from vllm_ascend .utils import get_graph_params
40
+ from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
38
41
39
42
40
43
class AscendAttentionBackend (AttentionBackend ):
@@ -156,39 +159,48 @@ def split_metadata_for_multistream(
156
159
157
160
class AscendAttentionMetadataBuilder :
158
161
159
- def __init__ (self , runner ):
162
+ def __init__ (
163
+ self ,
164
+ vllm_config : VllmConfig ,
165
+ device : torch .device ,
166
+ runner
167
+ ):
168
+ self .vllm_config = vllm_config
169
+ self .model_config = vllm_config .model_config
170
+ self .device = device
160
171
self .runner = runner
161
172
162
173
def reorder_batch (self , input_batch : "InputBatch" ,
163
174
scheduler_output : "SchedulerOutput" ) -> bool :
164
175
return False
165
176
166
- def build (self ,
167
- num_reqs ,
168
- num_actual_tokens ,
169
- max_query_len ,
170
- common_attn_metadata : CommonAttentionMetadata ,
171
- enable_dbo_across_dp : bool = False ,
172
- is_only_prefill : bool = False ,
173
- * args ,
174
- ** kwargs ):
175
-
176
- block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
177
- )
178
- block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
179
- block_table [:num_reqs ])
180
-
181
- query_start_loc = common_attn_metadata .query_start_loc
182
- seq_lens = common_attn_metadata .seq_lens
177
+ def build (
178
+ self ,
179
+ common_attn_metadata : AscendCommonAttentionMetadata ,
180
+ ):
181
+ num_reqs = common_attn_metadata .num_reqs
182
+ num_actual_tokens = common_attn_metadata .num_actual_tokens
183
+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:
184
+ num_reqs
185
+ + 1 ]
186
+
187
+ block_table = common_attn_metadata .block_table_tensor
188
+ block_table [:num_reqs , :common_attn_metadata .
189
+ max_num_blocks_per_req ] = (block_table [:num_reqs ])
190
+
191
+ seq_lens = common_attn_metadata .seq_lens_cpu [:num_reqs ]
183
192
# TODO: Refactor these two param to common metadata in runners,
184
193
# preparing for the hybrid KV groups feature
185
- query_lens = common_attn_metadata . query_lens or self . runner . query_lens
194
+ query_lens = query_start_loc_cpu [ 1 :] - query_start_loc_cpu [: - 1 ]
186
195
# Since FIA for GQA is not active now, we temporarily silence it
187
196
seq_lens_list = common_attn_metadata .seq_lens_list
188
197
189
- slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
190
- attn_mask = self .runner .attn_mask
191
- attn_state = self .runner .attn_state
198
+ slot_mapping = common_attn_metadata .slot_mapping_cpu [:num_actual_tokens ].to (
199
+ self .device , non_blocking = True )
200
+ attn_mask = common_attn_metadata .attn_mask
201
+ attn_state = common_attn_metadata .attn_state
202
+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:num_reqs + 1 ]
203
+ query_start_loc = query_start_loc_cpu .to (self .device ,non_blocking = True )
192
204
193
205
attn_metadata = AscendMetadata (
194
206
num_actual_tokens = num_actual_tokens ,
@@ -197,12 +209,12 @@ def build(self,
197
209
query_lens = query_lens ,
198
210
seq_lens = seq_lens ,
199
211
seq_lens_list = seq_lens_list ,
200
- max_query_len = max_query_len ,
212
+ max_query_len = common_attn_metadata . max_query_len ,
201
213
slot_mapping = slot_mapping ,
202
214
attn_mask = attn_mask ,
203
215
attn_state = attn_state ,
204
- enable_dbo_across_dp = enable_dbo_across_dp ,
205
- is_only_prefill = is_only_prefill )
216
+ enable_dbo_across_dp = common_attn_metadata . enable_dbo_across_dp ,
217
+ is_only_prefill = common_attn_metadata . is_only_prefill )
206
218
return attn_metadata
207
219
208
220
def build_dummy_metadata (self , num_actual_tokens , num_reqs ,
@@ -217,14 +229,33 @@ def build_dummy_metadata(self, num_actual_tokens, num_reqs,
217
229
num_reqs + 1 ,
218
230
device = block_table .device ,
219
231
dtype = block_table .dtype )
232
+ block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
233
+ )
234
+ block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
235
+ block_table [:num_reqs ])
236
+
237
+ query_start_loc = common_attn_metadata .query_start_loc
238
+ seq_lens = common_attn_metadata .seq_lens
239
+ query_lens = self .runner .query_lens
240
+ seq_lens_list = None
220
241
221
- attn_metadata = self .build (
222
- num_reqs = num_reqs ,
242
+ slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
243
+ attn_mask = self .runner .attn_mask
244
+ attn_state = self .runner .attn_state
245
+
246
+ attn_metadata = AscendMetadata (
223
247
num_actual_tokens = num_actual_tokens ,
248
+ block_tables = block_table ,
249
+ query_start_loc = query_start_loc ,
250
+ query_lens = query_lens ,
251
+ seq_lens = seq_lens ,
252
+ seq_lens_list = seq_lens_list ,
224
253
max_query_len = num_scheduled_tokens .max (),
225
- common_prefix_len = 0 ,
226
- common_attn_metadata = common_attn_metadata ,
227
- )
254
+ slot_mapping = slot_mapping ,
255
+ attn_mask = attn_mask ,
256
+ attn_state = attn_state ,
257
+ enable_dbo_across_dp = False ,
258
+ is_only_prefill = False )
228
259
else :
229
260
raise NotImplementedError (
230
261
"Currently we only support building dummy metadata for DecodeOnly state"
0 commit comments