24
24
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
25
25
AttentionLayer , AttentionType )
26
26
from vllm .attention .backends .utils import CommonAttentionState
27
- from vllm .config import get_current_vllm_config
27
+ from vllm .config import VllmConfig , get_current_vllm_config
28
28
from vllm .forward_context import ForwardContext , get_forward_context
29
29
from vllm .utils import direct_register_custom_op
30
30
from vllm .v1 .core .sched .output import SchedulerOutput
31
31
from vllm .v1 .worker .gpu_input_batch import InputBatch
32
32
33
- from vllm_ascend .attention .utils import \
34
- AscendCommonAttentionMetadata as CommonAttentionMetadata
33
+ from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
35
34
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
36
35
from vllm_ascend .ops .attention import vanilla_chunked_prefill
37
36
from vllm_ascend .utils import get_graph_params
@@ -156,39 +155,49 @@ def split_metadata_for_multistream(
156
155
157
156
class AscendAttentionMetadataBuilder :
158
157
159
- def __init__ (self , runner ):
158
+ def __init__ (self , vllm_config : VllmConfig , device : torch .device , runner ):
159
+ self .vllm_config = vllm_config
160
+ self .model_config = vllm_config .model_config
161
+ self .device = device
160
162
self .runner = runner
161
163
162
164
def reorder_batch (self , input_batch : "InputBatch" ,
163
165
scheduler_output : "SchedulerOutput" ) -> bool :
164
166
return False
165
167
166
- def build (self ,
167
- num_reqs ,
168
- num_actual_tokens ,
169
- max_query_len ,
170
- common_attn_metadata : CommonAttentionMetadata ,
171
- enable_dbo_across_dp : bool = False ,
172
- is_only_prefill : bool = False ,
173
- * args ,
174
- ** kwargs ):
175
-
176
- block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
177
- )
178
- block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
179
- block_table [:num_reqs ])
180
-
181
- query_start_loc = common_attn_metadata .query_start_loc
182
- seq_lens = common_attn_metadata .seq_lens
168
+ def build (
169
+ self ,
170
+ common_attn_metadata : AscendCommonAttentionMetadata ,
171
+ ):
172
+ num_reqs = common_attn_metadata .num_reqs
173
+ num_actual_tokens = common_attn_metadata .num_actual_tokens
174
+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:
175
+ num_reqs
176
+ + 1 ]
177
+
178
+ block_table = common_attn_metadata .block_table_tensor
179
+ block_table [:num_reqs , :common_attn_metadata .
180
+ max_num_blocks_per_req ] = (block_table [:num_reqs ])
181
+
182
+ seq_lens = common_attn_metadata .seq_lens_cpu [:num_reqs ]
183
183
# TODO: Refactor these two param to common metadata in runners,
184
184
# preparing for the hybrid KV groups feature
185
- query_lens = common_attn_metadata . query_lens or self . runner . query_lens
185
+ query_lens = query_start_loc_cpu [ 1 :] - query_start_loc_cpu [: - 1 ]
186
186
# Since FIA for GQA is not active now, we temporarily silence it
187
187
seq_lens_list = common_attn_metadata .seq_lens_list
188
188
189
- slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
190
- attn_mask = self .runner .attn_mask
191
- attn_state = self .runner .attn_state
189
+ slot_mapping = common_attn_metadata .slot_mapping_cpu [:
190
+ num_actual_tokens ].to (
191
+ self .device ,
192
+ non_blocking =
193
+ True )
194
+ attn_mask = common_attn_metadata .attn_mask
195
+ attn_state = common_attn_metadata .attn_state
196
+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:
197
+ num_reqs
198
+ + 1 ]
199
+ query_start_loc = query_start_loc_cpu .to (self .device ,
200
+ non_blocking = True )
192
201
193
202
attn_metadata = AscendMetadata (
194
203
num_actual_tokens = num_actual_tokens ,
@@ -197,34 +206,50 @@ def build(self,
197
206
query_lens = query_lens ,
198
207
seq_lens = seq_lens ,
199
208
seq_lens_list = seq_lens_list ,
200
- max_query_len = max_query_len ,
209
+ max_query_len = common_attn_metadata . max_query_len ,
201
210
slot_mapping = slot_mapping ,
202
211
attn_mask = attn_mask ,
203
212
attn_state = attn_state ,
204
- enable_dbo_across_dp = enable_dbo_across_dp ,
205
- is_only_prefill = is_only_prefill )
213
+ enable_dbo_across_dp = common_attn_metadata . enable_dbo_across_dp ,
214
+ is_only_prefill = common_attn_metadata . is_only_prefill )
206
215
return attn_metadata
207
216
208
217
def build_dummy_metadata (self , num_actual_tokens , num_reqs ,
209
218
num_scheduled_tokens , attn_state ):
210
219
if attn_state == AscendAttentionState .DecodeOnly :
211
220
# NOTE: We only need to pay attention to seq_lens_list and block_table here
212
- common_attn_metadata = CommonAttentionMetadata (
213
- seq_lens = torch .empty_like (self .runner .seq_lens_cpu ).fill_ (2 ))
214
-
215
221
block_table = self .runner .input_batch .block_table [0 ].block_table
216
222
block_table [:num_reqs , 0 ] = torch .arange (1 ,
217
223
num_reqs + 1 ,
218
224
device = block_table .device ,
219
225
dtype = block_table .dtype )
226
+ block_table = self .runner .input_batch .block_table [
227
+ 0 ].get_device_tensor ()
228
+ block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
229
+ block_table [:num_reqs ])
220
230
221
- attn_metadata = self .build (
222
- num_reqs = num_reqs ,
231
+ query_start_loc = None
232
+ seq_lens = torch .empty_like (self .runner .seq_lens_cpu ).fill_ (2 )
233
+ query_lens = self .runner .query_lens
234
+ seq_lens_list = None
235
+
236
+ slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
237
+ attn_mask = self .runner .attn_mask
238
+ attn_state = self .runner .attn_state
239
+
240
+ attn_metadata = AscendMetadata (
223
241
num_actual_tokens = num_actual_tokens ,
242
+ block_tables = block_table ,
243
+ query_start_loc = query_start_loc ,
244
+ query_lens = query_lens ,
245
+ seq_lens = seq_lens ,
246
+ seq_lens_list = seq_lens_list ,
224
247
max_query_len = num_scheduled_tokens .max (),
225
- common_prefix_len = 0 ,
226
- common_attn_metadata = common_attn_metadata ,
227
- )
248
+ slot_mapping = slot_mapping ,
249
+ attn_mask = attn_mask ,
250
+ attn_state = attn_state ,
251
+ enable_dbo_across_dp = False ,
252
+ is_only_prefill = False )
228
253
else :
229
254
raise NotImplementedError (
230
255
"Currently we only support building dummy metadata for DecodeOnly state"
0 commit comments