21
21
22
22
import torch
23
23
import torch_npu
24
+ import torch .nn as nn
25
+ from vllm .config import VllmConfig
24
26
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
25
27
AttentionLayer , AttentionType )
26
28
from vllm .attention .backends .utils import CommonAttentionState
30
32
from vllm .v1 .core .sched .output import SchedulerOutput
31
33
from vllm .v1 .worker .gpu_input_batch import InputBatch
32
34
33
- from vllm_ascend .attention .utils import \
34
- AscendCommonAttentionMetadata as CommonAttentionMetadata
35
35
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
36
36
from vllm_ascend .ops .attention import vanilla_chunked_prefill
37
37
from vllm_ascend .utils import get_graph_params
38
+ from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
38
39
39
40
40
41
class AscendAttentionBackend (AttentionBackend ):
@@ -156,39 +157,48 @@ def split_metadata_for_multistream(
156
157
157
158
class AscendAttentionMetadataBuilder :
158
159
159
- def __init__ (self , runner ):
160
+ def __init__ (
161
+ self ,
162
+ vllm_config : VllmConfig ,
163
+ device : torch .device ,
164
+ runner
165
+ ):
166
+ self .vllm_config = vllm_config
167
+ self .model_config = vllm_config .model_config
168
+ self .device = device
160
169
self .runner = runner
161
170
162
171
def reorder_batch (self , input_batch : "InputBatch" ,
163
172
scheduler_output : "SchedulerOutput" ) -> bool :
164
173
return False
165
174
166
- def build (self ,
167
- num_reqs ,
168
- num_actual_tokens ,
169
- max_query_len ,
170
- common_attn_metadata : CommonAttentionMetadata ,
171
- enable_dbo_across_dp : bool = False ,
172
- is_only_prefill : bool = False ,
173
- * args ,
174
- ** kwargs ):
175
-
176
- block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
177
- )
178
- block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
179
- block_table [:num_reqs ])
180
-
181
- query_start_loc = common_attn_metadata .query_start_loc
182
- seq_lens = common_attn_metadata .seq_lens
175
+ def build (
176
+ self ,
177
+ common_attn_metadata : AscendCommonAttentionMetadata ,
178
+ ):
179
+ num_reqs = common_attn_metadata .num_reqs
180
+ num_actual_tokens = common_attn_metadata .num_actual_tokens
181
+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:
182
+ num_reqs
183
+ + 1 ]
184
+
185
+ block_table = common_attn_metadata .block_table_tensor
186
+ block_table [:num_reqs , :common_attn_metadata .
187
+ max_num_blocks_per_req ] = (block_table [:num_reqs ])
188
+
189
+ seq_lens = common_attn_metadata .seq_lens_cpu [:num_reqs ]
183
190
# TODO: Refactor these two param to common metadata in runners,
184
191
# preparing for the hybrid KV groups feature
185
- query_lens = common_attn_metadata . query_lens or self . runner . query_lens
192
+ query_lens = query_start_loc_cpu [ 1 :] - query_start_loc_cpu [: - 1 ]
186
193
# Since FIA for GQA is not active now, we temporarily silence it
187
194
seq_lens_list = common_attn_metadata .seq_lens_list
188
195
189
- slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
190
- attn_mask = self .runner .attn_mask
191
- attn_state = self .runner .attn_state
196
+ slot_mapping = common_attn_metadata .slot_mapping_cpu [:num_actual_tokens ].to (
197
+ self .device , non_blocking = True )
198
+ attn_mask = common_attn_metadata .attn_mask
199
+ attn_state = common_attn_metadata .attn_state
200
+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:num_reqs + 1 ]
201
+ query_start_loc = query_start_loc_cpu .to (self .device ,non_blocking = True )
192
202
193
203
attn_metadata = AscendMetadata (
194
204
num_actual_tokens = num_actual_tokens ,
@@ -197,34 +207,53 @@ def build(self,
197
207
query_lens = query_lens ,
198
208
seq_lens = seq_lens ,
199
209
seq_lens_list = seq_lens_list ,
200
- max_query_len = max_query_len ,
210
+ max_query_len = common_attn_metadata . max_query_len ,
201
211
slot_mapping = slot_mapping ,
202
212
attn_mask = attn_mask ,
203
213
attn_state = attn_state ,
204
- enable_dbo_across_dp = enable_dbo_across_dp ,
205
- is_only_prefill = is_only_prefill )
214
+ enable_dbo_across_dp = common_attn_metadata . enable_dbo_across_dp ,
215
+ is_only_prefill = common_attn_metadata . is_only_prefill )
206
216
return attn_metadata
207
217
208
218
def build_dummy_metadata (self , num_actual_tokens , num_reqs ,
209
219
num_scheduled_tokens , attn_state ):
210
220
if attn_state == AscendAttentionState .DecodeOnly :
211
221
# NOTE: We only need to pay attention to seq_lens_list and block_table here
212
- common_attn_metadata = CommonAttentionMetadata (
222
+ common_attn_metadata = AscendCommonAttentionMetadata (
213
223
seq_lens = torch .empty_like (self .runner .seq_lens_cpu ).fill_ (2 ))
214
224
215
225
block_table = self .runner .input_batch .block_table [0 ].block_table
216
226
block_table [:num_reqs , 0 ] = torch .arange (1 ,
217
227
num_reqs + 1 ,
218
228
device = block_table .device ,
219
229
dtype = block_table .dtype )
230
+ block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
231
+ )
232
+ block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
233
+ block_table [:num_reqs ])
234
+
235
+ query_start_loc = common_attn_metadata .query_start_loc
236
+ seq_lens = common_attn_metadata .seq_lens
237
+ query_lens = self .runner .query_lens
238
+ seq_lens_list = None
220
239
221
- attn_metadata = self .build (
222
- num_reqs = num_reqs ,
240
+ slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
241
+ attn_mask = self .runner .attn_mask
242
+ attn_state = self .runner .attn_state
243
+
244
+ attn_metadata = AscendMetadata (
223
245
num_actual_tokens = num_actual_tokens ,
246
+ block_tables = block_table ,
247
+ query_start_loc = query_start_loc ,
248
+ query_lens = query_lens ,
249
+ seq_lens = seq_lens ,
250
+ seq_lens_list = seq_lens_list ,
224
251
max_query_len = num_scheduled_tokens .max (),
225
- common_prefix_len = 0 ,
226
- common_attn_metadata = common_attn_metadata ,
227
- )
252
+ slot_mapping = slot_mapping ,
253
+ attn_mask = attn_mask ,
254
+ attn_state = attn_state ,
255
+ enable_dbo_across_dp = False ,
256
+ is_only_prefill = False )
228
257
else :
229
258
raise NotImplementedError (
230
259
"Currently we only support building dummy metadata for DecodeOnly state"
0 commit comments