Skip to content

Commit a0a497f

Browse files
committed
[Feat]Make full graph mode compalible with MTP
Signed-off-by: anon189Ty <[email protected]>
1 parent 83092d9 commit a0a497f

File tree

5 files changed

+176
-23
lines changed

5 files changed

+176
-23
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def build_for_graph_capture(
237237
self,
238238
common_attn_metadata: AscendCommonAttentionMetadata,
239239
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
240+
model: Optional[nn.Module] = None,
240241
):
241242
if attn_state == AscendAttentionState.DecodeOnly:
242243
attn_metadata = self.build(

vllm_ascend/attention/mla_v1.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def split_metadata_for_multistream(
169169
class AscendMLAMetadataBuilder:
170170
# Does this backend/builder support ACL Graphs for attention (default: no).
171171
aclgraph_support: ClassVar[AttentionCGSupport] = \
172-
AttentionCGSupport.NEVER
172+
AttentionCGSupport.UNIFORM_BATCH
173173
"""
174174
NOTE: Please read the comment at the top of the file before trying to
175175
understand this class
@@ -389,6 +389,8 @@ def build(
389389

390390
decode_metadata = None
391391
if num_decodes > 0:
392+
cos = common_attn_metadata.cos
393+
sin = common_attn_metadata.sin
392394
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
393395
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
394396
max_seq_lens = seq_lens[:num_decodes].max().item()
@@ -397,21 +399,45 @@ def build(
397399
block_table = block_table[:num_decodes, ...]
398400
seq_lens_list = seq_lens.tolist()
399401

400-
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
401-
1).unsqueeze(2)
402-
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
403-
1).unsqueeze(2)
404-
405-
decode_metadata = AscendMLADecodeMetadata(
406-
input_positions=input_positions,
407-
block_table=block_table,
408-
seq_lens=seq_lens,
409-
seq_lens_list=seq_lens_list,
410-
max_seq_lens=max_seq_lens,
411-
attn_mask=common_attn_metadata.spec_attn_mask,
412-
actual_seq_lengths_q=actual_seq_lengths_q,
413-
sin=sin,
414-
cos=cos)
402+
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
403+
assert self.cos_cache is not None
404+
assert self.sin_cache is not None
405+
if cos is None and sin is None:
406+
cos = self.cos_cache[
407+
input_positions].unsqueeze( # type: ignore
408+
1).unsqueeze(2)
409+
sin = self.sin_cache[
410+
input_positions].unsqueeze( # type: ignore
411+
1).unsqueeze(2)
412+
413+
decode_metadata = AscendMLADecodeMetadata(
414+
input_positions=input_positions,
415+
block_table=block_table,
416+
seq_lens=seq_lens,
417+
seq_lens_list=seq_lens_list,
418+
max_seq_lens=max_seq_lens,
419+
attn_mask=common_attn_metadata.spec_attn_mask,
420+
actual_seq_lengths_q=actual_seq_lengths_q,
421+
sin=sin,
422+
cos=cos)
423+
else:
424+
cos[:num_decode_tokens,
425+
...] = self.cos_cache[input_positions].unsqueeze(
426+
1).unsqueeze(2)
427+
sin[:num_decode_tokens,
428+
...] = self.sin_cache[input_positions].unsqueeze(
429+
1).unsqueeze(2)
430+
431+
decode_metadata = AscendMLADecodeMetadata(
432+
input_positions=input_positions,
433+
block_table=block_table,
434+
seq_lens=seq_lens,
435+
seq_lens_list=seq_lens_list,
436+
max_seq_lens=max_seq_lens,
437+
attn_mask=common_attn_metadata.spec_attn_mask,
438+
actual_seq_lengths_q=actual_seq_lengths_q,
439+
sin=sin[:num_decode_tokens, ...],
440+
cos=cos[:num_decode_tokens, ...])
415441

416442
return self.metadata_cls( # type: ignore
417443
num_actual_tokens=num_actual_tokens,
@@ -431,6 +457,29 @@ def build(
431457
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
432458
)
433459

460+
def build_for_graph_capture(
461+
self,
462+
common_attn_metadata: AscendCommonAttentionMetadata,
463+
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
464+
model: Optional[nn.Module] = None,
465+
):
466+
if attn_state in {
467+
AscendAttentionState.DecodeOnly,
468+
AscendAttentionState.SpecDecoding
469+
}:
470+
attn_metadata = self.build(
471+
common_prefix_len=0,
472+
common_attn_metadata=common_attn_metadata,
473+
model=model,
474+
)
475+
else:
476+
raise NotImplementedError(
477+
"Currently we only supoort building dummy metadata for DecodeOnly and SpecDecoding state"
478+
)
479+
480+
attn_metadata.attn_state = attn_state
481+
return attn_metadata
482+
434483

435484
class DecodeMLAPreprocessResult(NamedTuple):
436485
ql_nope: Optional[torch.Tensor] = None
@@ -814,7 +863,8 @@ def _forward_decode(
814863

815864
if attn_metadata.attn_state in [
816865
AscendAttentionState.SpecDecoding,
817-
AscendAttentionState.ChunkedPrefill
866+
AscendAttentionState.ChunkedPrefill,
867+
AscendAttentionState.DecodeOnly,
818868
] and self.speculative_config is not None:
819869
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
820870
input_layout = "TND"

vllm_ascend/attention/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ class AscendCommonAttentionMetadata:
6363

6464
graph_pad_size: int = -1
6565

66+
# NOTE: This is a temporary solution for rotary embedding in MLA
67+
cos: torch.Tensor = None
68+
sin: torch.Tensor = None
69+
6670

6771
def split_decodes_and_prefills(
6872
common_attn_metadata: AscendCommonAttentionMetadata,

vllm_ascend/compilation/acl_graph.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,60 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
229229
event.record(update_stream)
230230

231231

232+
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
233+
speculative_config):
234+
graph_params = get_graph_params()
235+
# FIXME: Behold! We are using a temporary hack here to update the args
236+
# for each layer's attention op in the graph.
237+
for key, param, handle, event in zip(
238+
forward_context.attn_metadata,
239+
graph_params.attn_params[runtime_shape],
240+
graph_params.handles[runtime_shape],
241+
graph_params.events[runtime_shape],
242+
):
243+
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
244+
spec_attn_mask, sparse_mode, scale, block_table, block_size,
245+
seq_lens_list, actual_seq_lengths, workspace, attn_output,
246+
softmax_lse) = param
247+
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
248+
if speculative_config and speculative_config.method == "deepseek_mtp":
249+
actual_seq_lengths = forward_context.attn_metadata[
250+
key].decode.actual_seq_lengths_q
251+
spec_multiple = speculative_config.num_speculative_tokens + 1
252+
seq_lens_list = seq_lens_list + [0] * (
253+
runtime_shape // spec_multiple - len(seq_lens_list))
254+
actual_seq_lengths = [
255+
spec_multiple * (i + 1)
256+
for i in range(runtime_shape // spec_multiple)
257+
]
258+
with torch.npu.stream(update_stream):
259+
torch.npu.graph_task_update_begin(update_stream, handle)
260+
261+
torch_npu.npu_fused_infer_attention_score.out(
262+
q_nope,
263+
k_nope,
264+
k_nope,
265+
query_rope=q_pe,
266+
key_rope=k_pe,
267+
num_heads=num_heads,
268+
num_key_value_heads=num_kv_heads,
269+
input_layout=input_layout,
270+
atten_mask=spec_attn_mask,
271+
sparse_mode=sparse_mode,
272+
scale=scale,
273+
antiquant_mode=0,
274+
antiquant_scale=None,
275+
block_table=block_table,
276+
block_size=block_size,
277+
actual_seq_lengths_kv=seq_lens_list,
278+
actual_seq_lengths=actual_seq_lengths,
279+
workspace=workspace,
280+
out=[attn_output, softmax_lse])
281+
torch.npu.graph_task_update_end(update_stream)
282+
283+
event.record(update_stream)
284+
285+
232286
@dataclass
233287
class GraphParams:
234288
events: dict[int, list[torch.npu.ExternalEvent]]

vllm_ascend/worker/model_runner_v1.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@
102102
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
103103
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
104104
set_graph_params,
105-
update_attn_params)
105+
update_attn_params,
106+
update_mla_attn_params)
106107
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
107108
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
108109
D2DExpertWeightLoader
@@ -351,6 +352,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
351352
self.speculative_config.method, self.vllm_config,
352353
self.device, self)
353354
self.rejection_sampler = AscendRejectionSampler()
355+
self.actual_seq_lengths_q = list(
356+
range(self.decode_token_per_req, self.max_num_tokens + 1,
357+
self.decode_token_per_req))
354358

355359
# Persistent batch.
356360
self.input_ids = torch.zeros(self.max_num_tokens,
@@ -369,6 +373,25 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
369373
dtype=torch.int32,
370374
device=self.device)
371375

376+
if self.vllm_config.model_config.use_mla and \
377+
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
378+
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
379+
self.cos = torch.ones(self.max_num_reqs,
380+
1,
381+
1,
382+
rope_dim,
383+
dtype=self.dtype,
384+
device=self.device)
385+
self.sin = torch.zeros(self.max_num_reqs,
386+
1,
387+
1,
388+
rope_dim,
389+
dtype=self.dtype,
390+
device=self.device)
391+
else:
392+
self.cos = None
393+
self.sin = None
394+
372395
self.uses_mrope = self.model_config.uses_mrope
373396
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
374397
if self.uses_mrope:
@@ -1518,6 +1541,8 @@ def _prepare_inputs(
15181541
max_query_len=max_num_scheduled_tokens,
15191542
graph_pad_size=self.graph_pad_size,
15201543
decode_token_per_req=self.decode_token_per_req,
1544+
cos=self.cos,
1545+
sin=self.sin,
15211546
)
15221547

15231548
if self.speculative_config and \
@@ -1547,7 +1572,7 @@ def _prepare_inputs(
15471572
attn_metadata_i = builder.build(
15481573
common_prefix_len=common_prefix_len,
15491574
common_attn_metadata=common_attn_metadata,
1550-
model=self.model,
1575+
model=self.get_model(),
15511576
**extra_attn_metadata_args)
15521577

15531578
if self.vllm_config.model_config.use_mla:
@@ -1582,8 +1607,14 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
15821607

15831608
forward_context = get_forward_context()
15841609
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
1585-
update_attn_params(self.update_stream, forward_context,
1586-
positions.shape[0])
1610+
if self.vllm_config.model_config.use_mla:
1611+
# FIXME: Try using `auto_dispatch_capture=True`
1612+
update_mla_attn_params(self.update_stream, forward_context,
1613+
positions.shape[0],
1614+
self.speculative_config)
1615+
else:
1616+
update_attn_params(self.update_stream, forward_context,
1617+
positions.shape[0])
15871618

15881619
if get_forward_context().sp_enabled:
15891620
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
@@ -2285,7 +2316,10 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
22852316
block_table_tensor = self.input_batch.block_table[
22862317
kv_cache_group_id].get_device_tensor()
22872318
common_attn_metadata = AscendCommonAttentionMetadata(
2288-
query_start_loc=self.query_start_loc[:num_reqs + 1],
2319+
query_start_loc=torch.tensor(
2320+
[0] + self.actual_seq_lengths_q[:num_reqs],
2321+
device=self.device,
2322+
dtype=torch.int32),
22892323
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
22902324
1],
22912325
seq_lens_cpu=self.seq_lens_cpu,
@@ -2296,17 +2330,27 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
22962330
block_table_tensor=block_table_tensor[:num_reqs],
22972331
slot_mapping=self.slot_mapping,
22982332
num_computed_tokens_cpu=num_computed_tokens_cpu,
2333+
positions=self.positions,
2334+
attn_mask=self.attn_mask,
2335+
spec_attn_mask=self.spec_attn_mask,
2336+
attn_state=self.attn_state,
22992337
max_query_len=max_query_len,
23002338
decode_token_per_req=self.decode_token_per_req,
2339+
cos=self.cos,
2340+
sin=self.sin,
23012341
)
2342+
attn_state = AscendAttentionState.DecodeOnly
2343+
if self.speculative_config and \
2344+
self.speculative_config.method == "deepseek_mtp":
2345+
attn_state = AscendAttentionState.SpecDecoding
23022346

23032347
for attn_group in self.attn_groups[kv_cache_group_id]:
23042348
if vllm_version_is("0.10.2"):
23052349
builder = attn_group.metadata_builder
23062350
else:
23072351
builder = attn_group.get_metadata_builder()
23082352
attn_metadata_i = builder.build_for_graph_capture(
2309-
common_attn_metadata)
2353+
common_attn_metadata, attn_state, self.get_model())
23102354
for layer_name in kv_cache_group_spec.layer_names:
23112355
attn_metadata[layer_name] = attn_metadata_i
23122356

0 commit comments

Comments
 (0)