Skip to content

Commit eb43a47

Browse files
[Feat] chunkprefill mla support torchair graph (#1772)
chunkprefill mla only support eager mode now,we want to optimaze it by support torchair graph, the idea is simple, when all the request is running in decode, use torchair graph to deal with it, else when chunkprefill or prefill only, use the eager mode - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@ebf7605 Signed-off-by: haojiangzheng <[email protected]> Co-authored-by: haojiangzheng <[email protected]>
1 parent 881e36d commit eb43a47

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,7 @@ def test_rope_single(self, mock_rope):
664664
def test_forward_decode_without_graph(self, mock_page_attention_mla,
665665
mock_up_proj):
666666
self.impl.running_in_graph = False
667+
self.impl.running_chunkprefilll_with_torchair = False
667668
num_tokens = 100
668669
num_blocks = 256
669670
block_size = 4

vllm_ascend/attention/mla_v1.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,7 @@ def _forward_decode(
998998
decode_meta = attn_metadata.decode
999999
assert decode_meta is not None
10001000
num_tokens = q_nope.size(0)
1001-
if self.running_in_graph:
1001+
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
10021002
# shape of knope/k_pe for npu graph mode should be:
10031003
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
10041004
block_size = kv_c_and_k_pe_cache[0].shape[1]
@@ -1112,6 +1112,7 @@ def forward(
11121112
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
11131113
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
11141114
]
1115+
self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill
11151116
num_actual_toks = attn_metadata.num_actual_tokens
11161117
if k_pe is None and not self.running_in_graph:
11171118
kv_c, k_pe = self.kv_a_proj_with_mqa(
@@ -1148,18 +1149,25 @@ def forward(
11481149
if has_decode:
11491150
decode_k_nope = None
11501151
assert attn_metadata.decode is not None
1151-
if self.running_in_graph:
1152+
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
11521153
cos = attn_metadata.decode.cos
11531154
sin = attn_metadata.decode.sin
1154-
with npu_stream_switch("mla_secondary",
1155-
0,
1156-
enabled=enable_multistream_mla):
1157-
npu_wait_tensor(hidden_states_or_kv_c_normed,
1158-
ckq,
1159-
enabled=enable_multistream_mla)
1155+
if self.running_chunkprefilll_with_torchair:
1156+
decode_hs = (
1157+
hidden_states_or_kv_c_normed[:num_decode_tokens])
1158+
slots = attn_metadata.slot_mapping[:num_decode_tokens]
11601159
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1161-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1162-
attn_metadata.slot_mapping)
1160+
decode_hs, cos, sin, kv_cache, slots)
1161+
else:
1162+
with npu_stream_switch("mla_secondary",
1163+
0,
1164+
enabled=enable_multistream_mla):
1165+
npu_wait_tensor(hidden_states_or_kv_c_normed,
1166+
ckq,
1167+
enabled=enable_multistream_mla)
1168+
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1169+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1170+
attn_metadata.slot_mapping)
11631171
# Without explicitly controlling the order, IndexByTensor operations
11641172
# would be placed after `matmul W_KV_T` hindering the overlapping of
11651173
# KvRmsNormRopeCache and SingleRope.
@@ -1183,6 +1191,8 @@ def forward(
11831191
decode_k_pe,
11841192
enabled=enable_multistream_mla)
11851193
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
1194+
elif self.running_chunkprefilll_with_torchair:
1195+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
11861196
else:
11871197
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
11881198
attn_metadata.decode.input_positions,
@@ -1221,16 +1231,15 @@ def forward(
12211231
kv_cache
12221232
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
12231233
if self.torchair_graph_enabled:
1224-
if kv_cache[0].numel(
1225-
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
1234+
if kv_cache[0].numel() > 0 and has_prefill:
12261235
slots = attn_metadata.slot_mapping
12271236
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
1228-
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
1229-
num_tokens, self.num_kv_heads, -1),
1230-
value=prefill_k_pe,
1231-
key_cache=kv_cache[0],
1232-
value_cache=kv_cache[1],
1233-
slot_indices=slots)
1237+
torch_npu._npu_reshape_and_cache(
1238+
key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1),
1239+
value=prefill_k_pe,
1240+
key_cache=kv_cache[0],
1241+
value_cache=kv_cache[1],
1242+
slot_indices=slots[num_decode_tokens:])
12341243
else:
12351244
kv_c_normed = kv_c_normed.view(
12361245
[num_actual_toks, self.num_kv_heads, -1])

0 commit comments

Comments
 (0)