Skip to content

Commit 6fad01b

Browse files
committed
decrease number of attention ops in prefill and decode
Signed-off-by: whx-sjtu <[email protected]>
1 parent e1cef4e commit 6fad01b

File tree

5 files changed

+89
-194
lines changed

5 files changed

+89
-194
lines changed

docs/source/locale/zh_CN/LC_MESSAGES/user_guide/configuration/additional_config.po

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,6 @@ msgid ""
148148
" to be passed in."
149149
msgstr "在为MOE模型使用专家负载均衡时,需要传入专家映射路径。"
150150

151-
#: ../../user_guide/configuration/additional_config.md
152-
msgid "`chunked_prefill_for_mla`"
153-
msgstr "`chunked_prefill_for_mla`"
154151

155152
#: ../../user_guide/configuration/additional_config.md
156153
msgid "`False`"

docs/source/user_guide/configuration/additional_config.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ The following table lists the additional configuration options available in vLLM
3030
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
3131
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
3232
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
33-
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
3433
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
3534
| `enable_shared_expert_dp` | bool | `True` | When the shared expert in DP, it has better performance but consumes more memory. When the memory is sensitive, this switch can be turned off manually. |
3635

examples/disaggregated_prefill_v1/README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ vllm serve /models/deepseek_r1_w8a8 \
7171
"engine_id": "0",
7272
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
7373
}' \
74-
--additional-config \
75-
'{"chunked_prefill_for_mla":true}'
7674
```
7775

7876
Run prefill server P2 on second node:
@@ -115,8 +113,6 @@ vllm serve /models/deepseek_r1_w8a8 \
115113
"engine_id": "0",
116114
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
117115
}' \
118-
--additional-config \
119-
'{"chunked_prefill_for_mla":true}'
120116
```
121117

122118
Run decode server d1 on third node:

vllm_ascend/ascend_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ def __init__(self, vllm_config):
4545
ascend_scheduler_config)
4646

4747
self.expert_map_path = additional_config.get("expert_map_path", None)
48-
self.chunked_prefill_for_mla = additional_config.get(
49-
"chunked_prefill_for_mla", False)
5048
self.enable_shared_expert_dp = additional_config.get(
5149
"enable_shared_expert_dp", True
5250
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel

vllm_ascend/attention/mla_v1.py

Lines changed: 89 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import numpy as np
55
import torch
66
import torch_npu
7-
from torch import nn
8-
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
7+
from vllm.attention.backends.abstract import (AttentionBackend,
98
AttentionMetadata,
109
MLAAttentionImpl)
1110
from vllm.attention.backends.utils import PAD_SLOT_ID
@@ -15,14 +14,11 @@
1514
UnquantizedLinearMethod)
1615
from vllm.utils import cdiv, round_down
1716

18-
import vllm_ascend.envs as envs_ascend
1917
from vllm_ascend.ascend_config import get_ascend_config
2018
from vllm_ascend.attention.attention_v1 import AscendAttentionState
2119
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2220
from vllm_ascend.multistream.context import get_multistream_comm_context
2321
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
24-
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
25-
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
2622
from vllm_ascend.utils import npu_prefetch
2723
from vllm_ascend.worker.npu_input_batch import InputBatch
2824

@@ -812,101 +808,43 @@ def _forward_prefill(
812808
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
813809
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
814810
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
815-
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
816-
ascend_config = get_ascend_config()
811+
attn_lse = torch.empty(self.num_heads,
812+
num_tokens,
813+
dtype=torch.float32,
814+
device=query.device)
815+
q_pe = query[..., self.qk_nope_head_dim:]
816+
q_nope = query[..., :self.qk_nope_head_dim]
817+
mask = torch.triu(
818+
torch.ones(512, 512, device=query.device, dtype=query.dtype),
819+
1) # 512: mask only support 512
820+
if attn_metadata.num_prefills > 1:
821+
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
822+
1)
823+
torch_npu.atb.npu_ring_mla(
824+
q_nope=q_nope,
825+
q_rope=q_pe,
826+
k_nope=k_nope,
827+
k_rope=k_pe,
828+
value=value,
829+
mask=mask,
830+
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
831+
dtype=torch.int32),
832+
head_num=self.num_heads,
833+
kv_head_num=self.num_heads,
834+
pre_out=None,
835+
prev_lse=None,
836+
qk_scale=self.scale,
837+
kernel_type="kernel_type_high_precision",
838+
mask_type="mask_type_triu",
839+
input_layout="type_bsnd",
840+
calc_type="calc_type_first_ring",
841+
output=attn_output,
842+
softmax_lse=attn_lse)
843+
attn_output, attn_lse = self._compute_prefill_context( \
844+
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
817845

818-
if attn_metadata.attn_state in [
819-
AscendAttentionState.ChunkedPrefill,
820-
AscendAttentionState.SpecDecoding,
821-
AscendAttentionState.PrefillCacheHit
822-
] and not ascend_config.chunked_prefill_for_mla:
823-
attn_output_torch = torch.empty(num_tokens,
824-
self.num_heads * self.v_head_dim,
825-
dtype=query.dtype,
826-
device=query.device)
827-
# current requests is chunked in prefill, disable flash attention with chunked prefill
828-
vanilla_chunked_prefill_mla(
829-
output=attn_output_torch,
830-
query=query,
831-
kv_cache=kv_c_and_k_pe_cache,
832-
block_tables=attn_metadata.prefill.block_table,
833-
query_lens=attn_metadata.prefill.query_lens,
834-
context_lens=attn_metadata.prefill.context_lens,
835-
kv_b_proj=self.kv_b_proj,
836-
max_query_len=attn_metadata.prefill.max_query_len,
837-
max_context_len=attn_metadata.prefill.max_seq_lens,
838-
nope_dim=self.qk_nope_head_dim,
839-
rope_dim=self.qk_rope_head_dim,
840-
v_head_dim=self.v_head_dim,
841-
scale=self.scale,
842-
alibi_slopes=None,
843-
causal=True)
844-
elif attn_metadata.attn_state in [
845-
AscendAttentionState.ChunkedPrefill,
846-
AscendAttentionState.SpecDecoding,
847-
AscendAttentionState.PrefillCacheHit
848-
]:
849-
attn_lse = torch.empty(self.num_heads,
850-
num_tokens,
851-
dtype=torch.float32,
852-
device=query.device)
853-
q_pe = query[..., self.qk_nope_head_dim:]
854-
q_nope = query[..., :self.qk_nope_head_dim]
855-
mask = torch.triu(
856-
torch.ones(512, 512, device=query.device, dtype=query.dtype),
857-
1) # 512: mask only support 512
858-
if attn_metadata.num_prefills > 1:
859-
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
860-
1)
861-
torch_npu.atb.npu_ring_mla(
862-
q_nope=q_nope,
863-
q_rope=q_pe,
864-
k_nope=k_nope,
865-
k_rope=k_pe,
866-
value=value,
867-
mask=mask,
868-
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
869-
dtype=torch.int32),
870-
head_num=self.num_heads,
871-
kv_head_num=self.num_heads,
872-
pre_out=None,
873-
prev_lse=None,
874-
qk_scale=self.scale,
875-
kernel_type="kernel_type_high_precision",
876-
mask_type="mask_type_triu",
877-
input_layout="type_bsnd",
878-
calc_type="calc_type_first_ring",
879-
output=attn_output,
880-
softmax_lse=attn_lse)
881-
attn_output, attn_lse = self._compute_prefill_context( \
882-
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
883-
884-
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
885-
key = torch.cat((k_nope, k_pe), dim=-1)
886-
torch_npu._npu_flash_attention(
887-
query=query,
888-
key=key,
889-
value=value,
890-
mask=attn_metadata.attn_mask,
891-
seq_len=attn_metadata.prefill.context_lens,
892-
scale_value=self.scale,
893-
num_heads=self.num_heads,
894-
num_kv_heads=self.num_heads,
895-
out=attn_output)
896-
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
897-
else:
898-
raise RuntimeError(
899-
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
900-
)
901846
attn_output = attn_output.reshape(
902847
[num_tokens, self.num_heads * self.v_head_dim])
903-
if attn_metadata.attn_state in [
904-
AscendAttentionState.ChunkedPrefill,
905-
AscendAttentionState.SpecDecoding,
906-
AscendAttentionState.PrefillCacheHit
907-
] and not ascend_config.chunked_prefill_for_mla:
908-
attn_output = attn_output_torch
909-
910848
return attn_output
911849

912850
def exec_kv(
@@ -984,102 +922,69 @@ def _forward_decode(
984922
q_pe: torch.Tensor,
985923
k_nope: torch.Tensor,
986924
k_pe: torch.Tensor,
987-
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
925+
block_size: int,
988926
attn_metadata: AscendMLAMetadata,
989-
enable_multistream_mla: bool = False,
990927
) -> torch.Tensor:
991928
decode_meta = attn_metadata.decode
992929
assert decode_meta is not None
993930
num_tokens = q_nope.size(0)
994-
if self.running_in_graph or self.running_chunkprefilll_with_torchair:
995-
# shape of knope/k_pe for npu graph mode should be:
996-
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
997-
block_size = kv_c_and_k_pe_cache[0].shape[1]
998-
actual_seq_lengths = None
999-
if self.enable_kv_nz:
1000-
k_nope = k_nope.view(-1, self.num_kv_heads,
1001-
self.kv_lora_rank // 16, block_size, 16)
1002-
k_pe = k_pe.view(-1, self.num_kv_heads,
1003-
self.qk_rope_head_dim // 16, block_size, 16)
1004-
input_layout = "BSND"
1005-
else:
1006-
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
1007-
self.kv_lora_rank)
1008-
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
1009-
self.qk_rope_head_dim)
1010-
input_layout = "BNSD"
1011-
1012-
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
1013-
assert num_tokens % self.spec_token_num == 0
1014-
input_layout = "TND"
1015-
# [bs * q_seq_len, num_heads_per_rank, dim]
1016-
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
1017-
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
1018-
sparse_mode = 3
1019-
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
1020-
actual_seq_lengths = decode_meta.actual_seq_lengths_q
1021-
else:
1022-
if self.enable_kv_nz:
1023-
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
1024-
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
1025-
else:
1026-
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
1027-
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
1028-
sparse_mode = 0
1029-
spec_attn_mask = None
1030-
1031-
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
1032-
q_nope,
1033-
k_nope,
1034-
k_nope,
1035-
query_rope=q_pe,
1036-
key_rope=k_pe,
1037-
num_heads=self.num_heads,
1038-
num_key_value_heads=self.num_kv_heads,
1039-
input_layout=input_layout,
1040-
atten_mask=spec_attn_mask,
1041-
sparse_mode=sparse_mode,
1042-
scale=self.scale,
1043-
antiquant_mode=0,
1044-
antiquant_scale=None,
1045-
block_table=decode_meta.block_table,
1046-
block_size=block_size,
1047-
actual_seq_lengths_kv=decode_meta.seq_lens_list,
1048-
actual_seq_lengths=actual_seq_lengths)
931+
# shape of knope/k_pe for npu graph mode should be:
932+
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
933+
actual_seq_lengths = None
934+
if self.enable_kv_nz:
935+
k_nope = k_nope.view(-1, self.num_kv_heads,
936+
self.kv_lora_rank // 16, block_size, 16)
937+
k_pe = k_pe.view(-1, self.num_kv_heads,
938+
self.qk_rope_head_dim // 16, block_size, 16)
939+
input_layout = "BSND"
1049940
else:
1050-
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
1051-
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
1052-
# public available
1053-
assert len(kv_c_and_k_pe_cache) > 1
1054-
if envs_ascend.VLLM_ASCEND_MLA_PA:
1055-
attn_output = torch_npu.atb.npu_multi_head_latent_attention(
1056-
q_nope, q_pe, kv_c_and_k_pe_cache[0],
1057-
kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table,
1058-
attn_metadata.decode.seq_lens, self.num_heads, self.scale,
1059-
self.num_kv_heads)
941+
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
942+
self.kv_lora_rank)
943+
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
944+
self.qk_rope_head_dim)
945+
input_layout = "BNSD"
946+
947+
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
948+
assert num_tokens % self.spec_token_num == 0
949+
input_layout = "TND"
950+
# [bs * q_seq_len, num_heads_per_rank, dim]
951+
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
952+
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
953+
sparse_mode = 3
954+
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
955+
actual_seq_lengths = decode_meta.actual_seq_lengths_q
956+
else:
957+
if self.enable_kv_nz:
958+
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
959+
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
1060960
else:
1061-
q = torch.cat([q_nope, q_pe], dim=-1)
1062-
attn_output = torch.empty(
1063-
[num_tokens, self.num_heads, self.kv_lora_rank],
1064-
dtype=q.dtype,
1065-
device=q.device)
1066-
k_cache = torch.cat(
1067-
[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1)
1068-
torch_npu._npu_paged_attention_mla(
1069-
query=q,
1070-
key_cache=k_cache,
1071-
num_kv_heads=self.num_kv_heads,
1072-
num_heads=self.num_heads,
1073-
scale_value=self.scale,
1074-
block_table=attn_metadata.decode.
1075-
block_table, # type:ignore
1076-
context_lens=attn_metadata.decode.seq_lens, # type:ignore
1077-
mla_vheadsize=self.kv_lora_rank,
1078-
out=attn_output)
961+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
962+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
963+
sparse_mode = 0
964+
spec_attn_mask = None
965+
966+
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
967+
q_nope,
968+
k_nope,
969+
k_nope,
970+
query_rope=q_pe,
971+
key_rope=k_pe,
972+
num_heads=self.num_heads,
973+
num_key_value_heads=self.num_kv_heads,
974+
input_layout=input_layout,
975+
atten_mask=spec_attn_mask,
976+
sparse_mode=sparse_mode,
977+
scale=self.scale,
978+
antiquant_mode=0,
979+
antiquant_scale=None,
980+
block_table=decode_meta.block_table,
981+
block_size=block_size,
982+
actual_seq_lengths_kv=decode_meta.seq_lens_list,
983+
actual_seq_lengths=actual_seq_lengths)
984+
1079985
current_ms_metadata = get_multistream_comm_context()
1080986
if current_ms_metadata is None:
1081-
return self._v_up_proj(attn_output,
1082-
enable_multistream_mla)
987+
return self._v_up_proj(attn_output)
1083988
else:
1084989
current_ms_metadata.before_comm_event.record()
1085990
with torch.npu.stream(current_ms_metadata.comm_stream):

0 commit comments

Comments
 (0)