102
102
from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
103
103
from vllm_ascend .compilation .acl_graph import (ACLGraphWrapper ,
104
104
set_graph_params ,
105
- update_attn_params )
105
+ update_attn_params ,
106
+ update_mla_attn_params )
106
107
from vllm_ascend .eplb .adaptor .vllm_adaptor import VllmEplbAdaptor
107
108
from vllm_ascend .eplb .core .eplb_device_transfer_loader import \
108
109
D2DExpertWeightLoader
@@ -351,6 +352,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
351
352
self .speculative_config .method , self .vllm_config ,
352
353
self .device , self )
353
354
self .rejection_sampler = AscendRejectionSampler ()
355
+ self .actual_seq_lengths_q = list (
356
+ range (self .decode_token_per_req , self .max_num_tokens + 1 , self .decode_token_per_req ))
354
357
355
358
# Persistent batch.
356
359
self .input_ids = torch .zeros (self .max_num_tokens ,
@@ -369,6 +372,25 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
369
372
dtype = torch .int32 ,
370
373
device = self .device )
371
374
375
+ if self .vllm_config .model_config .use_mla and \
376
+ self .compilation_config .cudagraph_mode == CUDAGraphMode .FULL_DECODE_ONLY :
377
+ rope_dim = self .model_config .hf_text_config .qk_rope_head_dim
378
+ self .cos = torch .ones (self .max_num_reqs ,
379
+ 1 ,
380
+ 1 ,
381
+ rope_dim ,
382
+ dtype = self .dtype ,
383
+ device = self .device )
384
+ self .sin = torch .zeros (self .max_num_reqs ,
385
+ 1 ,
386
+ 1 ,
387
+ rope_dim ,
388
+ dtype = self .dtype ,
389
+ device = self .device )
390
+ else :
391
+ self .cos = None
392
+ self .sin = None
393
+
372
394
self .uses_mrope = self .model_config .uses_mrope
373
395
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
374
396
if self .uses_mrope :
@@ -1518,6 +1540,8 @@ def _prepare_inputs(
1518
1540
max_query_len = max_num_scheduled_tokens ,
1519
1541
graph_pad_size = self .graph_pad_size ,
1520
1542
decode_token_per_req = self .decode_token_per_req ,
1543
+ cos = self .cos ,
1544
+ sin = self .sin ,
1521
1545
)
1522
1546
1523
1547
if self .speculative_config and \
@@ -1547,7 +1571,7 @@ def _prepare_inputs(
1547
1571
attn_metadata_i = builder .build (
1548
1572
common_prefix_len = common_prefix_len ,
1549
1573
common_attn_metadata = common_attn_metadata ,
1550
- model = self .model ,
1574
+ model = self .get_model () ,
1551
1575
** extra_attn_metadata_args )
1552
1576
1553
1577
if self .vllm_config .model_config .use_mla :
@@ -1582,8 +1606,14 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
1582
1606
1583
1607
forward_context = get_forward_context ()
1584
1608
if forward_context .cudagraph_runtime_mode == CUDAGraphMode .FULL :
1585
- update_attn_params (self .update_stream , forward_context ,
1586
- positions .shape [0 ])
1609
+ if self .vllm_config .model_config .use_mla :
1610
+ # FIXME: Try using `auto_dispatch_capture=True`
1611
+ update_mla_attn_params (self .update_stream , forward_context ,
1612
+ positions .shape [0 ],
1613
+ self .speculative_config )
1614
+ else :
1615
+ update_attn_params (self .update_stream , forward_context ,
1616
+ positions .shape [0 ])
1587
1617
1588
1618
if get_forward_context ().sp_enabled :
1589
1619
hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
@@ -2285,7 +2315,10 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
2285
2315
block_table_tensor = self .input_batch .block_table [
2286
2316
kv_cache_group_id ].get_device_tensor ()
2287
2317
common_attn_metadata = AscendCommonAttentionMetadata (
2288
- query_start_loc = self .query_start_loc [:num_reqs + 1 ],
2318
+ query_start_loc = torch .tensor (
2319
+ [0 ] + self .actual_seq_lengths_q [:num_reqs ],
2320
+ device = self .device ,
2321
+ dtype = torch .int32 ),
2289
2322
query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs +
2290
2323
1 ],
2291
2324
seq_lens_cpu = self .seq_lens_cpu ,
@@ -2296,17 +2329,29 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
2296
2329
block_table_tensor = block_table_tensor [:num_reqs ],
2297
2330
slot_mapping = self .slot_mapping ,
2298
2331
num_computed_tokens_cpu = num_computed_tokens_cpu ,
2332
+ positions = self .positions ,
2333
+ attn_mask = self .attn_mask ,
2334
+ spec_attn_mask = self .spec_attn_mask ,
2335
+ attn_state = self .attn_state ,
2299
2336
max_query_len = max_query_len ,
2300
2337
decode_token_per_req = self .decode_token_per_req ,
2338
+ cos = self .cos ,
2339
+ sin = self .sin ,
2301
2340
)
2341
+ attn_state = AscendAttentionState .DecodeOnly
2342
+ if self .speculative_config and \
2343
+ self .speculative_config .method == "deepseek_mtp" :
2344
+ attn_state = AscendAttentionState .SpecDecoding
2302
2345
2303
2346
for attn_group in self .attn_groups [kv_cache_group_id ]:
2304
2347
if vllm_version_is ("0.10.2" ):
2305
2348
builder = attn_group .metadata_builder
2306
2349
else :
2307
2350
builder = attn_group .get_metadata_builder ()
2308
2351
attn_metadata_i = builder .build_for_graph_capture (
2309
- common_attn_metadata )
2352
+ common_attn_metadata ,
2353
+ attn_state ,
2354
+ self .get_model ())
2310
2355
for layer_name in kv_cache_group_spec .layer_names :
2311
2356
attn_metadata [layer_name ] = attn_metadata_i
2312
2357
0 commit comments