Skip to content

Commit e3d7f83

Browse files
committed
[Feat]Make full graph mode compalible with MTP
Signed-off-by: anon189Ty <[email protected]>
1 parent 83092d9 commit e3d7f83

File tree

5 files changed

+178
-24
lines changed

5 files changed

+178
-24
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def build_for_graph_capture(
237237
self,
238238
common_attn_metadata: AscendCommonAttentionMetadata,
239239
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
240+
model: Optional[nn.Module] = None,
240241
):
241242
if attn_state == AscendAttentionState.DecodeOnly:
242243
attn_metadata = self.build(

vllm_ascend/attention/mla_v1.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def split_metadata_for_multistream(
169169
class AscendMLAMetadataBuilder:
170170
# Does this backend/builder support ACL Graphs for attention (default: no).
171171
aclgraph_support: ClassVar[AttentionCGSupport] = \
172-
AttentionCGSupport.NEVER
172+
AttentionCGSupport.UNIFORM_BATCH
173173
"""
174174
NOTE: Please read the comment at the top of the file before trying to
175175
understand this class
@@ -389,29 +389,55 @@ def build(
389389

390390
decode_metadata = None
391391
if num_decodes > 0:
392+
cos = common_attn_metadata.cos
393+
sin = common_attn_metadata.sin
392394
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
393395
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
394396
max_seq_lens = seq_lens[:num_decodes].max().item()
395397
seq_lens = seq_lens[:num_decodes]
396398
input_positions = input_positions[:num_decode_tokens]
397399
block_table = block_table[:num_decodes, ...]
398400
seq_lens_list = seq_lens.tolist()
399-
400-
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
401-
1).unsqueeze(2)
402-
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
403-
1).unsqueeze(2)
404-
405-
decode_metadata = AscendMLADecodeMetadata(
406-
input_positions=input_positions,
407-
block_table=block_table,
408-
seq_lens=seq_lens,
409-
seq_lens_list=seq_lens_list,
410-
max_seq_lens=max_seq_lens,
411-
attn_mask=common_attn_metadata.spec_attn_mask,
412-
actual_seq_lengths_q=actual_seq_lengths_q,
413-
sin=sin,
414-
cos=cos)
401+
402+
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
403+
assert self.cos_cache is not None
404+
assert self.sin_cache is not None
405+
if cos is None and sin is None:
406+
cos = self.cos_cache[
407+
input_positions].unsqueeze( # type: ignore
408+
1).unsqueeze(2)
409+
sin = self.sin_cache[
410+
input_positions].unsqueeze( # type: ignore
411+
1).unsqueeze(2)
412+
413+
decode_metadata = AscendMLADecodeMetadata(
414+
input_positions=input_positions,
415+
block_table=block_table,
416+
seq_lens=seq_lens,
417+
seq_lens_list=seq_lens_list,
418+
max_seq_lens=max_seq_lens,
419+
attn_mask=common_attn_metadata.spec_attn_mask,
420+
actual_seq_lengths_q=actual_seq_lengths_q,
421+
sin=sin,
422+
cos=cos)
423+
else:
424+
cos[:num_decode_tokens,
425+
...] = self.cos_cache[input_positions].unsqueeze(
426+
1).unsqueeze(2)
427+
sin[:num_decode_tokens,
428+
...] = self.sin_cache[input_positions].unsqueeze(
429+
1).unsqueeze(2)
430+
431+
decode_metadata = AscendMLADecodeMetadata(
432+
input_positions=input_positions,
433+
block_table=block_table,
434+
seq_lens=seq_lens,
435+
seq_lens_list=seq_lens_list,
436+
max_seq_lens=max_seq_lens,
437+
attn_mask=common_attn_metadata.spec_attn_mask,
438+
actual_seq_lengths_q=actual_seq_lengths_q,
439+
sin=sin[:num_decode_tokens, ...],
440+
cos=cos[:num_decode_tokens, ...])
415441

416442
return self.metadata_cls( # type: ignore
417443
num_actual_tokens=num_actual_tokens,
@@ -430,6 +456,29 @@ def build(
430456
seq_lens=seq_lens,
431457
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
432458
)
459+
460+
def build_for_graph_capture(
461+
self,
462+
common_attn_metadata: AscendCommonAttentionMetadata,
463+
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
464+
model: Optional[nn.Module] = None,
465+
):
466+
if attn_state in {
467+
AscendAttentionState.DecodeOnly,
468+
AscendAttentionState.SpecDecoding
469+
}:
470+
attn_metadata = self.build(
471+
common_prefix_len=0,
472+
common_attn_metadata=common_attn_metadata,
473+
model=model,
474+
)
475+
else:
476+
raise NotImplementedError(
477+
"Currently we only supoort building dummy metadata for DecodeOnly and SpecDecoding state"
478+
)
479+
480+
attn_metadata.attn_state = attn_state
481+
return attn_metadata
433482

434483

435484
class DecodeMLAPreprocessResult(NamedTuple):
@@ -814,7 +863,8 @@ def _forward_decode(
814863

815864
if attn_metadata.attn_state in [
816865
AscendAttentionState.SpecDecoding,
817-
AscendAttentionState.ChunkedPrefill
866+
AscendAttentionState.ChunkedPrefill,
867+
AscendAttentionState.DecodeOnly,
818868
] and self.speculative_config is not None:
819869
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
820870
input_layout = "TND"

vllm_ascend/attention/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ class AscendCommonAttentionMetadata:
6363

6464
graph_pad_size: int = -1
6565

66+
# NOTE: This is a temporary solution for rotary embedding in MLA
67+
cos: torch.Tensor = None
68+
sin: torch.Tensor = None
69+
6670

6771
def split_decodes_and_prefills(
6872
common_attn_metadata: AscendCommonAttentionMetadata,

vllm_ascend/compilation/acl_graph.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,60 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
229229
event.record(update_stream)
230230

231231

232+
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
233+
speculative_config):
234+
graph_params = get_graph_params()
235+
# FIXME: Behold! We are using a temporary hack here to update the args
236+
# for each layer's attention op in the graph.
237+
for key, param, handle, event in zip(
238+
forward_context.attn_metadata,
239+
graph_params.attn_params[runtime_shape],
240+
graph_params.handles[runtime_shape],
241+
graph_params.events[runtime_shape],
242+
):
243+
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
244+
spec_attn_mask, sparse_mode, scale, block_table, block_size,
245+
seq_lens_list, actual_seq_lengths, workspace, attn_output,
246+
softmax_lse) = param
247+
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
248+
if speculative_config and speculative_config.method == "deepseek_mtp":
249+
actual_seq_lengths = forward_context.attn_metadata[
250+
key].decode.actual_seq_lengths_q
251+
spec_multiple = speculative_config.num_speculative_tokens + 1
252+
seq_lens_list = seq_lens_list + [0] * (
253+
runtime_shape // spec_multiple - len(seq_lens_list))
254+
actual_seq_lengths = [
255+
spec_multiple * (i + 1)
256+
for i in range(runtime_shape // spec_multiple)
257+
]
258+
with torch.npu.stream(update_stream):
259+
torch.npu.graph_task_update_begin(update_stream, handle)
260+
261+
torch_npu.npu_fused_infer_attention_score.out(
262+
q_nope,
263+
k_nope,
264+
k_nope,
265+
query_rope=q_pe,
266+
key_rope=k_pe,
267+
num_heads=num_heads,
268+
num_key_value_heads=num_kv_heads,
269+
input_layout=input_layout,
270+
atten_mask=spec_attn_mask,
271+
sparse_mode=sparse_mode,
272+
scale=scale,
273+
antiquant_mode=0,
274+
antiquant_scale=None,
275+
block_table=block_table,
276+
block_size=block_size,
277+
actual_seq_lengths_kv=seq_lens_list,
278+
actual_seq_lengths=actual_seq_lengths,
279+
workspace=workspace,
280+
out=[attn_output, softmax_lse])
281+
torch.npu.graph_task_update_end(update_stream)
282+
283+
event.record(update_stream)
284+
285+
232286
@dataclass
233287
class GraphParams:
234288
events: dict[int, list[torch.npu.ExternalEvent]]

vllm_ascend/worker/model_runner_v1.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@
102102
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
103103
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
104104
set_graph_params,
105-
update_attn_params)
105+
update_attn_params,
106+
update_mla_attn_params)
106107
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
107108
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
108109
D2DExpertWeightLoader
@@ -351,6 +352,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
351352
self.speculative_config.method, self.vllm_config,
352353
self.device, self)
353354
self.rejection_sampler = AscendRejectionSampler()
355+
self.actual_seq_lengths_q = list(
356+
range(self.decode_token_per_req, self.max_num_tokens + 1, self.decode_token_per_req))
354357

355358
# Persistent batch.
356359
self.input_ids = torch.zeros(self.max_num_tokens,
@@ -369,6 +372,25 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
369372
dtype=torch.int32,
370373
device=self.device)
371374

375+
if self.vllm_config.model_config.use_mla and \
376+
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
377+
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
378+
self.cos = torch.ones(self.max_num_reqs,
379+
1,
380+
1,
381+
rope_dim,
382+
dtype=self.dtype,
383+
device=self.device)
384+
self.sin = torch.zeros(self.max_num_reqs,
385+
1,
386+
1,
387+
rope_dim,
388+
dtype=self.dtype,
389+
device=self.device)
390+
else:
391+
self.cos = None
392+
self.sin = None
393+
372394
self.uses_mrope = self.model_config.uses_mrope
373395
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
374396
if self.uses_mrope:
@@ -1518,6 +1540,8 @@ def _prepare_inputs(
15181540
max_query_len=max_num_scheduled_tokens,
15191541
graph_pad_size=self.graph_pad_size,
15201542
decode_token_per_req=self.decode_token_per_req,
1543+
cos=self.cos,
1544+
sin=self.sin,
15211545
)
15221546

15231547
if self.speculative_config and \
@@ -1547,7 +1571,7 @@ def _prepare_inputs(
15471571
attn_metadata_i = builder.build(
15481572
common_prefix_len=common_prefix_len,
15491573
common_attn_metadata=common_attn_metadata,
1550-
model=self.model,
1574+
model=self.get_model(),
15511575
**extra_attn_metadata_args)
15521576

15531577
if self.vllm_config.model_config.use_mla:
@@ -1582,8 +1606,14 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
15821606

15831607
forward_context = get_forward_context()
15841608
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
1585-
update_attn_params(self.update_stream, forward_context,
1586-
positions.shape[0])
1609+
if self.vllm_config.model_config.use_mla:
1610+
# FIXME: Try using `auto_dispatch_capture=True`
1611+
update_mla_attn_params(self.update_stream, forward_context,
1612+
positions.shape[0],
1613+
self.speculative_config)
1614+
else:
1615+
update_attn_params(self.update_stream, forward_context,
1616+
positions.shape[0])
15871617

15881618
if get_forward_context().sp_enabled:
15891619
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
@@ -2285,7 +2315,10 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
22852315
block_table_tensor = self.input_batch.block_table[
22862316
kv_cache_group_id].get_device_tensor()
22872317
common_attn_metadata = AscendCommonAttentionMetadata(
2288-
query_start_loc=self.query_start_loc[:num_reqs + 1],
2318+
query_start_loc=torch.tensor(
2319+
[0] + self.actual_seq_lengths_q[:num_reqs],
2320+
device=self.device,
2321+
dtype=torch.int32),
22892322
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
22902323
1],
22912324
seq_lens_cpu=self.seq_lens_cpu,
@@ -2296,17 +2329,29 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
22962329
block_table_tensor=block_table_tensor[:num_reqs],
22972330
slot_mapping=self.slot_mapping,
22982331
num_computed_tokens_cpu=num_computed_tokens_cpu,
2332+
positions=self.positions,
2333+
attn_mask=self.attn_mask,
2334+
spec_attn_mask=self.spec_attn_mask,
2335+
attn_state=self.attn_state,
22992336
max_query_len=max_query_len,
23002337
decode_token_per_req=self.decode_token_per_req,
2338+
cos=self.cos,
2339+
sin=self.sin,
23012340
)
2341+
attn_state = AscendAttentionState.DecodeOnly
2342+
if self.speculative_config and \
2343+
self.speculative_config.method == "deepseek_mtp":
2344+
attn_state = AscendAttentionState.SpecDecoding
23022345

23032346
for attn_group in self.attn_groups[kv_cache_group_id]:
23042347
if vllm_version_is("0.10.2"):
23052348
builder = attn_group.metadata_builder
23062349
else:
23072350
builder = attn_group.get_metadata_builder()
23082351
attn_metadata_i = builder.build_for_graph_capture(
2309-
common_attn_metadata)
2352+
common_attn_metadata,
2353+
attn_state,
2354+
self.get_model())
23102355
for layer_name in kv_cache_group_spec.layer_names:
23112356
attn_metadata[layer_name] = attn_metadata_i
23122357

0 commit comments

Comments
 (0)