102
102
from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
103
103
from vllm_ascend .compilation .acl_graph import (ACLGraphWrapper ,
104
104
set_graph_params ,
105
- update_attn_params )
105
+ update_attn_params ,
106
+ update_mla_attn_params )
106
107
from vllm_ascend .eplb .adaptor .vllm_adaptor import VllmEplbAdaptor
107
108
from vllm_ascend .eplb .core .eplb_device_transfer_loader import \
108
109
D2DExpertWeightLoader
@@ -351,6 +352,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
351
352
self .speculative_config .method , self .vllm_config ,
352
353
self .device , self )
353
354
self .rejection_sampler = AscendRejectionSampler ()
355
+ self .actual_seq_lengths_q = list (
356
+ range (self .decode_token_per_req , self .max_num_tokens + 1 ,
357
+ self .decode_token_per_req ))
354
358
355
359
# Persistent batch.
356
360
self .input_ids = torch .zeros (self .max_num_tokens ,
@@ -369,6 +373,25 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
369
373
dtype = torch .int32 ,
370
374
device = self .device )
371
375
376
+ if self .vllm_config .model_config .use_mla and \
377
+ self .compilation_config .cudagraph_mode == CUDAGraphMode .FULL_DECODE_ONLY :
378
+ rope_dim = self .model_config .hf_text_config .qk_rope_head_dim
379
+ self .cos = torch .ones (self .max_num_reqs ,
380
+ 1 ,
381
+ 1 ,
382
+ rope_dim ,
383
+ dtype = self .dtype ,
384
+ device = self .device )
385
+ self .sin = torch .zeros (self .max_num_reqs ,
386
+ 1 ,
387
+ 1 ,
388
+ rope_dim ,
389
+ dtype = self .dtype ,
390
+ device = self .device )
391
+ else :
392
+ self .cos = None
393
+ self .sin = None
394
+
372
395
self .uses_mrope = self .model_config .uses_mrope
373
396
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
374
397
if self .uses_mrope :
@@ -1518,6 +1541,8 @@ def _prepare_inputs(
1518
1541
max_query_len = max_num_scheduled_tokens ,
1519
1542
graph_pad_size = self .graph_pad_size ,
1520
1543
decode_token_per_req = self .decode_token_per_req ,
1544
+ cos = self .cos ,
1545
+ sin = self .sin ,
1521
1546
)
1522
1547
1523
1548
if self .speculative_config and \
@@ -1547,7 +1572,7 @@ def _prepare_inputs(
1547
1572
attn_metadata_i = builder .build (
1548
1573
common_prefix_len = common_prefix_len ,
1549
1574
common_attn_metadata = common_attn_metadata ,
1550
- model = self .model ,
1575
+ model = self .get_model () ,
1551
1576
** extra_attn_metadata_args )
1552
1577
1553
1578
if self .vllm_config .model_config .use_mla :
@@ -1582,8 +1607,14 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
1582
1607
1583
1608
forward_context = get_forward_context ()
1584
1609
if forward_context .cudagraph_runtime_mode == CUDAGraphMode .FULL :
1585
- update_attn_params (self .update_stream , forward_context ,
1586
- positions .shape [0 ])
1610
+ if self .vllm_config .model_config .use_mla :
1611
+ # FIXME: Try using `auto_dispatch_capture=True`
1612
+ update_mla_attn_params (self .update_stream , forward_context ,
1613
+ positions .shape [0 ],
1614
+ self .speculative_config )
1615
+ else :
1616
+ update_attn_params (self .update_stream , forward_context ,
1617
+ positions .shape [0 ])
1587
1618
1588
1619
if get_forward_context ().sp_enabled :
1589
1620
hidden_states = tensor_model_parallel_all_gather (hidden_states , 0 )
@@ -2285,7 +2316,10 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
2285
2316
block_table_tensor = self .input_batch .block_table [
2286
2317
kv_cache_group_id ].get_device_tensor ()
2287
2318
common_attn_metadata = AscendCommonAttentionMetadata (
2288
- query_start_loc = self .query_start_loc [:num_reqs + 1 ],
2319
+ query_start_loc = torch .tensor (
2320
+ [0 ] + self .actual_seq_lengths_q [:num_reqs ],
2321
+ device = self .device ,
2322
+ dtype = torch .int32 ),
2289
2323
query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs +
2290
2324
1 ],
2291
2325
seq_lens_cpu = self .seq_lens_cpu ,
@@ -2296,17 +2330,27 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
2296
2330
block_table_tensor = block_table_tensor [:num_reqs ],
2297
2331
slot_mapping = self .slot_mapping ,
2298
2332
num_computed_tokens_cpu = num_computed_tokens_cpu ,
2333
+ positions = self .positions ,
2334
+ attn_mask = self .attn_mask ,
2335
+ spec_attn_mask = self .spec_attn_mask ,
2336
+ attn_state = self .attn_state ,
2299
2337
max_query_len = max_query_len ,
2300
2338
decode_token_per_req = self .decode_token_per_req ,
2339
+ cos = self .cos ,
2340
+ sin = self .sin ,
2301
2341
)
2342
+ attn_state = AscendAttentionState .DecodeOnly
2343
+ if self .speculative_config and \
2344
+ self .speculative_config .method == "deepseek_mtp" :
2345
+ attn_state = AscendAttentionState .SpecDecoding
2302
2346
2303
2347
for attn_group in self .attn_groups [kv_cache_group_id ]:
2304
2348
if vllm_version_is ("0.10.2" ):
2305
2349
builder = attn_group .metadata_builder
2306
2350
else :
2307
2351
builder = attn_group .get_metadata_builder ()
2308
2352
attn_metadata_i = builder .build_for_graph_capture (
2309
- common_attn_metadata )
2353
+ common_attn_metadata , attn_state , self . get_model () )
2310
2354
for layer_name in kv_cache_group_spec .layer_names :
2311
2355
attn_metadata [layer_name ] = attn_metadata_i
2312
2356
0 commit comments