Skip to content

Commit 2de99bb

Browse files
committed
decrease number of attention ops in prefill and decode
Signed-off-by: whx-sjtu <[email protected]>
1 parent 9cf6d76 commit 2de99bb

File tree

5 files changed

+87
-185
lines changed

5 files changed

+87
-185
lines changed

docs/source/locale/zh_CN/LC_MESSAGES/user_guide/configuration/additional_config.po

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,6 @@ msgid ""
148148
" to be passed in."
149149
msgstr "在为MOE模型使用专家负载均衡时,需要传入专家映射路径。"
150150

151-
#: ../../user_guide/configuration/additional_config.md
152-
msgid "`chunked_prefill_for_mla`"
153-
msgstr "`chunked_prefill_for_mla`"
154151

155152
#: ../../user_guide/configuration/additional_config.md
156153
msgid "`False`"

docs/source/user_guide/configuration/additional_config.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ The following table lists the additional configuration options available in vLLM
3030
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
3131
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
3232
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
33-
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
3433
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
3534
| `enable_shared_expert_dp` | bool | `True` | When the shared expert in DP, it has better performance but consumes more memory. When the memory is sensitive, this switch can be turned off manually. |
3635

examples/disaggregated_prefill_v1/README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ vllm serve /models/deepseek_r1_w8a8 \
7171
"engine_id": "0",
7272
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
7373
}' \
74-
--additional-config \
75-
'{"chunked_prefill_for_mla":true}'
7674
```
7775

7876
Run prefill server P2 on second node:
@@ -115,8 +113,6 @@ vllm serve /models/deepseek_r1_w8a8 \
115113
"engine_id": "0",
116114
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
117115
}' \
118-
--additional-config \
119-
'{"chunked_prefill_for_mla":true}'
120116
```
121117

122118
Run decode server d1 on third node:

vllm_ascend/ascend_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ def __init__(self, vllm_config):
4545
ascend_scheduler_config)
4646

4747
self.expert_map_path = additional_config.get("expert_map_path", None)
48-
self.chunked_prefill_for_mla = additional_config.get(
49-
"chunked_prefill_for_mla", False)
5048
self.enable_shared_expert_dp = additional_config.get(
5149
"enable_shared_expert_dp", True
5250
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel

vllm_ascend/attention/mla_v1.py

Lines changed: 87 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -812,101 +812,43 @@ def _forward_prefill(
812812
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
813813
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
814814
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
815-
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
816-
ascend_config = get_ascend_config()
815+
attn_lse = torch.empty(self.num_heads,
816+
num_tokens,
817+
dtype=torch.float32,
818+
device=query.device)
819+
q_pe = query[..., self.qk_nope_head_dim:]
820+
q_nope = query[..., :self.qk_nope_head_dim]
821+
mask = torch.triu(
822+
torch.ones(512, 512, device=query.device, dtype=query.dtype),
823+
1) # 512: mask only support 512
824+
if attn_metadata.num_prefills > 1:
825+
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
826+
1)
827+
torch_npu.atb.npu_ring_mla(
828+
q_nope=q_nope,
829+
q_rope=q_pe,
830+
k_nope=k_nope,
831+
k_rope=k_pe,
832+
value=value,
833+
mask=mask,
834+
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
835+
dtype=torch.int32),
836+
head_num=self.num_heads,
837+
kv_head_num=self.num_heads,
838+
pre_out=None,
839+
prev_lse=None,
840+
qk_scale=self.scale,
841+
kernel_type="kernel_type_high_precision",
842+
mask_type="mask_type_triu",
843+
input_layout="type_bsnd",
844+
calc_type="calc_type_first_ring",
845+
output=attn_output,
846+
softmax_lse=attn_lse)
847+
attn_output, attn_lse = self._compute_prefill_context( \
848+
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
817849

818-
if attn_metadata.attn_state in [
819-
AscendAttentionState.ChunkedPrefill,
820-
AscendAttentionState.SpecDecoding,
821-
AscendAttentionState.PrefillCacheHit
822-
] and not ascend_config.chunked_prefill_for_mla:
823-
attn_output_torch = torch.empty(num_tokens,
824-
self.num_heads * self.v_head_dim,
825-
dtype=query.dtype,
826-
device=query.device)
827-
# current requests is chunked in prefill, disable flash attention with chunked prefill
828-
vanilla_chunked_prefill_mla(
829-
output=attn_output_torch,
830-
query=query,
831-
kv_cache=kv_c_and_k_pe_cache,
832-
block_tables=attn_metadata.prefill.block_table,
833-
query_lens=attn_metadata.prefill.query_lens,
834-
context_lens=attn_metadata.prefill.context_lens,
835-
kv_b_proj=self.kv_b_proj,
836-
max_query_len=attn_metadata.prefill.max_query_len,
837-
max_context_len=attn_metadata.prefill.max_seq_lens,
838-
nope_dim=self.qk_nope_head_dim,
839-
rope_dim=self.qk_rope_head_dim,
840-
v_head_dim=self.v_head_dim,
841-
scale=self.scale,
842-
alibi_slopes=None,
843-
causal=True)
844-
elif attn_metadata.attn_state in [
845-
AscendAttentionState.ChunkedPrefill,
846-
AscendAttentionState.SpecDecoding,
847-
AscendAttentionState.PrefillCacheHit
848-
]:
849-
attn_lse = torch.empty(self.num_heads,
850-
num_tokens,
851-
dtype=torch.float32,
852-
device=query.device)
853-
q_pe = query[..., self.qk_nope_head_dim:]
854-
q_nope = query[..., :self.qk_nope_head_dim]
855-
mask = torch.triu(
856-
torch.ones(512, 512, device=query.device, dtype=query.dtype),
857-
1) # 512: mask only support 512
858-
if attn_metadata.num_prefills > 1:
859-
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
860-
1)
861-
torch_npu.atb.npu_ring_mla(
862-
q_nope=q_nope,
863-
q_rope=q_pe,
864-
k_nope=k_nope,
865-
k_rope=k_pe,
866-
value=value,
867-
mask=mask,
868-
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
869-
dtype=torch.int32),
870-
head_num=self.num_heads,
871-
kv_head_num=self.num_heads,
872-
pre_out=None,
873-
prev_lse=None,
874-
qk_scale=self.scale,
875-
kernel_type="kernel_type_high_precision",
876-
mask_type="mask_type_triu",
877-
input_layout="type_bsnd",
878-
calc_type="calc_type_first_ring",
879-
output=attn_output,
880-
softmax_lse=attn_lse)
881-
attn_output, attn_lse = self._compute_prefill_context( \
882-
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
883-
884-
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
885-
key = torch.cat((k_nope, k_pe), dim=-1)
886-
torch_npu._npu_flash_attention(
887-
query=query,
888-
key=key,
889-
value=value,
890-
mask=attn_metadata.attn_mask,
891-
seq_len=attn_metadata.prefill.context_lens,
892-
scale_value=self.scale,
893-
num_heads=self.num_heads,
894-
num_kv_heads=self.num_heads,
895-
out=attn_output)
896-
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
897-
else:
898-
raise RuntimeError(
899-
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
900-
)
901850
attn_output = attn_output.reshape(
902851
[num_tokens, self.num_heads * self.v_head_dim])
903-
if attn_metadata.attn_state in [
904-
AscendAttentionState.ChunkedPrefill,
905-
AscendAttentionState.SpecDecoding,
906-
AscendAttentionState.PrefillCacheHit
907-
] and not ascend_config.chunked_prefill_for_mla:
908-
attn_output = attn_output_torch
909-
910852
return attn_output
911853

912854
def exec_kv(
@@ -991,91 +933,61 @@ def _forward_decode(
991933
decode_meta = attn_metadata.decode
992934
assert decode_meta is not None
993935
num_tokens = q_nope.size(0)
994-
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
995-
# shape of knope/k_pe for npu graph mode should be:
996-
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
997-
block_size = kv_c_and_k_pe_cache[0].shape[1]
998-
actual_seq_lengths = None
999-
if self.enable_kv_nz:
1000-
k_nope = k_nope.view(-1, self.num_kv_heads,
1001-
self.kv_lora_rank // 16, block_size, 16)
1002-
k_pe = k_pe.view(-1, self.num_kv_heads,
1003-
self.qk_rope_head_dim // 16, block_size, 16)
1004-
input_layout = "BSND"
1005-
else:
1006-
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
1007-
self.kv_lora_rank)
1008-
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
1009-
self.qk_rope_head_dim)
1010-
input_layout = "BNSD"
1011-
1012-
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
1013-
assert num_tokens % self.spec_token_num == 0
1014-
input_layout = "TND"
1015-
# [bs * q_seq_len, num_heads_per_rank, dim]
1016-
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
1017-
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
1018-
sparse_mode = 3
1019-
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
1020-
actual_seq_lengths = decode_meta.actual_seq_lengths_q
1021-
else:
1022-
if self.enable_kv_nz:
1023-
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
1024-
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
1025-
else:
1026-
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
1027-
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
1028-
sparse_mode = 0
1029-
spec_attn_mask = None
1030-
1031-
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
1032-
q_nope,
1033-
k_nope,
1034-
k_nope,
1035-
query_rope=q_pe,
1036-
key_rope=k_pe,
1037-
num_heads=self.num_heads,
1038-
num_key_value_heads=self.num_kv_heads,
1039-
input_layout=input_layout,
1040-
atten_mask=spec_attn_mask,
1041-
sparse_mode=sparse_mode,
1042-
scale=self.scale,
1043-
antiquant_mode=0,
1044-
antiquant_scale=None,
1045-
block_table=decode_meta.block_table,
1046-
block_size=block_size,
1047-
actual_seq_lengths_kv=decode_meta.seq_lens_list,
1048-
actual_seq_lengths=actual_seq_lengths)
936+
# shape of knope/k_pe for npu graph mode should be:
937+
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
938+
block_size = kv_c_and_k_pe_cache[0].shape[1]
939+
actual_seq_lengths = None
940+
if self.enable_kv_nz:
941+
k_nope = k_nope.view(-1, self.num_kv_heads,
942+
self.kv_lora_rank // 16, block_size, 16)
943+
k_pe = k_pe.view(-1, self.num_kv_heads,
944+
self.qk_rope_head_dim // 16, block_size, 16)
945+
input_layout = "BSND"
1049946
else:
1050-
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
1051-
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
1052-
# public available
1053-
assert len(kv_c_and_k_pe_cache) > 1
1054-
if envs.VLLM_ASCEND_MLA_PA:
1055-
attn_output = torch_npu.atb.npu_multi_head_latent_attention(
1056-
q_nope, q_pe, kv_c_and_k_pe_cache[0],
1057-
kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table,
1058-
attn_metadata.decode.seq_lens, self.num_heads, self.scale,
1059-
self.num_kv_heads)
947+
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
948+
self.kv_lora_rank)
949+
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
950+
self.qk_rope_head_dim)
951+
input_layout = "BNSD"
952+
953+
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
954+
assert num_tokens % self.spec_token_num == 0
955+
input_layout = "TND"
956+
# [bs * q_seq_len, num_heads_per_rank, dim]
957+
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
958+
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
959+
sparse_mode = 3
960+
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
961+
actual_seq_lengths = decode_meta.actual_seq_lengths_q
962+
else:
963+
if self.enable_kv_nz:
964+
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
965+
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
1060966
else:
1061-
q = torch.cat([q_nope, q_pe], dim=-1)
1062-
attn_output = torch.empty(
1063-
[num_tokens, self.num_heads, self.kv_lora_rank],
1064-
dtype=q.dtype,
1065-
device=q.device)
1066-
k_cache = torch.cat(
1067-
[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1)
1068-
torch_npu._npu_paged_attention_mla(
1069-
query=q,
1070-
key_cache=k_cache,
1071-
num_kv_heads=self.num_kv_heads,
1072-
num_heads=self.num_heads,
1073-
scale_value=self.scale,
1074-
block_table=attn_metadata.decode.
1075-
block_table, # type:ignore
1076-
context_lens=attn_metadata.decode.seq_lens, # type:ignore
1077-
mla_vheadsize=self.kv_lora_rank,
1078-
out=attn_output)
967+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
968+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
969+
sparse_mode = 0
970+
spec_attn_mask = None
971+
972+
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
973+
q_nope,
974+
k_nope,
975+
k_nope,
976+
query_rope=q_pe,
977+
key_rope=k_pe,
978+
num_heads=self.num_heads,
979+
num_key_value_heads=self.num_kv_heads,
980+
input_layout=input_layout,
981+
atten_mask=spec_attn_mask,
982+
sparse_mode=sparse_mode,
983+
scale=self.scale,
984+
antiquant_mode=0,
985+
antiquant_scale=None,
986+
block_table=decode_meta.block_table,
987+
block_size=block_size,
988+
actual_seq_lengths_kv=decode_meta.seq_lens_list,
989+
actual_seq_lengths=actual_seq_lengths)
990+
1079991
current_ms_metadata = get_multistream_comm_context()
1080992
if current_ms_metadata is None:
1081993
return self._v_up_proj(attn_output,

0 commit comments

Comments
 (0)