Skip to content

Commit 6d9e5f6

Browse files
authored
[0.9.1]remove chunked_prefill_for_mla (#2177)
### What this PR does / why we need it? remove chunked_prefill_for_mla ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Processed prompts: 100%|??????????| 4/4 [00:02<00:00, 1.92it/s, est. speed input: 12.46 toks/s, output: 38.34 toks/s] DP rank 2, Generated text: ' [Your Name] and I am a professional carpenter with over 10 years of experience in the industry' DP rank 2, Generated text: ' the head of state and head of government of the United States, indirectly elected to a four-year term' DP rank 2, Generated text: ' Paris, a city that is renowned for its rich history, culture, and influence on art, fashion' DP rank 2, Generated text: ' a topic of much speculation and debate. Some experts believe that AI will eventually surpass human intelligence, while' Processed prompts: 100%|??????????| 4/4 [00:02<00:00, 1.95it/s, est. speed input: 12.65 toks/s, output: 38.93 toks/s] DP rank 0, Generated text: " Dr. David Hill and today we're going to be talking about how to treat a child with a" DP rank 0, Generated text: ' the head of state and head of government of the United States, indirectly elected to a four-year term' DP rank 0, Generated text: ' Paris, a city that is renowned for its rich history, culture, and influence on art, fashion' DP rank 0, Generated text: ' here, and it’s called ChatGPT. This revolutionary technology is changing the way we interact with machines' Processed prompts: 100%|??????????| 4/4 [00:02<00:00, 1.97it/s, est. speed input: 12.79 toks/s, output: 39.36 toks/s] DP rank 1, Generated text: " Dr. David Hill and today we're going to be talking about how to treat a child's fever" DP rank 3, Generated text: ' [Your Name] and I’m here to talk to you about the importance of a healthy diet' DP rank 1, Generated text: ' the head of state and head of government of the United States, indirectly elected to a four-year term' DP rank 1, Generated text: ' Paris, a city that is renowned for its rich history, culture, and influence on art, fashion' DP rank 1, Generated text: ' a topic of much speculation and debate. Some experts believe that AI will eventually surpass human intelligence, leading' DP rank 3, Generated text: ' the head of state and head of government of the United States, indirectly elected to a four-year term' DP rank 3, Generated text: " Paris. It is the largest city in France and serves as the country's political, cultural, and" DP rank 3, Generated text: ' here, and it’s called ChatGPT. This revolutionary technology is changing the way we interact with machines --------- Signed-off-by: fems14 <[email protected]>
1 parent 6a2f792 commit 6d9e5f6

File tree

8 files changed

+40
-244
lines changed

8 files changed

+40
-244
lines changed

docs/source/user_guide/configuration/additional_config.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ The following table lists the additional configuration options available in vLLM
3030
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
3131
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
3232
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
33-
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
3433
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
3534

3635
The details of each config option are as follows:

examples/disaggregate_prefill_v1/README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ vllm serve /models/deepseek_r1_w8a8 \
7171
"engine_id": "0",
7272
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
7373
}' \
74-
--additional-config \
75-
'{"chunked_prefill_for_mla":true}'
7674
```
7775

7876
Run prefill server P2 on second node:
@@ -115,8 +113,6 @@ vllm serve /models/deepseek_r1_w8a8 \
115113
"engine_id": "0",
116114
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
117115
}' \
118-
--additional-config \
119-
'{"chunked_prefill_for_mla":true}'
120116
```
121117

122118
Run decode server d1 on third node:

tests/multicard/test_torchair_graph_mode.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch,
7171
# inaccurate. This will only change if accuracy improves with the
7272
# official weights of DeepSeek-V3.
7373
golden_results = [
74-
'Hello, my name is下载早点向前很有่อง',
75-
'The president of the United States isSender)## physiological Albany',
76-
'The capital of France is Rocky转角 hospitalizedinterval sparked',
74+
'Hello, my name is bioavailability裹格外 struct',
75+
'The president of the United States isStr Fiona tratamientoPant narciss',
76+
'The capital of France is Rocky转角){\\-Hill偷袭',
7777
'The future of AI is её asegο BIOS一扫',
7878
]
7979

vllm_ascend/ascend_config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ def __init__(self, vllm_config):
5454
self.num_wait_worker_iterations = additional_config.get(
5555
"num_wait_worker_iterations", 30
5656
) # Number of iterations to wait before applying a redistribution plan
57-
self.chunked_prefill_for_mla = additional_config.get(
58-
"chunked_prefill_for_mla",
59-
False) # Whether to enable the fused operator-like chunked_prefill
6057
self.enable_weight_nz_layout = additional_config.get(
6158
"enable_weight_nz_layout", False
6259
) # Whether to convert quantized weights to NZ format to accelerate matrix multiplication

vllm_ascend/attention/mla_v1.py

Lines changed: 33 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2222
from vllm_ascend.multistream.context import get_multistream_comm_context
2323
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
24-
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
2524
from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor
2625

2726
if TYPE_CHECKING:
@@ -211,6 +210,9 @@ def __init__(self,
211210
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
212211
self.cos_cache = None
213212
self.sin_cache = None
213+
self.prefill_attn_mask = torch.triu(
214+
torch.ones(512, 512, device=runner.device, dtype=runner.dtype),
215+
1) # 512: mask only support 512
214216

215217
def reorder_batch(self, input_batch: "InputBatch",
216218
scheduler_output: "SchedulerOutput") -> bool:
@@ -479,7 +481,7 @@ def build(
479481
prefill_input_positions].unsqueeze( # type: ignore
480482
1).unsqueeze(2)
481483
prefill_metadata = AscendMLAPrefillMetadata(
482-
attn_mask=self.runner.attn_mask,
484+
attn_mask=self.prefill_attn_mask,
483485
query_lens=query_lens[tokens_start:],
484486
seq_lens=seq_lens,
485487
context_lens=seq_lens[tokens_start:],
@@ -767,16 +769,13 @@ def _compute_prefill_context(
767769
k_nope, v = kv_nope\
768770
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
769771
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
770-
mask = torch.triu(
771-
torch.ones(512, 512, device=query.device, dtype=query.dtype),
772-
1)
773772
torch_npu.atb.npu_ring_mla(
774773
q_nope=q_nope,
775774
q_rope=q_pe,
776775
k_nope=k_nope,
777776
k_rope=k_pe,
778777
value=v,
779-
mask=mask,
778+
mask=prefill_metadata.attn_mask,
780779
seqlen=seq_len,
781780
head_num=self.num_heads,
782781
kv_head_num=self.num_heads,
@@ -808,101 +807,40 @@ def _forward_prefill(
808807
self.v_head_dim,
809808
dtype=query.dtype,
810809
device=query.device)
810+
attn_lse = torch.empty(self.num_heads,
811+
num_tokens,
812+
dtype=torch.float32,
813+
device=query.device)
811814
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
812815
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
813816
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
814817
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
815-
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
816-
ascend_config = get_ascend_config()
817-
818-
if attn_metadata.attn_state in [
819-
AscendAttentionState.ChunkedPrefill,
820-
AscendAttentionState.SpecDecoding,
821-
AscendAttentionState.PrefillCacheHit
822-
] and not ascend_config.chunked_prefill_for_mla:
823-
attn_output_torch = torch.empty(num_tokens,
824-
self.num_heads * self.v_head_dim,
825-
dtype=query.dtype,
826-
device=query.device)
827-
# current requests is chunked in prefill, disable flash attention with chunked prefill
828-
vanilla_chunked_prefill_mla(
829-
output=attn_output_torch,
830-
query=query,
831-
kv_cache=kv_c_and_k_pe_cache,
832-
block_tables=attn_metadata.prefill.block_table,
833-
query_lens=attn_metadata.prefill.query_lens,
834-
context_lens=attn_metadata.prefill.context_lens,
835-
kv_b_proj=self.kv_b_proj,
836-
max_query_len=attn_metadata.prefill.max_query_len,
837-
max_context_len=attn_metadata.prefill.max_seq_lens,
838-
nope_dim=self.qk_nope_head_dim,
839-
rope_dim=self.qk_rope_head_dim,
840-
v_head_dim=self.v_head_dim,
841-
scale=self.scale,
842-
alibi_slopes=None,
843-
causal=True)
844-
elif attn_metadata.attn_state in [
845-
AscendAttentionState.ChunkedPrefill,
846-
AscendAttentionState.SpecDecoding,
847-
AscendAttentionState.PrefillCacheHit
848-
]:
849-
attn_lse = torch.empty(self.num_heads,
850-
num_tokens,
851-
dtype=torch.float32,
852-
device=query.device)
853-
q_pe = query[..., self.qk_nope_head_dim:]
854-
q_nope = query[..., :self.qk_nope_head_dim]
855-
mask = torch.triu(
856-
torch.ones(512, 512, device=query.device, dtype=query.dtype),
857-
1) # 512: mask only support 512
858-
if attn_metadata.num_prefills > 1:
859-
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
860-
1)
861-
torch_npu.atb.npu_ring_mla(
862-
q_nope=q_nope,
863-
q_rope=q_pe,
864-
k_nope=k_nope,
865-
k_rope=k_pe,
866-
value=value,
867-
mask=mask,
868-
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
869-
dtype=torch.int32),
870-
head_num=self.num_heads,
871-
kv_head_num=self.num_heads,
872-
pre_out=None,
873-
prev_lse=None,
874-
qk_scale=self.scale,
875-
kernel_type="kernel_type_high_precision",
876-
mask_type="mask_type_triu",
877-
input_layout="type_bsnd",
878-
calc_type="calc_type_first_ring",
879-
output=attn_output,
880-
softmax_lse=attn_lse)
881-
attn_output, attn_lse = self._compute_prefill_context( \
882-
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
883-
884-
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
885-
key = torch.cat((k_nope, k_pe), dim=-1)
886-
torch_npu._npu_flash_attention(
887-
query=query,
888-
key=key,
889-
value=value,
890-
mask=attn_metadata.attn_mask,
891-
seq_len=attn_metadata.prefill.context_lens,
892-
scale_value=self.scale,
893-
num_heads=self.num_heads,
894-
num_kv_heads=self.num_heads,
895-
out=attn_output)
896-
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
818+
q_pe = query[..., self.qk_nope_head_dim:]
819+
q_nope = query[..., :self.qk_nope_head_dim]
820+
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
821+
q_rope=q_pe,
822+
k_nope=k_nope,
823+
k_rope=k_pe,
824+
value=value,
825+
mask=attn_metadata.prefill.attn_mask,
826+
seqlen=torch.tensor(
827+
attn_metadata.prefill.query_lens,
828+
dtype=torch.int32),
829+
head_num=self.num_heads,
830+
kv_head_num=self.num_heads,
831+
pre_out=None,
832+
prev_lse=None,
833+
qk_scale=self.scale,
834+
kernel_type="kernel_type_high_precision",
835+
mask_type="mask_type_triu",
836+
input_layout="type_bsnd",
837+
calc_type="calc_type_first_ring",
838+
output=attn_output,
839+
softmax_lse=attn_lse)
840+
attn_output, attn_lse = self._compute_prefill_context( \
841+
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
897842
attn_output = attn_output.reshape(
898843
[num_tokens, self.num_heads * self.v_head_dim])
899-
if attn_metadata.attn_state in [
900-
AscendAttentionState.ChunkedPrefill,
901-
AscendAttentionState.SpecDecoding,
902-
AscendAttentionState.PrefillCacheHit
903-
] and not ascend_config.chunked_prefill_for_mla:
904-
attn_output = attn_output_torch
905-
906844
return attn_output
907845

908846
def exec_kv(

vllm_ascend/multistream/ms_split.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def model_input_split_v1_mla_attn(
167167
attn_metadata.prefill.sin,
168168
token_index - attn_metadata.num_decode_tokens)
169169
prefill_pre = AscendMLAPrefillMetadata(
170-
attn_mask=attn_mask_pre,
170+
attn_mask=attn_metadata.prefill.attn_mask,
171171
query_lens=prefill_query_lens_pre,
172172
seq_lens=seq_lens_pre,
173173
query_start_loc=prefill_query_start_loc_pre,
@@ -179,7 +179,7 @@ def model_input_split_v1_mla_attn(
179179
cos=cos_pre,
180180
sin=sin_pre)
181181
prefill_post = AscendMLAPrefillMetadata(
182-
attn_mask=attn_mask_post,
182+
attn_mask=attn_metadata.prefill.attn_mask,
183183
query_lens=prefill_query_lens_post,
184184
seq_lens=seq_lens_post,
185185
query_start_loc=prefill_query_start_loc_post,

vllm_ascend/ops/attention.py

Lines changed: 1 addition & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18-
from typing import List, Optional, Tuple
18+
from typing import List, Optional
1919

2020
import torch
21-
from vllm.model_executor.layers.linear import ColumnParallelLinear
2221

2322

2423
# Implementation of vanilla chunked prefill, should be removed after the kernel is ready for
@@ -135,138 +134,6 @@ def vanilla_chunked_prefill(
135134
return attn_output
136135

137136

138-
def vanilla_chunked_prefill_mla(
139-
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
140-
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
141-
kv_cache: Tuple[
142-
torch.Tensor], # [nope, rope] (num_blocks, block_size, latent_kv)
143-
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
144-
query_lens: torch.Tensor, # (batch_size)
145-
context_lens: torch.Tensor, # (batch_size)
146-
kv_b_proj: ColumnParallelLinear, # ()
147-
max_query_len: int,
148-
max_context_len: int,
149-
nope_dim: int,
150-
rope_dim: int,
151-
v_head_dim: int,
152-
scale: float,
153-
alibi_slopes: Optional[torch.Tensor],
154-
causal: bool = True) -> None:
155-
batch_size = block_tables.size(0)
156-
assert len(kv_cache) > 1
157-
assert query_lens.size(0) == batch_size
158-
num_heads = query.size(1)
159-
nope_cache = kv_cache[0]
160-
rope_cache = kv_cache[1]
161-
block_size = nope_cache.size(1)
162-
latent_kv_dim = nope_cache.size(-1)
163-
max_num_blocks_per_seq = block_tables.size(1)
164-
batch_size = query_lens.size(0)
165-
nope_cache = nope_cache.squeeze()
166-
# select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] and get kv_c and k_pe
167-
# cached_kv_c: [batch_size, max_context_len, latent_kv]
168-
# cached_k_pe: [batch_size, max_context_len, rope_dim]
169-
cache_kv_c = nope_cache[block_tables].view(
170-
batch_size, max_num_blocks_per_seq * block_size,
171-
latent_kv_dim)[:, :max_context_len, :]
172-
cache_k_pe = rope_cache[block_tables].view(
173-
batch_size, max_num_blocks_per_seq * block_size,
174-
rope_dim)[:, :max_context_len, :]
175-
# get k_rope and v
176-
# k_nope: [batch_size, max_context_len, num_heads, nope_dim]
177-
# value: [batch_size, max_context_len, num_heads, v_head_dim]
178-
k_nope, value = kv_b_proj(cache_kv_c)[0].view(
179-
batch_size, max_context_len, num_heads,
180-
nope_dim + v_head_dim).split([nope_dim, v_head_dim], dim=-1)
181-
# key: [batch_size, max_context_len, num_hads, rope_dim + nope_dim]
182-
key = torch.cat(
183-
[k_nope, cache_k_pe.unsqueeze(2).expand(-1, -1, num_heads, -1)],
184-
dim=-1)
185-
186-
context_lens = context_lens.view(-1, 1).to("npu")
187-
query_lens = query_lens.view(-1, 1).to("npu")
188-
seq_diff = context_lens - query_lens
189-
190-
q_idx_mask = (torch.arange(0, max_query_len,
191-
device="npu").view(1, -1).repeat(batch_size, 1))
192-
kv_c_idx_mask = (torch.arange(0, max_context_len,
193-
device="npu").view(1,
194-
-1).repeat(batch_size, 1))
195-
kv_c_mask = kv_c_idx_mask < context_lens
196-
q_mask = q_idx_mask < query_lens
197-
198-
# calculate idx for causal mask of query [batch, max_seqlen_q]
199-
causal_mask_idx = (q_idx_mask + seq_diff)[q_mask]
200-
201-
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
202-
tril_mask = torch.tril(
203-
torch.ones(max_context_len, max_context_len, device="npu"))
204-
tril_mask[tril_mask == 0] = float("-inf")
205-
tril_mask[tril_mask == 1] = 0
206-
causal_mask = tril_mask[causal_mask_idx]
207-
causal_mask_padding = torch.empty(
208-
[batch_size, max_query_len, max_context_len],
209-
device="npu").fill_(float("-inf"))
210-
causal_mask_padding[q_mask] = causal_mask
211-
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
212-
causal_mask_padding = causal_mask_padding.unsqueeze(1)
213-
214-
pad_q = torch.zeros(
215-
[batch_size, max_query_len, num_heads, rope_dim + nope_dim],
216-
device="npu",
217-
dtype=query.dtype,
218-
)
219-
pad_k = torch.zeros(
220-
[batch_size, max_context_len, num_heads, rope_dim + nope_dim],
221-
device="npu",
222-
dtype=key.dtype,
223-
)
224-
pad_v = torch.zeros(
225-
[batch_size, max_context_len, num_heads, v_head_dim],
226-
device="npu",
227-
dtype=value.dtype,
228-
)
229-
num_query = torch.sum(q_mask).item()
230-
num_add_query = num_query - query.size(0)
231-
# mtp will come in
232-
if num_add_query > 0:
233-
add_query_size = query.size()
234-
add_query_size = list(add_query_size)
235-
add_query_size[0] = num_add_query
236-
pad_tensor = torch.zeros(add_query_size,
237-
dtype=query.dtype,
238-
device=query.device)
239-
query = torch.cat([query, pad_tensor], dim=0)
240-
pad_q[q_mask] = query
241-
pad_k[kv_c_mask] = key[kv_c_mask]
242-
pad_v[kv_c_mask] = value[kv_c_mask]
243-
244-
pad_q = pad_q.permute(0, 2, 1, 3)
245-
pad_k = pad_k.permute(0, 2, 1, 3)
246-
pad_v = pad_v.permute(0, 2, 1, 3)
247-
attn_mask = torch.empty([batch_size, 1, 1, max_context_len],
248-
device="npu").fill_(float("-inf"))
249-
attn_mask[:, :, :, :max_context_len].masked_fill_(
250-
kv_c_mask[:, None, None, :], 0)
251-
# [b, h, f, t]
252-
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
253-
attn_weights *= scale
254-
attn_mask = attn_mask.float()
255-
attn_weights = attn_weights + attn_mask
256-
if causal:
257-
attn_weights = attn_weights + causal_mask_padding
258-
259-
attn_weights = torch.softmax(attn_weights, dim=-1)
260-
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
261-
attn_output = attn_output.permute(0, 2, 1, 3)
262-
263-
attn_output = (attn_output[q_mask].view([-1, num_heads,
264-
v_head_dim]).to(output.dtype))
265-
attn_output = attn_output.view_as(output)
266-
output.copy_(attn_output)
267-
return attn_output
268-
269-
270137
def vanilla_decode_mla(
271138
query: torch.Tensor, # [num_tokens, num_heads, latent_dim + rope_dim]
272139
key_cache: torch.

0 commit comments

Comments
 (0)