From f5862acc4ff21f17aaae93a6bf7494c8402f1786 Mon Sep 17 00:00:00 2001 From: LookAround Date: Wed, 24 Sep 2025 22:15:38 +0800 Subject: [PATCH 01/12] [mla backend] support dcp&cp prefill Signed-off-by: LookAround --- examples/offline_inference_npu_long_seq.py | 59 +++ vllm_ascend/ascend_config.py | 4 +- vllm_ascend/attention/mla_v1.py | 430 ++++++++++++++++++++- vllm_ascend/attention/utils.py | 34 +- vllm_ascend/core/schedule_config.py | 1 + vllm_ascend/distributed/parallel_state.py | 1 + vllm_ascend/platform.py | 32 -- vllm_ascend/worker/worker_v1.py | 5 +- 8 files changed, 515 insertions(+), 51 deletions(-) create mode 100644 examples/offline_inference_npu_long_seq.py diff --git a/examples/offline_inference_npu_long_seq.py b/examples/offline_inference_npu_long_seq.py new file mode 100644 index 0000000000..111928402e --- /dev/null +++ b/examples/offline_inference_npu_long_seq.py @@ -0,0 +1,59 @@ +import os +import time +import argparse + +from vllm import LLM, SamplingParams + +os.environ["VLLM_USE_MODELSCOPE"] = "True" +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('--input_len', type=int, default=1024) + parser.add_argument('--output_len', type=int, default=128) + parser.add_argument('--bs', type=int, default=1) + parser.add_argument('--model_path', type=str, default="deepseek-ai/DeepSeek-V2-Lite") + parser.add_argument('--tp', type=int, default=2) + parser.add_argument('--cp', type=int, default=2) + parser.add_argument('--dcp', type=int, default=2) + parser.add_argument('--iter_times', type=int, default=1) + + args = parser.parse_args() + + prompts = [ + "The capital of France is", + "Hello, my name is Tom, I am", + "The president of United States is", + "AI future is? What do you think about it? Can you give me some information or any thing you want?" + ] + + sampling_params = SamplingParams(temperature = 0.8, top_p = 0.95, max_tokens=args.output_len) + llm = LLM( + model=args.model_path, + trust_remote_code=True, + enforce_eager=True, + tensor_parallel_size=args.tp, + context_parallel_size=args.cp, + decode_context_parallel_size=args.dcp, + enable_prefix_caching=False, + enable_expert_parallel=True, + enable_chunked_prefill=False, + max_num_batched_tokens=args.input_len + 138, + max_model_len=args.input_len + args.output_len + 138, + additional_config={"ascend_scheduler_config": {"enabled": True}}, + max_num_seqs=1, + block_size=128, + gpu_memory_utilization=0.9 + ) + + t0 = time.time() + for _ in range(args.iter_times): + outputs = llm.generate(prompts, sampling_params) + t1 = time.time() + print(f"TTFT: {(t1 - t0) * 1000 / (args.iter_times * args.bs)} ms") + + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"req_num: {i}\nGenerated text: {generated_text!r}") \ No newline at end of file diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 301e64242d..e84bede7d3 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -16,6 +16,7 @@ from typing import Optional from vllm.logger import logger +from vllm.distributed import get_dcp_group TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"] @@ -60,7 +61,8 @@ def __init__(self, vllm_config): "chunked_prefill_for_mla", False) self.enable_shared_expert_dp = additional_config.get( "enable_shared_expert_dp", False - ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel + ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel \ + and not get_dcp_group().world_size > 1 self.multistream_overlap_shared_expert = additional_config.get( "multistream_overlap_shared_expert", False) self.enable_prefetch = additional_config.get("enable_prefetch", False) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index cb15bd188f..6714fb5b45 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -2,14 +2,25 @@ from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type, TypeVar) +import numpy as np import torch import torch_npu from torch import nn +import torch.distributed as dist +import torch.nn.functional as F from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, MLAAttentionImpl) from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, + get_decode_context_model_parallel_rank, + get_decode_context_model_parallel_world_size, + get_dcp_group, + get_context_model_parallel_rank, + get_context_model_parallel_world_size, + get_cp_group) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.utils import cdiv, round_down @@ -84,6 +95,18 @@ class ChunkedContextMetadata: chunked_context: Optional[ChunkedContextMetadata] = None sin: torch.Tensor = None cos: torch.Tensor = None + cp_kv_recover_idx: Optional[list[int]] = None + q_head_idx: torch.Tensor = None + q_tail_idx: torch.Tensor = None + kv_with_q_head_nomask_idx: torch.Tensor = None + kv_with_q_head_mask_idx: torch.Tensor = None + kv_with_q_tail_nomask_idx: torch.Tensor = None + kv_with_q_tail_mask_idx: torch.Tensor = None + attn_mask_seqlens: torch.Tensor = None + head_attn_nomask_seqlens: torch.Tensor = None + tail_attn_nomask_seqlens: torch.Tensor = None + q_full_idx: torch.Tensor = None + cp_prefill_mask: torch.Tensor = None @dataclass @@ -288,6 +311,23 @@ def build( num_actual_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + long_seq_metadata = common_attn_metadata.common_long_seq_metadata + + cp_kv_recover_idx = long_seq_metadata.cp_kv_recover_idx if long_seq_metadata else None + num_actual_tokens_cp_full = long_seq_metadata.num_actual_tokens_cp_full if long_seq_metadata else None + num_computed_tokens_of_cp_sp = long_seq_metadata.num_computed_tokens_of_cp_sp if long_seq_metadata else None + q_head_idx_tensor = long_seq_metadata.q_head_idx_tensor if long_seq_metadata else None + q_tail_idx_tensor = long_seq_metadata.q_tail_idx_tensor if long_seq_metadata else None + kv_with_q_head_nomask_idx_tensor = long_seq_metadata.kv_with_q_head_nomask_idx_tensor if long_seq_metadata else None + kv_with_q_head_mask_idx_tensor = long_seq_metadata.kv_with_q_head_mask_idx_tensor if long_seq_metadata else None + kv_with_q_tail_nomask_idx_tensor = long_seq_metadata.kv_with_q_tail_nomask_idx_tensor if long_seq_metadata else None + kv_with_q_tail_mask_idx_tensor = long_seq_metadata.kv_with_q_tail_mask_idx_tensor if long_seq_metadata else None + attn_mask_seqlens = long_seq_metadata.attn_mask_seqlens if long_seq_metadata else None + head_attn_nomask_seqlens = long_seq_metadata.head_attn_nomask_seqlens if long_seq_metadata else None + tail_attn_nomask_seqlens = long_seq_metadata.tail_attn_nomask_seqlens if long_seq_metadata else None + q_full_idx = long_seq_metadata.q_full_idx if long_seq_metadata else None + cp_prefill_mask = long_seq_metadata.cp_prefill_mask if long_seq_metadata else None + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) assert num_decodes + num_prefills == num_reqs @@ -299,6 +339,9 @@ def build( device = self.device block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + if num_actual_tokens_cp_full is not None: + num_actual_tokens = num_actual_tokens_cp_full + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] input_positions = common_attn_metadata.positions[: num_actual_tokens].long( @@ -383,6 +426,18 @@ def build( chunked_context=chunked_context_metadata, sin=sin, cos=cos, + cp_kv_recover_idx=cp_kv_recover_idx, + q_head_idx=q_head_idx_tensor, + q_tail_idx=q_tail_idx_tensor, + kv_with_q_head_nomask_idx=kv_with_q_head_nomask_idx_tensor, + kv_with_q_head_mask_idx=kv_with_q_head_mask_idx_tensor, + kv_with_q_tail_nomask_idx=kv_with_q_tail_nomask_idx_tensor, + kv_with_q_tail_mask_idx=kv_with_q_tail_mask_idx_tensor, + attn_mask_seqlens=attn_mask_seqlens, + head_attn_nomask_seqlens=head_attn_nomask_seqlens, + tail_attn_nomask_seqlens=tail_attn_nomask_seqlens, + q_full_idx=q_full_idx, + cp_prefill_mask=cp_prefill_mask ) decode_metadata = None @@ -435,6 +490,7 @@ class DecodeMLAPreprocessResult(NamedTuple): q_pe: Optional[torch.Tensor] = None k_nope: Optional[torch.Tensor] = None k_pe: Optional[torch.Tensor] = None + decode_q_wo_k_up: Optional[torch.Tensor] = None class PrefillMLAPreprocessResult(NamedTuple): @@ -500,6 +556,23 @@ def __init__( self.speculative_config = vllm_config.speculative_config + self.cp_size = get_context_model_parallel_world_size() + self.cp_rank = get_context_model_parallel_rank( + ) if self.cp_size > 1 else 0 + self.cp_group = get_cp_group( + ).device_group if self.cp_size > 1 else None + + self.dcp_size = get_decode_context_model_parallel_world_size() + self.dcp_rank = get_decode_context_model_parallel_rank( + ) if self.dcp_size > 1 else 0 + self.dcp_group = get_dcp_group( + ).device_group if self.dcp_size > 1 else None + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_group = get_tp_group( + ).device_group if self.tp_size > 1 else None + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -899,6 +972,13 @@ def _mla_preprocess(self, layer_name, hidden_states, kv_cache, sin = attn_metadata.decode.sin decode_ql_nope, decode_q_pe = \ self._q_proj_and_k_up_proj(decode_q_c) + if self.dcp_size > 1: + decode_q_no_split = torch.cat([decode_ql_nope, decode_q_pe], + dim=-1) + decode_q_no_split = get_dcp_group().all_gather( + decode_q_no_split, 1) + decode_ql_nope, decode_q_pe = decode_q_no_split.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) decode_q_pe = self.rope_single(decode_q_pe, cos, sin) decode_slots = attn_metadata.slot_mapping[:num_decode_tokens] decode_kv_no_split = kv_no_split[:num_decode_tokens] @@ -920,15 +1000,48 @@ def _mla_preprocess(self, layer_name, hidden_states, kv_cache, prefill_slots = attn_metadata.slot_mapping[ num_decode_tokens:num_actual_tokens] prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) - prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill( - prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) - prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], - self.num_kv_heads, -1) + if self.cp_size > 1: + kv_c, k_pe = prefill_kv_no_split.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + assert len( + kv_cache + ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" + kv_c_normed = kv_c_normed.view( + [num_actual_tokens, self.num_kv_heads, -1]) + k_pe = k_pe.unsqueeze(1) + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_pe = self.rope_single(prefill_k_pe, cos, sin) + prefill_k_c_normed = kv_c_normed[num_decode_tokens:] + + prefill_kv_c_k_pe = torch.cat( + [prefill_k_c_normed, prefill_k_pe], dim=-1) + prefill_kv_c_k_pe = get_cp_group().all_gather( + prefill_kv_c_k_pe, 0) + prefill_kv_c_k_pe = torch.index_select( + prefill_kv_c_k_pe, 0, + attn_metadata.prefill.cp_kv_recover_idx) + prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe + prefill_k_c_normed = prefill_k_c_normed.squeeze() + torch_npu._npu_reshape_and_cache( + key=kv_c_normed, + value=k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=attn_metadata.slot_mapping) + else: + prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill( + prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) prefill_k_nope, prefill_value = self.kv_b_proj( prefill_k_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + if not self.cp_size > 1: + prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], + self.num_kv_heads, -1) prefill_k_pe = prefill_k_pe.expand( (*prefill_k_nope.shape[:-1], -1)) prefill_preprocess_res = PrefillMLAPreprocessResult( @@ -970,12 +1083,20 @@ def forward( if decode_preprocess_res is not None: # MLA Preprocess for decoding - output_decode = self._forward_decode(decode_preprocess_res.ql_nope, - decode_preprocess_res.q_pe, - decode_preprocess_res.k_nope, - decode_preprocess_res.k_pe, - kv_cache[0].shape[1], - attn_metadata) + if self.cp_size * self.dcp_size > 1: + output_decode = self._forward_decode_sp( + decode_preprocess_res.ql_nope, + decode_preprocess_res.q_pe, + decode_preprocess_res.k_nope, + decode_preprocess_res.k_pe, + kv_cache[0].shape[1], + attn_metadata, + ) + else: + output_decode = self._forward_decode( + decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe, + decode_preprocess_res.k_nope, decode_preprocess_res.k_pe, + kv_cache[0].shape[1], attn_metadata) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is not None: with torch.npu.stream(current_ms_metadata.comm_stream): @@ -988,10 +1109,16 @@ def forward( # FIX: aicore move should be also placed on the comm stream in dbo, # otherwise it may affect the accuracy # TODO: use an elegant way to overlap - output_prefill = self._forward_prefill( - prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe, - prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe, - prefill_preprocess_res.value, kv_cache, attn_metadata) + if self.cp_size > 1: + output_prefill = self._forward_prefill_cp( + prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe, + prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe, + prefill_preprocess_res.value, kv_cache, attn_metadata) + else: + output_prefill = self._forward_prefill( + prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe, + prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe, + prefill_preprocess_res.value, kv_cache, attn_metadata) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is not None: with torch.npu.stream(current_ms_metadata.comm_stream): @@ -1029,3 +1156,276 @@ def forward( if has_prefill: maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) return output_padded + + def _forward_prefill_cp( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + value: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], + attn_metadata: AscendMLAMetadata, + ) -> torch.Tensor: + assert attn_metadata.prefill is not None + num_tokens = q_nope.size(0) + # Use precomputed indices from the metadata (already converted to tensors and on device) + q_head_idx = attn_metadata.prefill.q_head_idx + q_tail_idx = attn_metadata.prefill.q_tail_idx + kv_with_q_head_nomask_idx = attn_metadata.prefill.kv_with_q_head_nomask_idx + kv_with_q_head_mask_idx = attn_metadata.prefill.kv_with_q_head_mask_idx + kv_with_q_tail_nomask_idx = attn_metadata.prefill.kv_with_q_tail_nomask_idx + kv_with_q_tail_mask_idx = attn_metadata.prefill.kv_with_q_tail_mask_idx + attn_mask_seqlens = attn_metadata.prefill.attn_mask_seqlens + head_attn_nomask_seqlens = attn_metadata.prefill.head_attn_nomask_seqlens + tail_attn_nomask_seqlens = attn_metadata.prefill.tail_attn_nomask_seqlens + mask = attn_metadata.prefill.cp_prefill_mask + + output_head = self._attention_with_mask_and_nomask( + q_nope=torch.index_select(q_nope, 0, q_head_idx), + q_pe=torch.index_select(q_pe, 0, q_head_idx), + k_nope=k_nope, + k_pe=k_pe, + value=value, + kv_mask_idx=kv_with_q_head_mask_idx, + kv_nomask_idx=kv_with_q_head_nomask_idx, + attn_mask_seqlens=attn_mask_seqlens, + attn_nomask_seqlens=head_attn_nomask_seqlens, + mask=mask) + + output_tail = self._attention_with_mask_and_nomask( + q_nope=torch.index_select(q_nope, 0, q_tail_idx), + q_pe=torch.index_select(q_pe, 0, q_tail_idx), + k_nope=k_nope, + k_pe=k_pe, + value=value, + kv_mask_idx=kv_with_q_tail_mask_idx, + kv_nomask_idx=kv_with_q_tail_nomask_idx, + attn_mask_seqlens=attn_mask_seqlens, + attn_nomask_seqlens=tail_attn_nomask_seqlens, + mask=mask) + + q_full_idx = attn_metadata.prefill.q_full_idx + output = torch.index_select( + torch.cat([output_head, output_tail], dim=0), 0, q_full_idx) + + output = output.reshape([num_tokens, self.num_heads * self.v_head_dim]) + + return output + + def _attention_with_mask_and_nomask( + self, q_nope: torch.Tensor, q_pe: torch.Tensor, + k_nope: torch.Tensor, k_pe: torch.Tensor, value: torch.Tensor, + kv_mask_idx: torch.Tensor, kv_nomask_idx: torch.Tensor, + attn_mask_seqlens: torch.Tensor, attn_nomask_seqlens: torch.Tensor, + mask: torch.Tensor): + attn_output = torch.empty( + q_nope.shape[0], + self.num_heads, + self.v_head_dim, + dtype=k_pe.dtype, + device=k_pe.device) + attn_lse = torch.empty(self.num_heads, + q_pe.shape[0], + dtype=torch.float32, + device=k_pe.device) + # mask + k_nope_mask = torch.index_select(k_nope, 0, kv_mask_idx) + value_mask = torch.index_select(value, 0, kv_mask_idx) + k_pe_mask = torch.index_select(k_pe, 0, kv_mask_idx) + torch_npu.atb.npu_ring_mla(q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope_mask, + k_rope=k_pe_mask, + value=value_mask, + mask=mask, + seqlen=attn_mask_seqlens, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse) + + # nomask + if kv_nomask_idx.shape[0] == 0: + return attn_output + + k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx) + value_nomask = torch.index_select(value, 0, kv_nomask_idx) + k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx) + torch_npu.atb.npu_ring_mla(q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope_nomask, + k_rope=k_pe_nomask, + value=value_nomask, + mask=mask, + seqlen=attn_nomask_seqlens, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=attn_output, + prev_lse=attn_lse, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + input_layout="type_bsnd", + calc_type="calc_type_default", + output=attn_output, + softmax_lse=attn_lse) + return attn_output + + def _forward_decode_sp( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + block_size: int, + attn_metadata: AscendMLAMetadata, + ) -> torch.Tensor: + decode_meta = attn_metadata.decode + assert decode_meta is not None + num_tokens = q_nope.size(0) + # shape of knope/k_pe for npu graph mode should be: + # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] + if self.dcp_size > 1: + num_heads = self.num_heads * self.dcp_size + else: + num_heads = self.num_heads + + k_nope = k_nope.view(-1, block_size, self.num_kv_heads, + self.kv_lora_rank) + k_pe = k_pe.view(-1, block_size, self.num_kv_heads, + self.qk_rope_head_dim) + q_nope = q_nope.view(num_tokens, num_heads, -1) + q_pe = q_pe.view(num_tokens, num_heads, -1) + + # use cp & sp split computed token nums from scheduler to compute actual seq_len and seq_mask + num_computed_tokens_of_cp_sp = np.array( + decode_meta.num_computed_tokens_of_cp_sp) # [bs, cp_size, sp_size] + seq_mask_cp = torch.where( + torch.tensor(num_computed_tokens_of_cp_sp.sum(2)) == 0, 0, + 1).to(torch.uint8).to(q_pe.device) + seq_mask_sp = torch.where( + torch.tensor(num_computed_tokens_of_cp_sp[:, + self.cp_rank, :]) == 0, + 0, 1).to(torch.uint8).to(q_pe.device) + seq_len = num_computed_tokens_of_cp_sp[:, self.cp_rank, self.dcp_rank] + seq_len = torch.tensor(seq_len, dtype=torch.int32) + + if torch.sum(seq_len).item() == 0: + # Case that no kv_cache has been stored on this rank, no need to do following computation. + attn_output = torch.zeros( + [num_tokens, num_heads, self.kv_lora_rank], + dtype=q_nope.dtype, + device=q_nope.device) + softmax_lse = torch.full((num_tokens, num_heads, 1), + float('-inf'), + dtype=q_nope.dtype, + device=q_nope.device) + else: + attn_output, softmax_lse = torch_npu.atb.npu_multi_head_latent_attention( + q_nope, + q_pe, + k_nope, + k_pe, + decode_meta.block_table, + seq_len, + num_heads, + self.scale, + self.num_kv_heads, + return_lse=True, + calc_type="calc_type_ring") + + if self.dcp_size > 1: + # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] + attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1) + # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs] + attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous() + attn_out_lse_all2all = torch.empty_like(attn_out_lse) + dist.all_to_all_single(attn_out_lse_all2all, + attn_out_lse, + group=self.dcp_group) + # permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1] + attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1]) + attn_out_lse_split_on_seq = list( + torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1)) + # Update out&lse + attn_out_g = None + attn_lse_g = None + for i, attn_out_lse_l in enumerate(attn_out_lse_split_on_seq): + attn_out_l, attn_lse_l = torch.split(attn_out_lse_l, + [self.kv_lora_rank, 1], + dim=-1) + attn_out_g, attn_lse_g = self._update_out_and_lse( + attn_out_g, attn_lse_g, attn_out_l, attn_lse_l, + seq_mask_sp[:, i]) + attn_output = attn_out_g + softmax_lse = attn_lse_g + + if self.cp_size > 1: + # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] + attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1) + # AllGather out&lse within CP group + attn_out_lse_list = [ + torch.empty_like(attn_out_lse) for _ in range(self.cp_size) + ] + dist.all_gather(attn_out_lse_list, + attn_out_lse, + group=self.cp_group) + # Update out&lse + attn_out_g = None + attn_lse_g = None + for i, attn_out_lse_l in enumerate(attn_out_lse_list): + attn_out_l, attn_lse_l = torch.split(attn_out_lse_l, + [self.kv_lora_rank, 1], + dim=-1) + attn_out_g, attn_lse_g = self._update_out_and_lse( + attn_out_g, attn_lse_g, attn_out_l, attn_lse_l, + seq_mask_cp[:, i]) + attn_output = attn_out_g + attn_output = attn_output.reshape( + [num_tokens, self.num_heads * self.kv_lora_rank]) + # out = self.o_proj(attn_output, is_prefill=False)[0] + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is None: + return self._v_up_proj(attn_output) + else: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + return self._v_up_proj(attn_output) + +# TODO use update op to replace this + def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, + mask: torch.Tensor = None, + ): + if out is None: + out = block_out.to(torch.float32) + lse = block_lse + else: + if mask is None: + mask = torch.ones([block_out.size(0)], + dtype=torch.uint8, + device=block_out.device) + out_mask = mask[:, None, None].expand_as(block_out) + lse_mask = mask[:, None, None].expand_as(block_lse) + block_out = block_out.to(torch.float32) + out_without_update = out.clone() + lse_without_update = lse.clone() + + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + # mask + out = torch.where(out_mask, out, out_without_update) + lse = torch.where(lse_mask, lse, lse_without_update) + return out, lse \ No newline at end of file diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 519cde0c5a..3cf1e9a65e 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,6 +1,5 @@ from dataclasses import dataclass -from typing import Any, List - +from typing import Any, List, Optional import torch from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, @@ -8,6 +7,35 @@ from vllm.forward_context import ForwardContext, get_forward_context +@dataclass +class AscendCommonLongSequenceMetadata: + cp_kv_recover_idx: torch.Tensor = None + + num_actual_tokens_cp_full: Optional[int] = None + + q_head_idx_tensor: torch.Tensor = None + + q_tail_idx_tensor: torch.Tensor = None + + kv_with_q_head_nomask_idx_tensor: torch.Tensor = None + + kv_with_q_head_mask_idx_tensor: torch.Tensor = None + + kv_with_q_tail_nomask_idx_tensor: torch.Tensor = None + + kv_with_q_tail_mask_idx_tensor: torch.Tensor = None + + attn_mask_seqlens: torch.Tensor = None + + head_attn_nomask_seqlens: torch.Tensor = None + + tail_attn_nomask_seqlens: torch.Tensor = None + + q_full_idx: torch.Tensor = None + + cp_prefill_mask: torch.Tensor = None + + @dataclass class AscendCommonAttentionMetadata: """ @@ -63,6 +91,8 @@ class AscendCommonAttentionMetadata: graph_pad_size: int = -1 + common_long_seq_metadata: Optional[AscendCommonLongSequenceMetadata] = None + def split_decodes_and_prefills( common_attn_metadata: AscendCommonAttentionMetadata, diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index dcd5d05562..9e180e00ab 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -19,6 +19,7 @@ from typing import Type, Union from vllm.config import SchedulerConfig +from vllm.distributed import get_dcp_group @dataclass diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 07c707e3f5..57f4eeda03 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -53,6 +53,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): # every dp rank can generate independently (in verl integration). all_ranks = torch.arange(world_size).reshape( -1, parallel_config.data_parallel_size * + parallel_config.context_parallel_size * parallel_config.tensor_parallel_size) global _MC2 group_ranks = all_ranks.unbind(0) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index f00abcab7c..33c3515adf 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -128,38 +128,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config cache_config = vllm_config.cache_config - scheduler_config = vllm_config.scheduler_config - ascend_scheduler_config = ascend_config.ascend_scheduler_config - if vllm_version_is("0.10.2"): - structured_outputs_config = vllm_config.decoding_config - else: - structured_outputs_config = vllm_config.structured_outputs_config - - if model_config is not None and not model_config.use_mla: - logger.info( - "Non-MLA LLMs forcibly disable the chunked prefill feature," - "as the performance of operators supporting this feature " - "functionality is currently suboptimal.") - if not model_config.is_multimodal_model and \ - structured_outputs_config.backend == "auto" and \ - not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \ - not scheduler_config.send_delta_data and \ - scheduler_config.policy == "fcfs": - ascend_scheduler_config.enabled = True - chunked_prefill_enabled_in_ascend_scheduler = getattr( - ascend_scheduler_config, "enable_chunked_prefill", False) - if chunked_prefill_enabled_in_ascend_scheduler: - logger.warning( - "Chunked prefill feature is enabled in ascend_scheduler," - "but note that the operator supporting this feature " - "would lead to performance degradation.") - # In this situation, max_num_batched_tokens would have been rewritten. - # So we must make sure max_num_batched_tokens is not smaller than max_model_len. - if (scheduler_config.max_num_batched_tokens - < scheduler_config.max_model_len - and not chunked_prefill_enabled_in_ascend_scheduler): - scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len - kv_cache_dtype = vllm_config.additional_config.get( "kv_cache_dtype", None) if kv_cache_dtype is not None: diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 820ec63c12..e851cafd57 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -342,7 +342,10 @@ def _init_worker_distributed_environment(self) -> None: self.local_rank, "hccl") ensure_model_parallel_initialized( self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size) + self.parallel_config.pipeline_parallel_size, + self.parallel_config.context_parallel_size, + self.parallel_config.decode_context_parallel_size + ) init_ascend_model_parallel(self.parallel_config) ensure_kv_transfer_initialized(self.vllm_config) From d1ad588101ce60fbcb8da2911a7b491809f9d4ab Mon Sep 17 00:00:00 2001 From: chenjie Date: Wed, 24 Sep 2025 22:19:50 +0800 Subject: [PATCH 02/12] model runner support cp: input ids, position ids and slot mapping Signed-off-by: chenjie --- vllm_ascend/worker/model_runner_v1.py | 122 ++++++++++++++++++++++++-- 1 file changed, 114 insertions(+), 8 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f4656dddff..0f9091838d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -123,6 +123,11 @@ lmhead_tp_enable, vllm_version_is) from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch +if context_parallel_enable: + from vllm.distributed import get_cp_group + from vllm.distributed.parallel_state import ( + get_context_model_parallel_rank, get_context_model_parallel_world_size) + if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import SchedulerOutput @@ -256,6 +261,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): decode_max_num_seqs) self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank + self.cp_size = get_context_model_parallel_world_size( + ) if context_parallel_enable else 1 + self.cp_rank = get_context_model_parallel_rank( + ) if self.cp_size > 1 else 0 self.device = device if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP: self.prefetch_stream = torch.npu.Stream(device=device) @@ -315,7 +324,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): use_mla=self.model_config.use_mla, ) - if torch.version.cann.startswith("8.3"): + if self.cp_size > 1: + # todo: 避免长序列卡死 + self.attn_mask_builder = None + elif torch.version.cann.startswith("8.3"): self.attn_mask_builder = AttentionMaskBuilder( self.scheduler_config.max_num_batched_tokens, self.dtype, self.device) @@ -1115,6 +1127,76 @@ def _iter_mm_features(req_state: CachedRequestState): mm_embeds.append(mm_embeds_item) return mm_embeds + def _num_scheduled_tokens_prefill_cp(self, num_tokens, + num_computed_tokens, + cp_kv_recover_idx): + num_scheduled_tokens = num_tokens - num_computed_tokens + num_cp_padded_scheduled_tokens = cdiv( + num_scheduled_tokens, 2 * self.cp_size) * (2 * self.cp_size + ) # pad to 2*cp_size + cp_pad = num_cp_padded_scheduled_tokens - num_scheduled_tokens # 给sample用 + full_indices = list( + range(self.max_num_tokens * self.cp_size * self.sp_size + + self.cp_size * self.sp_size * self.max_num_reqs)) + chunk_size = num_cp_padded_scheduled_tokens // (2 * self.cp_size) + + # split position_ids (and use split position_ids to split input_ids afterwards) + req_position_cp = [] + req_position_cp.extend( + full_indices[self.cp_rank * chunk_size:(self.cp_rank + 1) * + chunk_size]) + req_position_cp.extend( + full_indices[num_cp_padded_scheduled_tokens - (self.cp_rank + 1) * + chunk_size:num_cp_padded_scheduled_tokens - + self.cp_rank * chunk_size]) + + # used to recover kv order in cp prefill (after all-gather kv and before storing kv_cache) + num_added_recover_tokens = len(cp_kv_recover_idx[0]) * self.cp_size + for rank in range(self.cp_size): + cp_kv_recover_idx[rank].extend( + full_indices[rank * chunk_size + + num_added_recover_tokens:(rank + 1) * chunk_size + + num_added_recover_tokens]) + cp_kv_recover_idx[rank].extend(full_indices[ + num_cp_padded_scheduled_tokens - (rank + 1) * chunk_size + + num_added_recover_tokens:num_cp_padded_scheduled_tokens - + rank * chunk_size + num_added_recover_tokens]) + + return req_position_cp, num_cp_padded_scheduled_tokens, cp_pad + + def _update_tokens_for_cp(self, tokens, scheduler_output: "SchedulerOutput"): + if not self.cp_size > 1: + return tokens + num_reqs = self.input_batch.num_reqs + self.num_cp_pads = np.empty(num_reqs, dtype=np.int32) + self.cp_kv_recover_idx: List[List[int]] = [[] + for _ in range(self.cp_size) + ] + self.position_cp = np.zeros(self.max_num_tokens, dtype=np.int32) + start_index = 0 + + for i, req_id in enumerate(self.input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + is_prefill = num_tokens > 1 # todo: compare num prompt tokens and num tokens + if is_prefill: + # when cp > 1 & prefill, need to pad & split sequence here + req_position_cp, num_cp_padded_scheduled_tokens, self.num_cp_pads[ + i] = self._num_scheduled_tokens_prefill_cp( + num_tokens, + self.input_batch.num_computed_tokens_cpu[i], + self.cp_kv_recover_idx) + num_tokens = len(req_position_cp) + self.position_cp[start_index:start_index + + num_tokens] = req_position_cp + start_index += num_tokens + tokens[i] = num_tokens + else: + self.num_cp_pads[i] = 0 + self.position_cp[start_index:start_index + + num_tokens] = [idx for idx in range(num_tokens)] + start_index += num_tokens + return tokens + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, @@ -1248,7 +1330,12 @@ def _prepare_inputs( # Get the number of scheduled tokens for each request. req_ids = self.input_batch.req_ids tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + original_num_scheduled_tokens = np.array(tokens, dtype=np.int32) + original_total_num_scheduled_tokens = total_num_scheduled_tokens + tokens = self._update_tokens_for_cp(tokens, scheduler_output) num_scheduled_tokens = np.array(tokens, dtype=np.int32) + # update total_num_scheduled_tokens + total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) max_num_scheduled_tokens = max(tokens) num_valid_tokens = np.array([ num_tokens - @@ -1307,9 +1394,24 @@ def _prepare_inputs( num_scheduled_tokens) positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) + if self.cp_size > 1: + original_req_indices = np.repeat(self.arange_np[:num_reqs], + original_num_scheduled_tokens) + _, original_arange = self._get_cumsum_and_arange( + original_num_scheduled_tokens) + original_positions_np = self.positions_np[:original_total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + self.position_cp[:total_num_scheduled_tokens], + out=positions_np) + np.add(self.input_batch.num_computed_tokens_cpu[original_req_indices], + original_arange, + out=original_positions_np) + else: + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + original_req_indices = req_indices + original_positions_np = positions_np # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -1340,7 +1442,7 @@ def _prepare_inputs( # Prepare some information for building Attention-Metadata # Compute and commit slot mapping self.input_batch.block_table.compute_slot_mapping( - req_indices, positions_np) + original_req_indices, original_positions_np) self.input_batch.block_table.commit_slot_mapping( total_num_scheduled_tokens) @@ -1375,9 +1477,13 @@ def _prepare_inputs( seq_lens_cpu = self.seq_lens_cpu[:num_reqs] attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) - self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, - position=positions_cpu, - attn_state=attn_state) + # todo: 避免额外显存占用 + if self.cp_size > 1: + self.attn_mask = None + else: + self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, + position=positions_cpu, + attn_state=attn_state) self.attn_state = attn_state # type: ignore self.with_prefill = with_prefill From b3016599282b60e42e267c315b5f401c98fec418 Mon Sep 17 00:00:00 2001 From: chenjie Date: Thu, 25 Sep 2025 11:25:49 +0800 Subject: [PATCH 03/12] model runner support cp: metadata, logits indices Signed-off-by: chenjie --- vllm_ascend/worker/model_runner_v1.py | 175 +++++++++++++++++++++++++- 1 file changed, 173 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0f9091838d..49581e1d71 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -98,7 +98,8 @@ set_ascend_forward_context) from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + AscendCommonLongSequenceMetadata) from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, set_graph_params, update_attn_params) @@ -1197,6 +1198,166 @@ def _update_tokens_for_cp(self, tokens, scheduler_output: "SchedulerOutput"): start_index += num_tokens return tokens + def _update_logits_indices_for_cp(self, cu_num_tokens, scheduler_output: "SchedulerOutput"): + # todo: find a better way to get is_prefill + is_prefill = list( + scheduler_output.num_scheduled_tokens.values())[0] > 1 + num_reqs = self.input_batch.num_reqs + if self.cp_size > 1 and is_prefill: + # logits_indices = cu_num_tokens - num_cp_pads[:num_reqs] - 1 # if without all-gather and only sample on cp0 + logits_indices = cu_num_tokens * self.cp_size - self.num_cp_pads[: + num_reqs] - 1 + else: + logits_indices = cu_num_tokens - 1 + return logits_indices + + def _generate_cp_metadata(self, total_num_scheduled_tokens, seq_lens, scheduler_output: "SchedulerOutput"): + # todo: find a better way to get is_prefill + is_prefill = list( + scheduler_output.num_scheduled_tokens.values())[0] > 1 + num_actual_tokens_cp_full = total_num_scheduled_tokens * ( + self.cp_size if is_prefill > 0 else 1) + long_seq_metadata = None + if self.cp_size > 1 and is_prefill > 0: + cp_kv_recover_idx = torch.zeros(num_actual_tokens_cp_full, + dtype=torch.int32, + device=self.device) + cp_kv_recover_idx.copy_(torch.tensor( + np.array(self.cp_kv_recover_idx).flatten().tolist()), + non_blocking=True) + self.cp_kv_recover_idx = cp_kv_recover_idx.to( + torch.float32).argsort().to(torch.int32) + + q_head_idx, q_tail_idx = [], [] + kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] + kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] + chunk_seqlens = [] + kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] + q_req_offset = 0 + kv_req_offset = 0 + q_head_chunk_id = self.cp_rank + q_tail_chunk_id = self.cp_size * 2 - 1 - self.cp_rank + for seq_len in seq_lens: + chunk_len = seq_len // 2 + chunk_seqlens.append(chunk_len) + q_head_idx.extend( + list(range(q_req_offset, q_req_offset + chunk_len))) + kv_with_q_head_nomask_idx.extend( + list( + range(kv_req_offset, + kv_req_offset + chunk_len * q_head_chunk_id))) + kv_with_q_head_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_head_chunk_id, + kv_req_offset + chunk_len * + (q_head_chunk_id + 1)))) + kv_with_q_head_nomask_seqlens.append(chunk_len * + q_head_chunk_id) + + q_tail_idx.extend( + list( + range(q_req_offset + chunk_len, + q_req_offset + chunk_len * 2))) + kv_with_q_tail_nomask_idx.extend( + list( + range(kv_req_offset, + kv_req_offset + chunk_len * q_tail_chunk_id))) + kv_with_q_tail_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_tail_chunk_id, + kv_req_offset + chunk_len * + (q_tail_chunk_id + 1)))) + kv_with_q_tail_nomask_seqlens.append(chunk_len * + q_tail_chunk_id) + + q_req_offset += seq_len + kv_req_offset += seq_len * self.cp_size + + # Convert lists to tensors and move to device + def _list_to_tensor(lst, device, dtype=torch.int32): + tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device) + tensor_npu.copy_(torch.tensor(lst, dtype=dtype), + non_blocking=True) + return tensor_npu + + q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) + q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) + self.q_head_idx_tensor = q_head_idx_tensor + self.q_tail_idx_tensor = q_tail_idx_tensor + + q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) + q_full_idx = q_full_idx.to(torch.float32).argsort().to(torch.int32) + self.q_full_idx = q_full_idx + + self.kv_idx_names = { + 'kv_with_q_head_nomask_idx_tensor': kv_with_q_head_nomask_idx, + 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, + 'kv_with_q_tail_nomask_idx_tensor': kv_with_q_tail_nomask_idx, + 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx + } + for key, value in self.kv_idx_names.items(): + tensor_npu = _list_to_tensor(value, self.device) + self.kv_idx_names[key] = tensor_npu + + attn_mask_seqlens = torch.tensor([chunk_seqlens, chunk_seqlens], + dtype=torch.int32) + head_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_head_nomask_seqlens], + dtype=torch.int32) + tail_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_tail_nomask_seqlens], + dtype=torch.int32) + if self.vllm_config.model_config.use_mla: + cp_prefill_mask = torch.triu( + torch.ones(512, 512, device=self.device, dtype=self.dtype), + 1) + else: + max_seq_len = max(seq_lens, default=0) + cp_prefill_mask = torch.triu( + torch.full((seq_lens.shape[0], max_seq_len, max_seq_len), + True, + device=self.device, + dtype=torch.bool), 1) + + self.extra_long_seq_kwargs = { + 'attn_mask_seqlens': attn_mask_seqlens, + 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, + 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, + 'cp_prefill_mask': cp_prefill_mask + } + long_seq_metadata = AscendCommonLongSequenceMetadata( + cp_kv_recover_idx=self.cp_kv_recover_idx, + num_actual_tokens_cp_full=num_actual_tokens_cp_full, + num_computed_tokens_of_cp_sp=self.input_batch.block_table.get_split_computed_tokens( + self.input_batch.num_computed_tokens_cpu[:self.input_batch.num_reqs]), + q_head_idx_tensor=self.q_head_idx_tensor, + q_tail_idx_tensor=self.q_tail_idx_tensor, + q_full_idx=self.q_full_idx, + kv_with_q_head_nomask_idx_tensor=self. + kv_idx_names['kv_with_q_head_nomask_idx_tensor'], + kv_with_q_head_mask_idx_tensor=self. + kv_idx_names['kv_with_q_head_mask_idx_tensor'], + kv_with_q_tail_nomask_idx_tensor=self. + kv_idx_names['kv_with_q_tail_nomask_idx_tensor'], + kv_with_q_tail_mask_idx_tensor=self. + kv_idx_names['kv_with_q_tail_mask_idx_tensor'], + attn_mask_seqlens=self. + extra_long_seq_kwargs['attn_mask_seqlens'], + head_attn_nomask_seqlens=self. + extra_long_seq_kwargs['head_attn_nomask_seqlens'], + tail_attn_nomask_seqlens=self. + extra_long_seq_kwargs['tail_attn_nomask_seqlens'], + cp_prefill_mask=self.extra_long_seq_kwargs['cp_prefill_mask']) + else: + long_seq_metadata = AscendCommonLongSequenceMetadata( + num_actual_tokens_cp_full=num_actual_tokens_cp_full, + num_computed_tokens_of_cp_sp=self.input_batch.block_table.split_computed_tokens( + self.input_batch.num_computed_tokens_cpu[:self.input_batch.num_reqs]),) + + return long_seq_metadata + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, @@ -1556,9 +1717,12 @@ def _prepare_inputs( # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. spec_decode_metadata = None - logits_indices = torch.from_numpy(cu_num_tokens - 1).to( + logits_indices = self._update_logits_indices_for_cp(cu_num_tokens, scheduler_output) + logits_indices = torch.from_numpy(logits_indices).to( self.device, non_blocking=True) else: + # cp not supported now + assert self.cp_size == 1 # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all # requests have draft tokens. @@ -1586,6 +1750,8 @@ def _prepare_inputs( self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() + # prepare cp meta data + long_seq_metadata = self._generate_cp_metadata(total_num_scheduled_tokens, seq_lens_cpu, scheduler_output) # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -1621,6 +1787,7 @@ def _prepare_inputs( max_query_len=max_num_scheduled_tokens, graph_pad_size=self.graph_pad_size, decode_token_per_req=self.decode_token_per_req, + common_long_seq_metadata=long_seq_metadata, ) if self.speculative_config and \ @@ -1693,6 +1860,10 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, pad_size = get_forward_context().pad_size if pad_size > 0: hidden_states = hidden_states[:-pad_size, :] + if self.cp_size > 1 and with_prefill: + hidden_states = get_cp_group().all_gather(hidden_states, 0) + hidden_states = torch.index_select( + hidden_states, 0, attn_metadata.prefill.cp_kv_recover_idx) return hidden_states def _build_attn_state(self, num_reqs, num_scheduled_tokens, From 2f3619778e3b667a9c16fd181aea123b36f83d87 Mon Sep 17 00:00:00 2001 From: LookAround Date: Thu, 25 Sep 2025 13:03:24 +0800 Subject: [PATCH 04/12] [mla backend] add num_computed_tokens_of_dcp_sp Signed-off-by: LookAround --- vllm_ascend/attention/mla_v1.py | 5 ++++- vllm_ascend/attention/utils.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 6714fb5b45..da4961fadf 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -122,6 +122,8 @@ class AscendMLADecodeMetadata: attn_mask: Optional[torch.Tensor] = None sin: torch.Tensor = None cos: torch.Tensor = None + num_computed_tokens_of_cp_sp: Optional[list[Optional[list[Optional[ + list[int]]]]]] = None @dataclass @@ -464,7 +466,8 @@ def build( attn_mask=common_attn_metadata.spec_attn_mask, actual_seq_lengths_q=actual_seq_lengths_q, sin=sin, - cos=cos) + cos=cos, + num_computed_tokens_of_cp_sp=num_computed_tokens_of_cp_sp) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 3cf1e9a65e..86cf016654 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -13,6 +13,9 @@ class AscendCommonLongSequenceMetadata: num_actual_tokens_cp_full: Optional[int] = None + num_computed_tokens_of_cp_sp: Optional[list[Optional[list[Optional[ + list[int]]]]]] = None + q_head_idx_tensor: torch.Tensor = None q_tail_idx_tensor: torch.Tensor = None From f887deb8f0690284a5cb9eb2210314e782f51005 Mon Sep 17 00:00:00 2001 From: LookAround Date: Thu, 25 Sep 2025 20:48:58 +0800 Subject: [PATCH 05/12] [bug] fix config & block_table bug Signed-off-by: LookAround --- vllm_ascend/ascend_config.py | 4 +--- vllm_ascend/worker/block_table.py | 37 ++++++++++++++++++++++++++----- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index e84bede7d3..301e64242d 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -16,7 +16,6 @@ from typing import Optional from vllm.logger import logger -from vllm.distributed import get_dcp_group TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"] @@ -61,8 +60,7 @@ def __init__(self, vllm_config): "chunked_prefill_for_mla", False) self.enable_shared_expert_dp = additional_config.get( "enable_shared_expert_dp", False - ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel \ - and not get_dcp_group().world_size > 1 + ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel self.multistream_overlap_shared_expert = additional_config.get( "multistream_overlap_shared_expert", False) self.enable_prefetch = additional_config.get("enable_prefetch", False) diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 307eb831aa..a05227e2bd 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -2,7 +2,7 @@ import numpy as np import torch -from vllm.distributed import get_dcp_group +from vllm.distributed import get_dcp_group, get_cp_group from vllm.utils import cdiv @@ -80,12 +80,16 @@ def __init__(self, dtype=torch.int64, device=self.device) try: + self.cp_world_size = get_cp_group().world_size + self.cp_rank = get_cp_group().rank_in_group self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group except AssertionError: # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 + self.cp_world_size = 1 + self.cp_rank = 0 self.kernel_sizes = kernel_sizes def append_row( @@ -132,14 +136,14 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # here because M (max_model_len) is not necessarily divisible by # block_size. - if self.dcp_world_size > 1: + if self.dcp_world_size * self.cp_world_size > 1: # Note(hc): The DCP implement store kvcache with an interleave # style, the kvcache for the token whose token_idx is i is # always stored on the GPU whose dcp_rank equals i % cp_world_size: # Use a "virtual block" which equals to world_size * block_size # for block_table_indices calculation. - virtual_block_size = self.block_size * self.dcp_world_size + virtual_block_size = self.block_size * self.dcp_world_size * self.cp_world_size # IMPORTANT: In hybrid mode, positions are in logical block space, # but we need to map them to the correct logical block table indices @@ -157,9 +161,11 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # Use virtual_block_size for mask calculation, which marks local # tokens. virtual_block_offsets = positions % virtual_block_size - mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank + self.current_rank = self.dcp_world_size * self.cp_rank + self.dcp_rank + mask = (virtual_block_offsets % + (self.dcp_world_size * self.cp_world_size) == self.current_rank) # Calculate local block_offsets - block_offsets = virtual_block_offsets // self.dcp_world_size + block_offsets = virtual_block_offsets // (self.dcp_world_size * self.cp_world_size) # Calculate slot_mapping slot_mapping = block_numbers * self.block_size + block_offsets # Write final slots, use -1 for not-local @@ -303,6 +309,27 @@ def commit_slot_mapping(self, num_tokens: int) -> None: for block_table in self.block_tables: block_table.commit_slot_mapping(num_tokens) + def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) -> list[list[list[int]]]: + "Splits computed token counts across dcp and sp dimensions for distributed allocation." + num_requests = len(num_computed_tokens) + num_computed_tokens_of_dcp_sp = [[ + [0] * self.dcp_world_size for _ in range(self.cp_world_size) + ] for _ in range(num_requests)] + total_ranks = self.cp_world_size * self.dcp_world_size + for req_idx in range(num_requests): + total_tokens = num_computed_tokens[req_idx] + if total_tokens <= 0: + continue + base = int(total_tokens) // total_ranks + remainder = int(total_tokens) % total_ranks + for rank_idx in range(total_ranks): + cp_idx = rank_idx // self.dcp_world_size + sp_idx = rank_idx % self.dcp_world_size + num_computed_tokens_of_dcp_sp[req_idx][cp_idx][sp_idx] = base + if rank_idx < remainder: + num_computed_tokens_of_dcp_sp[req_idx][cp_idx][sp_idx] += 1 + return num_computed_tokens_of_dcp_sp + def clear(self) -> None: for block_table in self.block_tables: block_table.clear() From 1bc86bc0dd508b4fb4a410c42489d91ebdc9618d Mon Sep 17 00:00:00 2001 From: LookAround Date: Thu, 25 Sep 2025 21:55:37 +0800 Subject: [PATCH 06/12] [optim] support not enable cp and add env Signed-off-by: LookAround --- vllm_ascend/attention/mla_v1.py | 14 +- vllm_ascend/distributed/parallel_state.py | 14 +- vllm_ascend/envs.py | 3 + vllm_ascend/utils.py | 4 + vllm_ascend/worker/block_table.py | 11 +- vllm_ascend/worker/model_runner_v1.py | 283 +++++++++++----------- vllm_ascend/worker/worker_v1.py | 22 +- 7 files changed, 191 insertions(+), 160 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index da4961fadf..bc5656d192 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -17,10 +17,7 @@ get_tp_group, get_decode_context_model_parallel_rank, get_decode_context_model_parallel_world_size, - get_dcp_group, - get_context_model_parallel_rank, - get_context_model_parallel_world_size, - get_cp_group) + get_dcp_group) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.utils import cdiv, round_down @@ -35,9 +32,13 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn -from vllm_ascend.utils import npu_prefetch +from vllm_ascend.utils import npu_prefetch, context_parallel_enable from vllm_ascend.worker.npu_input_batch import InputBatch +if context_parallel_enable(): + from vllm.distributed import (get_context_model_parallel_rank, + get_context_model_parallel_world_size, + get_cp_group) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -559,7 +560,8 @@ def __init__( self.speculative_config = vllm_config.speculative_config - self.cp_size = get_context_model_parallel_world_size() + self.cp_size = get_context_model_parallel_world_size( + ) if context_parallel_enable() else 1 self.cp_rank = get_context_model_parallel_rank( ) if self.cp_size > 1 else 0 self.cp_group = get_cp_group( diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 57f4eeda03..632a5bb8e9 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -7,6 +7,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.utils import context_parallel_enable # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None @@ -51,10 +52,15 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): # The layout of all ranks: ExternalDP * EP # ExternalDP is the data parallel group that is not part of the model, # every dp rank can generate independently (in verl integration). - all_ranks = torch.arange(world_size).reshape( - -1, parallel_config.data_parallel_size * - parallel_config.context_parallel_size * - parallel_config.tensor_parallel_size) + if context_parallel_enable(): + all_ranks = torch.arange(world_size).reshape( + -1, parallel_config.data_parallel_size * + parallel_config.context_parallel_size * + parallel_config.tensor_parallel_size) + else: + all_ranks = torch.arange(world_size).reshape( + -1, parallel_config.data_parallel_size * + parallel_config.tensor_parallel_size) global _MC2 group_ranks = all_ranks.unbind(0) group_ranks = [x.tolist() for x in group_ranks] diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 61df5e19e9..4768e18f9e 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -159,6 +159,9 @@ # caused by the initialization of the Mooncake connector. "PHYSICAL_DEVICES": lambda: os.getenv("PHYSICAL_DEVICES", None), + # Decide whether we should enable CP parallelism. + "VLLM_ASCEND_ENABLE_CP": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CP", '0'))) } # end-env-vars-definition diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 570756fd0c..58de158827 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -598,6 +598,10 @@ def enable_sp() -> bool: or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM) +def context_parallel_enable() -> bool: + return envs_ascend.VLLM_ASCEND_ENABLE_CP + + def is_moe_model(vllm_config: VllmConfig): config = vllm_config.model_config.hf_config return any('experts' in key.lower() for key in config.to_dict()) diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index a05227e2bd..393c62f9b5 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -2,9 +2,14 @@ import numpy as np import torch -from vllm.distributed import get_dcp_group, get_cp_group +from vllm.distributed import get_dcp_group from vllm.utils import cdiv +from vllm_ascend.utils import context_parallel_enable + +if context_parallel_enable(): + from vllm.distributed import get_cp_group + class BlockTable: @@ -80,8 +85,8 @@ def __init__(self, dtype=torch.int64, device=self.device) try: - self.cp_world_size = get_cp_group().world_size - self.cp_rank = get_cp_group().rank_in_group + self.cp_world_size = get_cp_group().world_size if context_parallel_enable() else 1 + self.cp_rank = get_cp_group().rank_in_group if self.cp_world_size > 1 else 0 self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group except AssertionError: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 49581e1d71..bb69706f22 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -121,10 +121,11 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, ProfileExecuteDuration, get_ascend_soc_version, is_310p, - lmhead_tp_enable, vllm_version_is) + lmhead_tp_enable, vllm_version_is, + context_parallel_enable) from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch -if context_parallel_enable: +if context_parallel_enable(): from vllm.distributed import get_cp_group from vllm.distributed.parallel_state import ( get_context_model_parallel_rank, get_context_model_parallel_world_size) @@ -263,7 +264,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.cp_size = get_context_model_parallel_world_size( - ) if context_parallel_enable else 1 + ) if context_parallel_enable() else 1 self.cp_rank = get_context_model_parallel_rank( ) if self.cp_size > 1 else 0 self.device = device @@ -1217,145 +1218,147 @@ def _generate_cp_metadata(self, total_num_scheduled_tokens, seq_lens, scheduler_ scheduler_output.num_scheduled_tokens.values())[0] > 1 num_actual_tokens_cp_full = total_num_scheduled_tokens * ( self.cp_size if is_prefill > 0 else 1) - long_seq_metadata = None - if self.cp_size > 1 and is_prefill > 0: - cp_kv_recover_idx = torch.zeros(num_actual_tokens_cp_full, - dtype=torch.int32, - device=self.device) - cp_kv_recover_idx.copy_(torch.tensor( - np.array(self.cp_kv_recover_idx).flatten().tolist()), - non_blocking=True) - self.cp_kv_recover_idx = cp_kv_recover_idx.to( - torch.float32).argsort().to(torch.int32) - - q_head_idx, q_tail_idx = [], [] - kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] - kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] - chunk_seqlens = [] - kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] - q_req_offset = 0 - kv_req_offset = 0 - q_head_chunk_id = self.cp_rank - q_tail_chunk_id = self.cp_size * 2 - 1 - self.cp_rank - for seq_len in seq_lens: - chunk_len = seq_len // 2 - chunk_seqlens.append(chunk_len) - q_head_idx.extend( - list(range(q_req_offset, q_req_offset + chunk_len))) - kv_with_q_head_nomask_idx.extend( - list( - range(kv_req_offset, - kv_req_offset + chunk_len * q_head_chunk_id))) - kv_with_q_head_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_head_chunk_id, - kv_req_offset + chunk_len * - (q_head_chunk_id + 1)))) - kv_with_q_head_nomask_seqlens.append(chunk_len * - q_head_chunk_id) - - q_tail_idx.extend( - list( - range(q_req_offset + chunk_len, - q_req_offset + chunk_len * 2))) - kv_with_q_tail_nomask_idx.extend( - list( - range(kv_req_offset, - kv_req_offset + chunk_len * q_tail_chunk_id))) - kv_with_q_tail_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_tail_chunk_id, - kv_req_offset + chunk_len * - (q_tail_chunk_id + 1)))) - kv_with_q_tail_nomask_seqlens.append(chunk_len * - q_tail_chunk_id) - - q_req_offset += seq_len - kv_req_offset += seq_len * self.cp_size - - # Convert lists to tensors and move to device - def _list_to_tensor(lst, device, dtype=torch.int32): - tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device) - tensor_npu.copy_(torch.tensor(lst, dtype=dtype), - non_blocking=True) - return tensor_npu - - q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) - q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) - self.q_head_idx_tensor = q_head_idx_tensor - self.q_tail_idx_tensor = q_tail_idx_tensor - - q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) - q_full_idx = q_full_idx.to(torch.float32).argsort().to(torch.int32) - self.q_full_idx = q_full_idx - - self.kv_idx_names = { - 'kv_with_q_head_nomask_idx_tensor': kv_with_q_head_nomask_idx, - 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, - 'kv_with_q_tail_nomask_idx_tensor': kv_with_q_tail_nomask_idx, - 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx - } - for key, value in self.kv_idx_names.items(): - tensor_npu = _list_to_tensor(value, self.device) - self.kv_idx_names[key] = tensor_npu - - attn_mask_seqlens = torch.tensor([chunk_seqlens, chunk_seqlens], - dtype=torch.int32) - head_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_head_nomask_seqlens], - dtype=torch.int32) - tail_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_tail_nomask_seqlens], - dtype=torch.int32) - if self.vllm_config.model_config.use_mla: - cp_prefill_mask = torch.triu( - torch.ones(512, 512, device=self.device, dtype=self.dtype), - 1) + if self.cp_size > 1: + if is_prefill > 0: + cp_kv_recover_idx = torch.zeros(num_actual_tokens_cp_full, + dtype=torch.int32, + device=self.device) + cp_kv_recover_idx.copy_(torch.tensor( + np.array(self.cp_kv_recover_idx).flatten().tolist()), + non_blocking=True) + self.cp_kv_recover_idx = cp_kv_recover_idx.to( + torch.float32).argsort().to(torch.int32) + + q_head_idx, q_tail_idx = [], [] + kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] + kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] + chunk_seqlens = [] + kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] + q_req_offset = 0 + kv_req_offset = 0 + q_head_chunk_id = self.cp_rank + q_tail_chunk_id = self.cp_size * 2 - 1 - self.cp_rank + for seq_len in seq_lens: + chunk_len = seq_len // 2 + chunk_seqlens.append(chunk_len) + q_head_idx.extend( + list(range(q_req_offset, q_req_offset + chunk_len))) + kv_with_q_head_nomask_idx.extend( + list( + range(kv_req_offset, + kv_req_offset + chunk_len * q_head_chunk_id))) + kv_with_q_head_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_head_chunk_id, + kv_req_offset + chunk_len * + (q_head_chunk_id + 1)))) + kv_with_q_head_nomask_seqlens.append(chunk_len * + q_head_chunk_id) + + q_tail_idx.extend( + list( + range(q_req_offset + chunk_len, + q_req_offset + chunk_len * 2))) + kv_with_q_tail_nomask_idx.extend( + list( + range(kv_req_offset, + kv_req_offset + chunk_len * q_tail_chunk_id))) + kv_with_q_tail_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_tail_chunk_id, + kv_req_offset + chunk_len * + (q_tail_chunk_id + 1)))) + kv_with_q_tail_nomask_seqlens.append(chunk_len * + q_tail_chunk_id) + + q_req_offset += seq_len + kv_req_offset += seq_len * self.cp_size + + # Convert lists to tensors and move to device + def _list_to_tensor(lst, device, dtype=torch.int32): + tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device) + tensor_npu.copy_(torch.tensor(lst, dtype=dtype), + non_blocking=True) + return tensor_npu + + q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) + q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) + self.q_head_idx_tensor = q_head_idx_tensor + self.q_tail_idx_tensor = q_tail_idx_tensor + + q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) + q_full_idx = q_full_idx.to(torch.float32).argsort().to(torch.int32) + self.q_full_idx = q_full_idx + + self.kv_idx_names = { + 'kv_with_q_head_nomask_idx_tensor': kv_with_q_head_nomask_idx, + 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, + 'kv_with_q_tail_nomask_idx_tensor': kv_with_q_tail_nomask_idx, + 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx + } + for key, value in self.kv_idx_names.items(): + tensor_npu = _list_to_tensor(value, self.device) + self.kv_idx_names[key] = tensor_npu + + attn_mask_seqlens = torch.tensor([chunk_seqlens, chunk_seqlens], + dtype=torch.int32) + head_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_head_nomask_seqlens], + dtype=torch.int32) + tail_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_tail_nomask_seqlens], + dtype=torch.int32) + if self.vllm_config.model_config.use_mla: + cp_prefill_mask = torch.triu( + torch.ones(512, 512, device=self.device, dtype=self.dtype), + 1) + else: + max_seq_len = max(seq_lens, default=0) + cp_prefill_mask = torch.triu( + torch.full((seq_lens.shape[0], max_seq_len, max_seq_len), + True, + device=self.device, + dtype=torch.bool), 1) + + self.extra_long_seq_kwargs = { + 'attn_mask_seqlens': attn_mask_seqlens, + 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, + 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, + 'cp_prefill_mask': cp_prefill_mask + } + long_seq_metadata = AscendCommonLongSequenceMetadata( + cp_kv_recover_idx=self.cp_kv_recover_idx, + num_actual_tokens_cp_full=num_actual_tokens_cp_full, + num_computed_tokens_of_cp_sp=self.input_batch.block_table.get_split_computed_tokens( + self.input_batch.num_computed_tokens_cpu[:self.input_batch.num_reqs]), + q_head_idx_tensor=self.q_head_idx_tensor, + q_tail_idx_tensor=self.q_tail_idx_tensor, + q_full_idx=self.q_full_idx, + kv_with_q_head_nomask_idx_tensor=self. + kv_idx_names['kv_with_q_head_nomask_idx_tensor'], + kv_with_q_head_mask_idx_tensor=self. + kv_idx_names['kv_with_q_head_mask_idx_tensor'], + kv_with_q_tail_nomask_idx_tensor=self. + kv_idx_names['kv_with_q_tail_nomask_idx_tensor'], + kv_with_q_tail_mask_idx_tensor=self. + kv_idx_names['kv_with_q_tail_mask_idx_tensor'], + attn_mask_seqlens=self. + extra_long_seq_kwargs['attn_mask_seqlens'], + head_attn_nomask_seqlens=self. + extra_long_seq_kwargs['head_attn_nomask_seqlens'], + tail_attn_nomask_seqlens=self. + extra_long_seq_kwargs['tail_attn_nomask_seqlens'], + cp_prefill_mask=self.extra_long_seq_kwargs['cp_prefill_mask']) else: - max_seq_len = max(seq_lens, default=0) - cp_prefill_mask = torch.triu( - torch.full((seq_lens.shape[0], max_seq_len, max_seq_len), - True, - device=self.device, - dtype=torch.bool), 1) - - self.extra_long_seq_kwargs = { - 'attn_mask_seqlens': attn_mask_seqlens, - 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, - 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, - 'cp_prefill_mask': cp_prefill_mask - } - long_seq_metadata = AscendCommonLongSequenceMetadata( - cp_kv_recover_idx=self.cp_kv_recover_idx, - num_actual_tokens_cp_full=num_actual_tokens_cp_full, - num_computed_tokens_of_cp_sp=self.input_batch.block_table.get_split_computed_tokens( - self.input_batch.num_computed_tokens_cpu[:self.input_batch.num_reqs]), - q_head_idx_tensor=self.q_head_idx_tensor, - q_tail_idx_tensor=self.q_tail_idx_tensor, - q_full_idx=self.q_full_idx, - kv_with_q_head_nomask_idx_tensor=self. - kv_idx_names['kv_with_q_head_nomask_idx_tensor'], - kv_with_q_head_mask_idx_tensor=self. - kv_idx_names['kv_with_q_head_mask_idx_tensor'], - kv_with_q_tail_nomask_idx_tensor=self. - kv_idx_names['kv_with_q_tail_nomask_idx_tensor'], - kv_with_q_tail_mask_idx_tensor=self. - kv_idx_names['kv_with_q_tail_mask_idx_tensor'], - attn_mask_seqlens=self. - extra_long_seq_kwargs['attn_mask_seqlens'], - head_attn_nomask_seqlens=self. - extra_long_seq_kwargs['head_attn_nomask_seqlens'], - tail_attn_nomask_seqlens=self. - extra_long_seq_kwargs['tail_attn_nomask_seqlens'], - cp_prefill_mask=self.extra_long_seq_kwargs['cp_prefill_mask']) + long_seq_metadata = AscendCommonLongSequenceMetadata( + num_actual_tokens_cp_full=num_actual_tokens_cp_full, + num_computed_tokens_of_cp_sp=self.input_batch.block_table.split_computed_tokens( + self.input_batch.num_computed_tokens_cpu[:self.input_batch.num_reqs]),) else: - long_seq_metadata = AscendCommonLongSequenceMetadata( - num_actual_tokens_cp_full=num_actual_tokens_cp_full, - num_computed_tokens_of_cp_sp=self.input_batch.block_table.split_computed_tokens( - self.input_batch.num_computed_tokens_cpu[:self.input_batch.num_reqs]),) - + long_seq_metadata = None + return long_seq_metadata def _get_cumsum_and_arange( diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index e851cafd57..4eeaaa8dc9 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -45,7 +45,7 @@ from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import (init_ascend_soc_version, +from vllm_ascend.utils import (init_ascend_soc_version, context_parallel_enable, register_ascend_customop, sleep_mode_enabled, try_register_lib) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -340,12 +340,20 @@ def _init_worker_distributed_environment(self) -> None: init_distributed_environment(self.parallel_config.world_size, self.rank, self.distributed_init_method, self.local_rank, "hccl") - ensure_model_parallel_initialized( - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - self.parallel_config.context_parallel_size, - self.parallel_config.decode_context_parallel_size - ) + print(f"context_parallel_enable:{context_parallel_enable}") + if context_parallel_enable(): + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size, + self.parallel_config.context_parallel_size, + self.parallel_config.decode_context_parallel_size + ) + else: + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size, + self.parallel_config.decode_context_parallel_size + ) init_ascend_model_parallel(self.parallel_config) ensure_kv_transfer_initialized(self.vllm_config) From b69f45af17d4223cfe407992d8360177d3f5678f Mon Sep 17 00:00:00 2001 From: LookAround Date: Fri, 26 Sep 2025 14:28:00 +0800 Subject: [PATCH 07/12] [bug] fix prefill bug Signed-off-by: LookAround --- vllm_ascend/attention/mla_v1.py | 10 +++++----- vllm_ascend/worker/block_table.py | 2 ++ vllm_ascend/worker/model_runner_v1.py | 21 +++++++++++---------- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index bc5656d192..8c1a3a05da 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -342,13 +342,13 @@ def build( device = self.device block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - if num_actual_tokens_cp_full is not None: - num_actual_tokens = num_actual_tokens_cp_full - - slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] input_positions = common_attn_metadata.positions[: num_actual_tokens].long( - ) + ) + if num_actual_tokens_cp_full is None: + num_actual_tokens_cp_full = num_actual_tokens + + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens_cp_full] if self.cos_cache is None: self.cos_cache = model.model.layers[ diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 393c62f9b5..9de98df882 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -316,6 +316,8 @@ def commit_slot_mapping(self, num_tokens: int) -> None: def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) -> list[list[list[int]]]: "Splits computed token counts across dcp and sp dimensions for distributed allocation." + self.cp_world_size = get_cp_group().world_size if context_parallel_enable() else 1 + self.dcp_world_size = get_dcp_group().world_size num_requests = len(num_computed_tokens) num_computed_tokens_of_dcp_sp = [[ [0] * self.dcp_world_size for _ in range(self.cp_world_size) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index bb69706f22..9e74a42cca 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -49,7 +49,7 @@ has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import (get_dp_group, get_pp_group, - get_tp_group, + get_tp_group, get_dcp_group, is_global_first_rank) from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger @@ -267,6 +267,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): ) if context_parallel_enable() else 1 self.cp_rank = get_context_model_parallel_rank( ) if self.cp_size > 1 else 0 + self.dcp_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group self.device = device if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP: self.prefetch_stream = torch.npu.Stream(device=device) @@ -1138,8 +1140,8 @@ def _num_scheduled_tokens_prefill_cp(self, num_tokens, ) # pad to 2*cp_size cp_pad = num_cp_padded_scheduled_tokens - num_scheduled_tokens # 给sample用 full_indices = list( - range(self.max_num_tokens * self.cp_size * self.sp_size + - self.cp_size * self.sp_size * self.max_num_reqs)) + range(self.max_num_tokens * self.cp_size * self.dcp_size + + self.cp_size * self.dcp_size * self.max_num_reqs)) chunk_size = num_cp_padded_scheduled_tokens // (2 * self.cp_size) # split position_ids (and use split position_ids to split input_ids afterwards) @@ -1563,7 +1565,7 @@ def _prepare_inputs( original_num_scheduled_tokens) _, original_arange = self._get_cumsum_and_arange( original_num_scheduled_tokens) - original_positions_np = self.positions_np[:original_total_num_scheduled_tokens] + original_positions_np = self.positions_np[:original_total_num_scheduled_tokens].copy() np.add(self.input_batch.num_computed_tokens_cpu[req_indices], self.position_cp[:total_num_scheduled_tokens], out=positions_np) @@ -1858,15 +1860,14 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, update_attn_params(self.update_stream, forward_context, positions.shape[0]) - if get_forward_context().sp_enabled: - hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) - pad_size = get_forward_context().pad_size - if pad_size > 0: - hidden_states = hidden_states[:-pad_size, :] if self.cp_size > 1 and with_prefill: + if isinstance(attn_metadata, dict): + cp_kv_recover_idx = list(attn_metadata.values())[0].prefill.cp_kv_recover_idx + else: + cp_kv_recover_idx = attn_metadata.prefill.cp_kv_recover_idx hidden_states = get_cp_group().all_gather(hidden_states, 0) hidden_states = torch.index_select( - hidden_states, 0, attn_metadata.prefill.cp_kv_recover_idx) + hidden_states, 0, cp_kv_recover_idx) return hidden_states def _build_attn_state(self, num_reqs, num_scheduled_tokens, From 8b333b92fdcd81a04bcefd0cfa5191b4a78069ec Mon Sep 17 00:00:00 2001 From: LookAround Date: Fri, 26 Sep 2025 23:36:02 +0800 Subject: [PATCH 08/12] [bug] fix decode bug (single batch) Signed-off-by: LookAround --- vllm_ascend/attention/mla_v1.py | 1 + vllm_ascend/worker/model_runner_v1.py | 30 +++++++-------------------- 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 8c1a3a05da..af4226b609 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1408,6 +1408,7 @@ def _forward_decode_sp( # TODO use update op to replace this def _update_out_and_lse( + self, out: torch.Tensor, lse: torch.Tensor, block_out: torch.Tensor, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 9e74a42cca..658ef9ac0e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1356,8 +1356,8 @@ def _list_to_tensor(lst, device, dtype=torch.int32): else: long_seq_metadata = AscendCommonLongSequenceMetadata( num_actual_tokens_cp_full=num_actual_tokens_cp_full, - num_computed_tokens_of_cp_sp=self.input_batch.block_table.split_computed_tokens( - self.input_batch.num_computed_tokens_cpu[:self.input_batch.num_reqs]),) + num_computed_tokens_of_cp_sp=self.input_batch.block_table.get_split_computed_tokens( + self.input_batch.num_tokens[:self.input_batch.num_reqs]), ) else: long_seq_metadata = None @@ -1757,6 +1757,7 @@ def _prepare_inputs( # prepare cp meta data long_seq_metadata = self._generate_cp_metadata(total_num_scheduled_tokens, seq_lens_cpu, scheduler_output) + original_total_num_scheduled_tokens = sum(original_num_scheduled_tokens[:num_reqs]) # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -1764,9 +1765,9 @@ def _prepare_inputs( blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor() slot_mapping = blk_table.slot_mapping_cpu[: - total_num_scheduled_tokens] - self.slot_mapping[:total_num_scheduled_tokens].copy_( - slot_mapping[:total_num_scheduled_tokens], + original_total_num_scheduled_tokens] + self.slot_mapping[:original_total_num_scheduled_tokens].copy_( + slot_mapping[:original_total_num_scheduled_tokens], non_blocking=True, ) @@ -2369,7 +2370,7 @@ def execute_model( max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens. - valid_sampled_token_ids = self._to_list(sampled_token_ids) + valid_sampled_token_ids = sampled_token_ids.tolist() else: # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( @@ -3820,19 +3821,4 @@ def get_supported_pooling_tasks(self): return list(model.pooler.get_supported_tasks()) def _build_drafter_prepare_inputs_torchair_param(self): - return False - - def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: - # This is a short term mitigation for issue mentioned in - # https://github.com/vllm-project/vllm/issues/22754. - # `tolist` would trigger a npu wise stream sync, which - # would block other copy ops from other npu streams. - # A npu event sync would avoid such a situation. Since - # this is in the critical path of every single model - # forward loop, this has caused perf issue for a disagg - # setup. - pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] - pinned.copy_(sampled_token_ids, non_blocking=True) - self.transfer_event.record() - self.transfer_event.synchronize() - return pinned.tolist() \ No newline at end of file + return False \ No newline at end of file From 24708949ddd05245b9ce9af5a5aee82c934e8816 Mon Sep 17 00:00:00 2001 From: LookAround Date: Sat, 27 Sep 2025 01:18:31 +0800 Subject: [PATCH 09/12] [bug] fix dcp bug Signed-off-by: LookAround --- vllm_ascend/worker/model_runner_v1.py | 278 +++++++++++++------------- 1 file changed, 138 insertions(+), 140 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 658ef9ac0e..d3baef8bd1 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1215,152 +1215,150 @@ def _update_logits_indices_for_cp(self, cu_num_tokens, scheduler_output: "Schedu return logits_indices def _generate_cp_metadata(self, total_num_scheduled_tokens, seq_lens, scheduler_output: "SchedulerOutput"): - # todo: find a better way to get is_prefill + # TODO: find a better way to get is_prefill is_prefill = list( scheduler_output.num_scheduled_tokens.values())[0] > 1 num_actual_tokens_cp_full = total_num_scheduled_tokens * ( self.cp_size if is_prefill > 0 else 1) - if self.cp_size > 1: - if is_prefill > 0: - cp_kv_recover_idx = torch.zeros(num_actual_tokens_cp_full, - dtype=torch.int32, - device=self.device) - cp_kv_recover_idx.copy_(torch.tensor( - np.array(self.cp_kv_recover_idx).flatten().tolist()), - non_blocking=True) - self.cp_kv_recover_idx = cp_kv_recover_idx.to( - torch.float32).argsort().to(torch.int32) - - q_head_idx, q_tail_idx = [], [] - kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] - kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] - chunk_seqlens = [] - kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] - q_req_offset = 0 - kv_req_offset = 0 - q_head_chunk_id = self.cp_rank - q_tail_chunk_id = self.cp_size * 2 - 1 - self.cp_rank - for seq_len in seq_lens: - chunk_len = seq_len // 2 - chunk_seqlens.append(chunk_len) - q_head_idx.extend( - list(range(q_req_offset, q_req_offset + chunk_len))) - kv_with_q_head_nomask_idx.extend( - list( - range(kv_req_offset, - kv_req_offset + chunk_len * q_head_chunk_id))) - kv_with_q_head_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_head_chunk_id, - kv_req_offset + chunk_len * - (q_head_chunk_id + 1)))) - kv_with_q_head_nomask_seqlens.append(chunk_len * - q_head_chunk_id) - - q_tail_idx.extend( - list( - range(q_req_offset + chunk_len, - q_req_offset + chunk_len * 2))) - kv_with_q_tail_nomask_idx.extend( - list( - range(kv_req_offset, - kv_req_offset + chunk_len * q_tail_chunk_id))) - kv_with_q_tail_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_tail_chunk_id, - kv_req_offset + chunk_len * - (q_tail_chunk_id + 1)))) - kv_with_q_tail_nomask_seqlens.append(chunk_len * - q_tail_chunk_id) - - q_req_offset += seq_len - kv_req_offset += seq_len * self.cp_size - - # Convert lists to tensors and move to device - def _list_to_tensor(lst, device, dtype=torch.int32): - tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device) - tensor_npu.copy_(torch.tensor(lst, dtype=dtype), - non_blocking=True) - return tensor_npu - - q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) - q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) - self.q_head_idx_tensor = q_head_idx_tensor - self.q_tail_idx_tensor = q_tail_idx_tensor - - q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) - q_full_idx = q_full_idx.to(torch.float32).argsort().to(torch.int32) - self.q_full_idx = q_full_idx - - self.kv_idx_names = { - 'kv_with_q_head_nomask_idx_tensor': kv_with_q_head_nomask_idx, - 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, - 'kv_with_q_tail_nomask_idx_tensor': kv_with_q_tail_nomask_idx, - 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx - } - for key, value in self.kv_idx_names.items(): - tensor_npu = _list_to_tensor(value, self.device) - self.kv_idx_names[key] = tensor_npu - - attn_mask_seqlens = torch.tensor([chunk_seqlens, chunk_seqlens], - dtype=torch.int32) - head_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_head_nomask_seqlens], - dtype=torch.int32) - tail_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_tail_nomask_seqlens], - dtype=torch.int32) - if self.vllm_config.model_config.use_mla: - cp_prefill_mask = torch.triu( - torch.ones(512, 512, device=self.device, dtype=self.dtype), - 1) - else: - max_seq_len = max(seq_lens, default=0) - cp_prefill_mask = torch.triu( - torch.full((seq_lens.shape[0], max_seq_len, max_seq_len), - True, - device=self.device, - dtype=torch.bool), 1) - - self.extra_long_seq_kwargs = { - 'attn_mask_seqlens': attn_mask_seqlens, - 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, - 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, - 'cp_prefill_mask': cp_prefill_mask - } - long_seq_metadata = AscendCommonLongSequenceMetadata( - cp_kv_recover_idx=self.cp_kv_recover_idx, - num_actual_tokens_cp_full=num_actual_tokens_cp_full, - num_computed_tokens_of_cp_sp=self.input_batch.block_table.get_split_computed_tokens( - self.input_batch.num_computed_tokens_cpu[:self.input_batch.num_reqs]), - q_head_idx_tensor=self.q_head_idx_tensor, - q_tail_idx_tensor=self.q_tail_idx_tensor, - q_full_idx=self.q_full_idx, - kv_with_q_head_nomask_idx_tensor=self. - kv_idx_names['kv_with_q_head_nomask_idx_tensor'], - kv_with_q_head_mask_idx_tensor=self. - kv_idx_names['kv_with_q_head_mask_idx_tensor'], - kv_with_q_tail_nomask_idx_tensor=self. - kv_idx_names['kv_with_q_tail_nomask_idx_tensor'], - kv_with_q_tail_mask_idx_tensor=self. - kv_idx_names['kv_with_q_tail_mask_idx_tensor'], - attn_mask_seqlens=self. - extra_long_seq_kwargs['attn_mask_seqlens'], - head_attn_nomask_seqlens=self. - extra_long_seq_kwargs['head_attn_nomask_seqlens'], - tail_attn_nomask_seqlens=self. - extra_long_seq_kwargs['tail_attn_nomask_seqlens'], - cp_prefill_mask=self.extra_long_seq_kwargs['cp_prefill_mask']) + long_seq_metadata = None + if self.cp_size > 1 and is_prefill > 0: + cp_kv_recover_idx = torch.zeros(num_actual_tokens_cp_full, + dtype=torch.int32, + device=self.device) + cp_kv_recover_idx.copy_(torch.tensor( + np.array(self.cp_kv_recover_idx).flatten().tolist()), + non_blocking=True) + self.cp_kv_recover_idx = cp_kv_recover_idx.to( + torch.float32).argsort().to(torch.int32) + + q_head_idx, q_tail_idx = [], [] + kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] + kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] + chunk_seqlens = [] + kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] + q_req_offset = 0 + kv_req_offset = 0 + q_head_chunk_id = self.cp_rank + q_tail_chunk_id = self.cp_size * 2 - 1 - self.cp_rank + for seq_len in seq_lens: + chunk_len = seq_len // 2 + chunk_seqlens.append(chunk_len) + q_head_idx.extend( + list(range(q_req_offset, q_req_offset + chunk_len))) + kv_with_q_head_nomask_idx.extend( + list( + range(kv_req_offset, + kv_req_offset + chunk_len * q_head_chunk_id))) + kv_with_q_head_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_head_chunk_id, + kv_req_offset + chunk_len * + (q_head_chunk_id + 1)))) + kv_with_q_head_nomask_seqlens.append(chunk_len * + q_head_chunk_id) + + q_tail_idx.extend( + list( + range(q_req_offset + chunk_len, + q_req_offset + chunk_len * 2))) + kv_with_q_tail_nomask_idx.extend( + list( + range(kv_req_offset, + kv_req_offset + chunk_len * q_tail_chunk_id))) + kv_with_q_tail_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_tail_chunk_id, + kv_req_offset + chunk_len * + (q_tail_chunk_id + 1)))) + kv_with_q_tail_nomask_seqlens.append(chunk_len * + q_tail_chunk_id) + + q_req_offset += seq_len + kv_req_offset += seq_len * self.cp_size + + # Convert lists to tensors and move to device + def _list_to_tensor(lst, device, dtype=torch.int32): + tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device) + tensor_npu.copy_(torch.tensor(lst, dtype=dtype), + non_blocking=True) + return tensor_npu + + q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) + q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) + self.q_head_idx_tensor = q_head_idx_tensor + self.q_tail_idx_tensor = q_tail_idx_tensor + + q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) + q_full_idx = q_full_idx.to(torch.float32).argsort().to(torch.int32) + self.q_full_idx = q_full_idx + + self.kv_idx_names = { + 'kv_with_q_head_nomask_idx_tensor': kv_with_q_head_nomask_idx, + 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, + 'kv_with_q_tail_nomask_idx_tensor': kv_with_q_tail_nomask_idx, + 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx + } + for key, value in self.kv_idx_names.items(): + tensor_npu = _list_to_tensor(value, self.device) + self.kv_idx_names[key] = tensor_npu + + attn_mask_seqlens = torch.tensor([chunk_seqlens, chunk_seqlens], + dtype=torch.int32) + head_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_head_nomask_seqlens], + dtype=torch.int32) + tail_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_tail_nomask_seqlens], + dtype=torch.int32) + if self.vllm_config.model_config.use_mla: + cp_prefill_mask = torch.triu( + torch.ones(512, 512, device=self.device, dtype=self.dtype), + 1) else: - long_seq_metadata = AscendCommonLongSequenceMetadata( - num_actual_tokens_cp_full=num_actual_tokens_cp_full, - num_computed_tokens_of_cp_sp=self.input_batch.block_table.get_split_computed_tokens( - self.input_batch.num_tokens[:self.input_batch.num_reqs]), ) + max_seq_len = max(seq_lens, default=0) + cp_prefill_mask = torch.triu( + torch.full((seq_lens.shape[0], max_seq_len, max_seq_len), + True, + device=self.device, + dtype=torch.bool), 1) + + self.extra_long_seq_kwargs = { + 'attn_mask_seqlens': attn_mask_seqlens, + 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, + 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, + 'cp_prefill_mask': cp_prefill_mask + } + long_seq_metadata = AscendCommonLongSequenceMetadata( + cp_kv_recover_idx=self.cp_kv_recover_idx, + num_actual_tokens_cp_full=num_actual_tokens_cp_full, + num_computed_tokens_of_cp_sp=self.input_batch.block_table.get_split_computed_tokens( + self.input_batch.num_computed_tokens_cpu[:self.input_batch.num_reqs]), + q_head_idx_tensor=self.q_head_idx_tensor, + q_tail_idx_tensor=self.q_tail_idx_tensor, + q_full_idx=self.q_full_idx, + kv_with_q_head_nomask_idx_tensor=self. + kv_idx_names['kv_with_q_head_nomask_idx_tensor'], + kv_with_q_head_mask_idx_tensor=self. + kv_idx_names['kv_with_q_head_mask_idx_tensor'], + kv_with_q_tail_nomask_idx_tensor=self. + kv_idx_names['kv_with_q_tail_nomask_idx_tensor'], + kv_with_q_tail_mask_idx_tensor=self. + kv_idx_names['kv_with_q_tail_mask_idx_tensor'], + attn_mask_seqlens=self. + extra_long_seq_kwargs['attn_mask_seqlens'], + head_attn_nomask_seqlens=self. + extra_long_seq_kwargs['head_attn_nomask_seqlens'], + tail_attn_nomask_seqlens=self. + extra_long_seq_kwargs['tail_attn_nomask_seqlens'], + cp_prefill_mask=self.extra_long_seq_kwargs['cp_prefill_mask']) else: - long_seq_metadata = None - + long_seq_metadata = AscendCommonLongSequenceMetadata( + num_actual_tokens_cp_full=num_actual_tokens_cp_full, + num_computed_tokens_of_cp_sp=self.input_batch.block_table.get_split_computed_tokens( + self.input_batch.num_tokens[:self.input_batch.num_reqs]), ) + return long_seq_metadata def _get_cumsum_and_arange( From 8442fb87e934cf95b3e0f5ad0499fa3860891109 Mon Sep 17 00:00:00 2001 From: LookAround Date: Sun, 28 Sep 2025 20:08:33 +0800 Subject: [PATCH 10/12] [bug] fix block size bug Signed-off-by: LookAround --- vllm_ascend/worker/block_table.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 9de98df882..5d774f5134 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -260,9 +260,11 @@ def __init__(self, # must be multiplied by dcp_world_size. try: dcp_world_size = get_dcp_group().world_size + cp_world_size = get_cp_group().world_size except AssertionError: # DCP might not be initialized in testing dcp_world_size = 1 + cp_world_size = 1 if kernel_sizes is None: kernel_sizes = [[0]] * len(block_sizes) @@ -278,7 +280,7 @@ def __init__(self, self.block_tables = [ BlockTable( block_size, max_num_reqs, - max(cdiv(max_model_len, block_size * dcp_world_size), + max(cdiv(max_model_len, block_size * dcp_world_size * cp_world_size), 1 + num_speculative_tokens), max_num_batched_tokens, pin_memory, device, kernel_size_list) for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) From 9022138b416108158ea57b04ac9a363563094071 Mon Sep 17 00:00:00 2001 From: LookAround Date: Mon, 29 Sep 2025 14:00:26 +0800 Subject: [PATCH 11/12] [optim] clean code Signed-off-by: LookAround --- examples/offline_inference_npu_long_seq.py | 11 +- vllm_ascend/core/schedule_config.py | 1 - vllm_ascend/worker/block_table.py | 8 +- vllm_ascend/worker/model_runner_v1.py | 460 ++++++++++----------- 4 files changed, 239 insertions(+), 241 deletions(-) diff --git a/examples/offline_inference_npu_long_seq.py b/examples/offline_inference_npu_long_seq.py index 111928402e..9d3b55da8a 100644 --- a/examples/offline_inference_npu_long_seq.py +++ b/examples/offline_inference_npu_long_seq.py @@ -5,6 +5,7 @@ from vllm import LLM, SamplingParams os.environ["VLLM_USE_MODELSCOPE"] = "True" +os.environ["VLLM_ASCEND_ENABLE_CP"] = "1" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" if __name__ == "__main__": @@ -16,7 +17,7 @@ parser.add_argument('--model_path', type=str, default="deepseek-ai/DeepSeek-V2-Lite") parser.add_argument('--tp', type=int, default=2) parser.add_argument('--cp', type=int, default=2) - parser.add_argument('--dcp', type=int, default=2) + parser.add_argument('--dcp', type=int, default=1) parser.add_argument('--iter_times', type=int, default=1) args = parser.parse_args() @@ -25,7 +26,7 @@ "The capital of France is", "Hello, my name is Tom, I am", "The president of United States is", - "AI future is? What do you think about it? Can you give me some information or any thing you want?" + "AI future is" ] sampling_params = SamplingParams(temperature = 0.8, top_p = 0.95, max_tokens=args.output_len) @@ -39,9 +40,9 @@ enable_prefix_caching=False, enable_expert_parallel=True, enable_chunked_prefill=False, - max_num_batched_tokens=args.input_len + 138, - max_model_len=args.input_len + args.output_len + 138, - additional_config={"ascend_scheduler_config": {"enabled": True}}, + max_num_batched_tokens=2048, + max_model_len=1024, + additional_config={"ascend_scheduler_config": {"enabled": False}}, max_num_seqs=1, block_size=128, gpu_memory_utilization=0.9 diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index 9e180e00ab..dcd5d05562 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -19,7 +19,6 @@ from typing import Type, Union from vllm.config import SchedulerConfig -from vllm.distributed import get_dcp_group @dataclass diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 5d774f5134..2dd21ec67b 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -321,7 +321,7 @@ def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) -> list[lis self.cp_world_size = get_cp_group().world_size if context_parallel_enable() else 1 self.dcp_world_size = get_dcp_group().world_size num_requests = len(num_computed_tokens) - num_computed_tokens_of_dcp_sp = [[ + num_computed_tokens_of_cp_dcp = [[ [0] * self.dcp_world_size for _ in range(self.cp_world_size) ] for _ in range(num_requests)] total_ranks = self.cp_world_size * self.dcp_world_size @@ -334,10 +334,10 @@ def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) -> list[lis for rank_idx in range(total_ranks): cp_idx = rank_idx // self.dcp_world_size sp_idx = rank_idx % self.dcp_world_size - num_computed_tokens_of_dcp_sp[req_idx][cp_idx][sp_idx] = base + num_computed_tokens_of_cp_dcp[req_idx][cp_idx][sp_idx] = base if rank_idx < remainder: - num_computed_tokens_of_dcp_sp[req_idx][cp_idx][sp_idx] += 1 - return num_computed_tokens_of_dcp_sp + num_computed_tokens_of_cp_dcp[req_idx][cp_idx][sp_idx] += 1 + return num_computed_tokens_of_cp_dcp def clear(self) -> None: for block_table in self.block_tables: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d3baef8bd1..5b5af31590 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1131,236 +1131,6 @@ def _iter_mm_features(req_state: CachedRequestState): mm_embeds.append(mm_embeds_item) return mm_embeds - def _num_scheduled_tokens_prefill_cp(self, num_tokens, - num_computed_tokens, - cp_kv_recover_idx): - num_scheduled_tokens = num_tokens - num_computed_tokens - num_cp_padded_scheduled_tokens = cdiv( - num_scheduled_tokens, 2 * self.cp_size) * (2 * self.cp_size - ) # pad to 2*cp_size - cp_pad = num_cp_padded_scheduled_tokens - num_scheduled_tokens # 给sample用 - full_indices = list( - range(self.max_num_tokens * self.cp_size * self.dcp_size + - self.cp_size * self.dcp_size * self.max_num_reqs)) - chunk_size = num_cp_padded_scheduled_tokens // (2 * self.cp_size) - - # split position_ids (and use split position_ids to split input_ids afterwards) - req_position_cp = [] - req_position_cp.extend( - full_indices[self.cp_rank * chunk_size:(self.cp_rank + 1) * - chunk_size]) - req_position_cp.extend( - full_indices[num_cp_padded_scheduled_tokens - (self.cp_rank + 1) * - chunk_size:num_cp_padded_scheduled_tokens - - self.cp_rank * chunk_size]) - - # used to recover kv order in cp prefill (after all-gather kv and before storing kv_cache) - num_added_recover_tokens = len(cp_kv_recover_idx[0]) * self.cp_size - for rank in range(self.cp_size): - cp_kv_recover_idx[rank].extend( - full_indices[rank * chunk_size + - num_added_recover_tokens:(rank + 1) * chunk_size + - num_added_recover_tokens]) - cp_kv_recover_idx[rank].extend(full_indices[ - num_cp_padded_scheduled_tokens - (rank + 1) * chunk_size + - num_added_recover_tokens:num_cp_padded_scheduled_tokens - - rank * chunk_size + num_added_recover_tokens]) - - return req_position_cp, num_cp_padded_scheduled_tokens, cp_pad - - def _update_tokens_for_cp(self, tokens, scheduler_output: "SchedulerOutput"): - if not self.cp_size > 1: - return tokens - num_reqs = self.input_batch.num_reqs - self.num_cp_pads = np.empty(num_reqs, dtype=np.int32) - self.cp_kv_recover_idx: List[List[int]] = [[] - for _ in range(self.cp_size) - ] - self.position_cp = np.zeros(self.max_num_tokens, dtype=np.int32) - start_index = 0 - - for i, req_id in enumerate(self.input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - is_prefill = num_tokens > 1 # todo: compare num prompt tokens and num tokens - if is_prefill: - # when cp > 1 & prefill, need to pad & split sequence here - req_position_cp, num_cp_padded_scheduled_tokens, self.num_cp_pads[ - i] = self._num_scheduled_tokens_prefill_cp( - num_tokens, - self.input_batch.num_computed_tokens_cpu[i], - self.cp_kv_recover_idx) - num_tokens = len(req_position_cp) - self.position_cp[start_index:start_index + - num_tokens] = req_position_cp - start_index += num_tokens - tokens[i] = num_tokens - else: - self.num_cp_pads[i] = 0 - self.position_cp[start_index:start_index + - num_tokens] = [idx for idx in range(num_tokens)] - start_index += num_tokens - return tokens - - def _update_logits_indices_for_cp(self, cu_num_tokens, scheduler_output: "SchedulerOutput"): - # todo: find a better way to get is_prefill - is_prefill = list( - scheduler_output.num_scheduled_tokens.values())[0] > 1 - num_reqs = self.input_batch.num_reqs - if self.cp_size > 1 and is_prefill: - # logits_indices = cu_num_tokens - num_cp_pads[:num_reqs] - 1 # if without all-gather and only sample on cp0 - logits_indices = cu_num_tokens * self.cp_size - self.num_cp_pads[: - num_reqs] - 1 - else: - logits_indices = cu_num_tokens - 1 - return logits_indices - - def _generate_cp_metadata(self, total_num_scheduled_tokens, seq_lens, scheduler_output: "SchedulerOutput"): - # TODO: find a better way to get is_prefill - is_prefill = list( - scheduler_output.num_scheduled_tokens.values())[0] > 1 - num_actual_tokens_cp_full = total_num_scheduled_tokens * ( - self.cp_size if is_prefill > 0 else 1) - long_seq_metadata = None - if self.cp_size > 1 and is_prefill > 0: - cp_kv_recover_idx = torch.zeros(num_actual_tokens_cp_full, - dtype=torch.int32, - device=self.device) - cp_kv_recover_idx.copy_(torch.tensor( - np.array(self.cp_kv_recover_idx).flatten().tolist()), - non_blocking=True) - self.cp_kv_recover_idx = cp_kv_recover_idx.to( - torch.float32).argsort().to(torch.int32) - - q_head_idx, q_tail_idx = [], [] - kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] - kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] - chunk_seqlens = [] - kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] - q_req_offset = 0 - kv_req_offset = 0 - q_head_chunk_id = self.cp_rank - q_tail_chunk_id = self.cp_size * 2 - 1 - self.cp_rank - for seq_len in seq_lens: - chunk_len = seq_len // 2 - chunk_seqlens.append(chunk_len) - q_head_idx.extend( - list(range(q_req_offset, q_req_offset + chunk_len))) - kv_with_q_head_nomask_idx.extend( - list( - range(kv_req_offset, - kv_req_offset + chunk_len * q_head_chunk_id))) - kv_with_q_head_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_head_chunk_id, - kv_req_offset + chunk_len * - (q_head_chunk_id + 1)))) - kv_with_q_head_nomask_seqlens.append(chunk_len * - q_head_chunk_id) - - q_tail_idx.extend( - list( - range(q_req_offset + chunk_len, - q_req_offset + chunk_len * 2))) - kv_with_q_tail_nomask_idx.extend( - list( - range(kv_req_offset, - kv_req_offset + chunk_len * q_tail_chunk_id))) - kv_with_q_tail_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_tail_chunk_id, - kv_req_offset + chunk_len * - (q_tail_chunk_id + 1)))) - kv_with_q_tail_nomask_seqlens.append(chunk_len * - q_tail_chunk_id) - - q_req_offset += seq_len - kv_req_offset += seq_len * self.cp_size - - # Convert lists to tensors and move to device - def _list_to_tensor(lst, device, dtype=torch.int32): - tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device) - tensor_npu.copy_(torch.tensor(lst, dtype=dtype), - non_blocking=True) - return tensor_npu - - q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) - q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) - self.q_head_idx_tensor = q_head_idx_tensor - self.q_tail_idx_tensor = q_tail_idx_tensor - - q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) - q_full_idx = q_full_idx.to(torch.float32).argsort().to(torch.int32) - self.q_full_idx = q_full_idx - - self.kv_idx_names = { - 'kv_with_q_head_nomask_idx_tensor': kv_with_q_head_nomask_idx, - 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, - 'kv_with_q_tail_nomask_idx_tensor': kv_with_q_tail_nomask_idx, - 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx - } - for key, value in self.kv_idx_names.items(): - tensor_npu = _list_to_tensor(value, self.device) - self.kv_idx_names[key] = tensor_npu - - attn_mask_seqlens = torch.tensor([chunk_seqlens, chunk_seqlens], - dtype=torch.int32) - head_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_head_nomask_seqlens], - dtype=torch.int32) - tail_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_tail_nomask_seqlens], - dtype=torch.int32) - if self.vllm_config.model_config.use_mla: - cp_prefill_mask = torch.triu( - torch.ones(512, 512, device=self.device, dtype=self.dtype), - 1) - else: - max_seq_len = max(seq_lens, default=0) - cp_prefill_mask = torch.triu( - torch.full((seq_lens.shape[0], max_seq_len, max_seq_len), - True, - device=self.device, - dtype=torch.bool), 1) - - self.extra_long_seq_kwargs = { - 'attn_mask_seqlens': attn_mask_seqlens, - 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, - 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, - 'cp_prefill_mask': cp_prefill_mask - } - long_seq_metadata = AscendCommonLongSequenceMetadata( - cp_kv_recover_idx=self.cp_kv_recover_idx, - num_actual_tokens_cp_full=num_actual_tokens_cp_full, - num_computed_tokens_of_cp_sp=self.input_batch.block_table.get_split_computed_tokens( - self.input_batch.num_computed_tokens_cpu[:self.input_batch.num_reqs]), - q_head_idx_tensor=self.q_head_idx_tensor, - q_tail_idx_tensor=self.q_tail_idx_tensor, - q_full_idx=self.q_full_idx, - kv_with_q_head_nomask_idx_tensor=self. - kv_idx_names['kv_with_q_head_nomask_idx_tensor'], - kv_with_q_head_mask_idx_tensor=self. - kv_idx_names['kv_with_q_head_mask_idx_tensor'], - kv_with_q_tail_nomask_idx_tensor=self. - kv_idx_names['kv_with_q_tail_nomask_idx_tensor'], - kv_with_q_tail_mask_idx_tensor=self. - kv_idx_names['kv_with_q_tail_mask_idx_tensor'], - attn_mask_seqlens=self. - extra_long_seq_kwargs['attn_mask_seqlens'], - head_attn_nomask_seqlens=self. - extra_long_seq_kwargs['head_attn_nomask_seqlens'], - tail_attn_nomask_seqlens=self. - extra_long_seq_kwargs['tail_attn_nomask_seqlens'], - cp_prefill_mask=self.extra_long_seq_kwargs['cp_prefill_mask']) - else: - long_seq_metadata = AscendCommonLongSequenceMetadata( - num_actual_tokens_cp_full=num_actual_tokens_cp_full, - num_computed_tokens_of_cp_sp=self.input_batch.block_table.get_split_computed_tokens( - self.input_batch.num_tokens[:self.input_batch.num_reqs]), ) - - return long_seq_metadata - def _get_cumsum_and_arange( self, num_tokens: np.ndarray, @@ -3819,4 +3589,232 @@ def get_supported_pooling_tasks(self): return list(model.pooler.get_supported_tasks()) def _build_drafter_prepare_inputs_torchair_param(self): - return False \ No newline at end of file + return False + + def _num_scheduled_tokens_prefill_cp(self, num_tokens, + num_computed_tokens, + cp_kv_recover_idx): + num_scheduled_tokens = num_tokens - num_computed_tokens + num_cp_padded_scheduled_tokens = cdiv( + num_scheduled_tokens, 2 * self.cp_size) * (2 * self.cp_size + ) # pad to 2*cp_size + cp_pad = num_cp_padded_scheduled_tokens - num_scheduled_tokens # 给sample用 + full_indices = list( + range(self.max_num_tokens * self.cp_size * self.dcp_size + + self.cp_size * self.dcp_size * self.max_num_reqs)) + chunk_size = num_cp_padded_scheduled_tokens // (2 * self.cp_size) + + # split position_ids (and use split position_ids to split input_ids afterwards) + req_position_cp = [] + req_position_cp.extend( + full_indices[self.cp_rank * chunk_size:(self.cp_rank + 1) * + chunk_size]) + req_position_cp.extend( + full_indices[num_cp_padded_scheduled_tokens - (self.cp_rank + 1) * + chunk_size:num_cp_padded_scheduled_tokens - + self.cp_rank * chunk_size]) + + # used to recover kv order in cp prefill (after all-gather kv and before storing kv_cache) + num_added_recover_tokens = len(cp_kv_recover_idx[0]) * self.cp_size + for rank in range(self.cp_size): + cp_kv_recover_idx[rank].extend( + full_indices[rank * chunk_size + + num_added_recover_tokens:(rank + 1) * chunk_size + + num_added_recover_tokens]) + cp_kv_recover_idx[rank].extend(full_indices[ + num_cp_padded_scheduled_tokens - (rank + 1) * chunk_size + + num_added_recover_tokens:num_cp_padded_scheduled_tokens - + rank * chunk_size + num_added_recover_tokens]) + + return req_position_cp, num_cp_padded_scheduled_tokens, cp_pad + + def _update_tokens_for_cp(self, tokens, scheduler_output: "SchedulerOutput"): + if not self.cp_size > 1: + return tokens + num_reqs = self.input_batch.num_reqs + self.num_cp_pads = np.empty(num_reqs, dtype=np.int32) + self.cp_kv_recover_idx: List[List[int]] = [[] + for _ in range(self.cp_size) + ] + self.position_cp = np.zeros(self.max_num_tokens, dtype=np.int32) + start_index = 0 + + for i, req_id in enumerate(self.input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + is_prefill = num_tokens > 1 + if is_prefill: + # when cp > 1 & prefill, need to pad & split sequence here + req_position_cp, num_cp_padded_scheduled_tokens, self.num_cp_pads[ + i] = self._num_scheduled_tokens_prefill_cp( + num_tokens, + self.input_batch.num_computed_tokens_cpu[i], + self.cp_kv_recover_idx) + num_tokens = len(req_position_cp) + self.position_cp[start_index:start_index + + num_tokens] = req_position_cp + start_index += num_tokens + tokens[i] = num_tokens + else: + self.num_cp_pads[i] = 0 + self.position_cp[start_index:start_index + + num_tokens] = [idx for idx in range(num_tokens)] + start_index += num_tokens + return tokens + + def _update_logits_indices_for_cp(self, cu_num_tokens, scheduler_output: "SchedulerOutput"): + is_prefill = list( + scheduler_output.num_scheduled_tokens.values())[0] > 1 + num_reqs = self.input_batch.num_reqs + if self.cp_size > 1 and is_prefill: + # logits_indices = cu_num_tokens - num_cp_pads[:num_reqs] - 1 # if without all-gather and only sample on cp0 + logits_indices = cu_num_tokens * self.cp_size - self.num_cp_pads[: + num_reqs] - 1 + else: + logits_indices = cu_num_tokens - 1 + return logits_indices + + def _generate_cp_metadata(self, total_num_scheduled_tokens, seq_lens, scheduler_output: "SchedulerOutput"): + is_prefill = list( + scheduler_output.num_scheduled_tokens.values())[0] > 1 + num_actual_tokens_cp_full = total_num_scheduled_tokens * ( + self.cp_size if is_prefill > 0 else 1) + long_seq_metadata = None + if self.cp_size > 1 and is_prefill > 0: + cp_kv_recover_idx = torch.zeros(num_actual_tokens_cp_full, + dtype=torch.int32, + device=self.device) + cp_kv_recover_idx.copy_(torch.tensor( + np.array(self.cp_kv_recover_idx).flatten().tolist()), + non_blocking=True) + self.cp_kv_recover_idx = cp_kv_recover_idx.to( + torch.float32).argsort().to(torch.int32) + + q_head_idx, q_tail_idx = [], [] + kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] + kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] + chunk_seqlens = [] + kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] + q_req_offset = 0 + kv_req_offset = 0 + q_head_chunk_id = self.cp_rank + q_tail_chunk_id = self.cp_size * 2 - 1 - self.cp_rank + for seq_len in seq_lens: + chunk_len = seq_len // 2 + chunk_seqlens.append(chunk_len) + q_head_idx.extend( + list(range(q_req_offset, q_req_offset + chunk_len))) + kv_with_q_head_nomask_idx.extend( + list( + range(kv_req_offset, + kv_req_offset + chunk_len * q_head_chunk_id))) + kv_with_q_head_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_head_chunk_id, + kv_req_offset + chunk_len * + (q_head_chunk_id + 1)))) + kv_with_q_head_nomask_seqlens.append(chunk_len * + q_head_chunk_id) + + q_tail_idx.extend( + list( + range(q_req_offset + chunk_len, + q_req_offset + chunk_len * 2))) + kv_with_q_tail_nomask_idx.extend( + list( + range(kv_req_offset, + kv_req_offset + chunk_len * q_tail_chunk_id))) + kv_with_q_tail_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_tail_chunk_id, + kv_req_offset + chunk_len * + (q_tail_chunk_id + 1)))) + kv_with_q_tail_nomask_seqlens.append(chunk_len * + q_tail_chunk_id) + + q_req_offset += seq_len + kv_req_offset += seq_len * self.cp_size + + # Convert lists to tensors and move to device + def _list_to_tensor(lst, device, dtype=torch.int32): + tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device) + tensor_npu.copy_(torch.tensor(lst, dtype=dtype), + non_blocking=True) + return tensor_npu + + q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) + q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) + self.q_head_idx_tensor = q_head_idx_tensor + self.q_tail_idx_tensor = q_tail_idx_tensor + + q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) + q_full_idx = q_full_idx.to(torch.float32).argsort().to(torch.int32) + self.q_full_idx = q_full_idx + + self.kv_idx_names = { + 'kv_with_q_head_nomask_idx_tensor': kv_with_q_head_nomask_idx, + 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, + 'kv_with_q_tail_nomask_idx_tensor': kv_with_q_tail_nomask_idx, + 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx + } + for key, value in self.kv_idx_names.items(): + tensor_npu = _list_to_tensor(value, self.device) + self.kv_idx_names[key] = tensor_npu + + attn_mask_seqlens = torch.tensor([chunk_seqlens, chunk_seqlens], + dtype=torch.int32) + head_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_head_nomask_seqlens], + dtype=torch.int32) + tail_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_tail_nomask_seqlens], + dtype=torch.int32) + if self.vllm_config.model_config.use_mla: + cp_prefill_mask = torch.triu( + torch.ones(512, 512, device=self.device, dtype=self.dtype), + 1) + else: + max_seq_len = max(seq_lens, default=0) + cp_prefill_mask = torch.triu( + torch.full((seq_lens.shape[0], max_seq_len, max_seq_len), + True, + device=self.device, + dtype=torch.bool), 1) + + self.extra_long_seq_kwargs = { + 'attn_mask_seqlens': attn_mask_seqlens, + 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, + 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, + 'cp_prefill_mask': cp_prefill_mask + } + long_seq_metadata = AscendCommonLongSequenceMetadata( + cp_kv_recover_idx=self.cp_kv_recover_idx, + num_actual_tokens_cp_full=num_actual_tokens_cp_full, + num_computed_tokens_of_cp_sp=self.input_batch.block_table.get_split_computed_tokens( + self.input_batch.num_computed_tokens_cpu[:self.input_batch.num_reqs]), + q_head_idx_tensor=self.q_head_idx_tensor, + q_tail_idx_tensor=self.q_tail_idx_tensor, + q_full_idx=self.q_full_idx, + kv_with_q_head_nomask_idx_tensor=self. + kv_idx_names['kv_with_q_head_nomask_idx_tensor'], + kv_with_q_head_mask_idx_tensor=self. + kv_idx_names['kv_with_q_head_mask_idx_tensor'], + kv_with_q_tail_nomask_idx_tensor=self. + kv_idx_names['kv_with_q_tail_nomask_idx_tensor'], + kv_with_q_tail_mask_idx_tensor=self. + kv_idx_names['kv_with_q_tail_mask_idx_tensor'], + attn_mask_seqlens=self. + extra_long_seq_kwargs['attn_mask_seqlens'], + head_attn_nomask_seqlens=self. + extra_long_seq_kwargs['head_attn_nomask_seqlens'], + tail_attn_nomask_seqlens=self. + extra_long_seq_kwargs['tail_attn_nomask_seqlens'], + cp_prefill_mask=self.extra_long_seq_kwargs['cp_prefill_mask']) + else: + long_seq_metadata = AscendCommonLongSequenceMetadata( + num_actual_tokens_cp_full=num_actual_tokens_cp_full, + num_computed_tokens_of_cp_sp=self.input_batch.block_table.get_split_computed_tokens( + self.input_batch.num_tokens[:self.input_batch.num_reqs]), ) + + return long_seq_metadata \ No newline at end of file From 8dda1bafa7795ec63dd03f83b83257960f21cfed Mon Sep 17 00:00:00 2001 From: Delphine-Nic Date: Tue, 30 Sep 2025 10:02:39 +0800 Subject: [PATCH 12/12] GQA support pcp and dcp Signed-off-by: Delphine-Nic --- vllm_ascend/attention/attention_v1.py | 526 +++++++++++++++++++++----- vllm_ascend/ops/fused_moe.py | 5 +- 2 files changed, 443 insertions(+), 88 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 963a947d42..72736f18ed 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -20,12 +20,23 @@ from typing import ClassVar, List, Optional, Tuple, Type import torch +import torch.distributed as dist import torch.nn as nn +import torch.nn.functional as F import torch_npu + +import numpy as np + from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState from vllm.config import VllmConfig +from vllm.distributed import (get_context_model_parallel_rank, + get_context_model_parallel_world_size, + get_cp_group, + get_decode_context_model_parallel_rank, + get_decode_context_model_parallel_world_size, + get_dcp_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils import cdiv, direct_register_custom_op from vllm.v1.attention.backends.utils import AttentionCGSupport @@ -34,7 +45,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, maybe_save_kv_layer_to_connector, - wait_for_kv_layer_from_connector) + wait_for_kv_layer_from_connector, split_decodes_and_prefills) from vllm_ascend.compilation.acl_graph import get_graph_params from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, @@ -66,10 +77,10 @@ def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]: @staticmethod def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, ) -> Tuple[int, ...]: if is_310p(): return (2, num_blocks, num_kv_heads * head_size // 16, block_size, @@ -78,18 +89,18 @@ def get_kv_cache_shape( @staticmethod def get_bsh_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, ) -> Tuple[int, ...]: return (2, num_blocks, block_size, num_kv_heads * head_size) @staticmethod def swap_blocks( - src_kv_cache: List[torch.Tensor], - dst_kv_cache: List[torch.Tensor], - src_to_dst: torch.Tensor, + src_kv_cache: List[torch.Tensor], + dst_kv_cache: List[torch.Tensor], + src_to_dst: torch.Tensor, ) -> None: src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1] dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1] @@ -103,8 +114,8 @@ def swap_blocks( @staticmethod def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, ) -> None: src_indices = src_to_dists[:, 0] dst_indices = src_to_dists[:, 1] @@ -129,8 +140,36 @@ class AscendAttentionState(Enum): @dataclass -class AscendMetadata: +class AscendCpMetadata: + q_head_idx: torch.Tensor = None + q_tail_idx: torch.Tensor = None + kv_with_q_head_nomask_idx: torch.Tensor = None + kv_with_q_head_mask_idx: torch.Tensor = None + kv_with_q_tail_nomask_idx: torch.Tensor = None + kv_with_q_tail_mask_idx: torch.Tensor = None + attn_mask_seqlens: torch.Tensor = None + head_attn_nomask_seqlens: torch.Tensor = None + tail_attn_nomask_seqlens: torch.Tensor = None + q_full_idx: torch.Tensor = None + cp_prefill_mask: torch.Tensor = None + + +@dataclass +class AscendMetadataForPrefill: + """ Prefill Specific Metadata for Ascend""" + cp_metadata: Optional[AscendCpMetadata] = None + cp_kv_recover_idx: Optional[List[int]] = None + +@dataclass +class AscendMetadataForDecode: + """ Decode Specific Metadata for Ascend""" + num_computed_tokens_of_cp_sp: Optional[list[Optional[list[Optional[ + list[int]]]]]] = None + + +@dataclass +class AscendMetadata: # **************************** Basic Properties ************************** # attn_mask: Optional[torch.Tensor] = None # Current state of this attention run. @@ -138,6 +177,7 @@ class AscendMetadata: # Number of tokens excluding padding. num_actual_tokens: int = 0 + num_prefills: int = 0 # The sequence length per sequence. Sequence length means the computed # tokens + new tokens (is None if it is a decoding). @@ -164,6 +204,10 @@ class AscendMetadata: # *************************** Other Properties *************************** # enable_dbo_across_dp: bool = False + prefill: Optional[AscendMetadataForPrefill] = None + + decode: Optional[AscendMetadataForDecode] = None + class AscendAttentionMetadataBuilder: # Does this backend/builder support ACL Graphs for attention (default: no). @@ -175,11 +219,11 @@ class AscendAttentionMetadataBuilder: reorder_batch_threshold: ClassVar[int] = 1 def __init__( - self, - kv_cache_spec: AttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -193,20 +237,35 @@ def reorder_batch(self, input_batch, return False def build( - self, - common_prefix_len: int, - common_attn_metadata: AscendCommonAttentionMetadata, - model: Optional[nn.Module] = None, + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: Optional[nn.Module] = None, ): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] + + decode_threshold = 1 + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_actual_tokens + block_table = common_attn_metadata.block_table_tensor query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] + + long_seq_metadata = common_attn_metadata.common_long_seq_metadata + num_actual_tokens_cp_full = long_seq_metadata.num_actual_tokens_cp_full if long_seq_metadata else None + if num_actual_tokens_cp_full is None: + num_actual_tokens_cp_full = num_actual_tokens + + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens_cp_full].to(self.device, + non_blocking=True) + # slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] attn_mask = common_attn_metadata.attn_mask attn_state = common_attn_metadata.attn_state query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: @@ -225,6 +284,43 @@ def build( attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), ACL_FORMAT_FRACTAL_NZ) + prefill_metadata = None + if num_prefills > 0: + cp_metadata = None + common_long_seq_metadata = common_attn_metadata.common_long_seq_metadata + if common_long_seq_metadata is not None: + cp_metadata = AscendCpMetadata( + q_head_idx=common_long_seq_metadata.q_head_idx_tensor, + q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor, + kv_with_q_head_nomask_idx=common_long_seq_metadata. + kv_with_q_head_nomask_idx_tensor, + kv_with_q_head_mask_idx=common_long_seq_metadata. + kv_with_q_head_mask_idx_tensor, + kv_with_q_tail_nomask_idx=common_long_seq_metadata. + kv_with_q_tail_nomask_idx_tensor, + kv_with_q_tail_mask_idx=common_long_seq_metadata. + kv_with_q_tail_mask_idx_tensor, + attn_mask_seqlens=common_long_seq_metadata. + attn_mask_seqlens, + head_attn_nomask_seqlens=common_long_seq_metadata. + head_attn_nomask_seqlens, + tail_attn_nomask_seqlens=common_long_seq_metadata. + tail_attn_nomask_seqlens, + q_full_idx=common_long_seq_metadata.q_full_idx, + cp_prefill_mask=common_long_seq_metadata.cp_prefill_mask) + prefill_metadata = AscendMetadataForPrefill( + cp_metadata=cp_metadata, + cp_kv_recover_idx=common_long_seq_metadata.cp_kv_recover_idx + if common_long_seq_metadata is not None else None) + # TODO + decode_metadata = None + if num_decodes > 0: + common_long_seq_metadata = common_attn_metadata.common_long_seq_metadata + if common_long_seq_metadata is not None: + num_computed_tokens_of_cp_sp = common_long_seq_metadata.num_computed_tokens_of_cp_sp + num_computed_tokens_of_cp_sp = np.array(num_computed_tokens_of_cp_sp) + decode_metadata = AscendMetadataForDecode(num_computed_tokens_of_cp_sp=num_computed_tokens_of_cp_sp) + attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, block_tables=block_table, @@ -235,13 +331,16 @@ def build( slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, - enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp) + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, + num_prefills=num_prefills, + prefill=prefill_metadata, + decode=decode_metadata) return attn_metadata def build_for_graph_capture( - self, - common_attn_metadata: AscendCommonAttentionMetadata, - attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, + self, + common_attn_metadata: AscendCommonAttentionMetadata, + attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, ): if attn_state == AscendAttentionState.DecodeOnly: attn_metadata = self.build( @@ -260,18 +359,18 @@ def build_for_graph_capture( class AscendAttentionBackendImpl(AttentionImpl): def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - **kwargs, + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -292,14 +391,22 @@ def __init__( self.key_cache = None self.value_cache = None + self.cp_size = get_context_model_parallel_world_size() + self.cp_rank = get_context_model_parallel_rank() if self.cp_size > 1 else 0 + self.cp_group = get_cp_group().device_group if self.cp_size > 1 else None + + self.dcp_size = get_decode_context_model_parallel_world_size() + self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 + self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None + def _forward_prefill_no_cache( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, - num_tokens=0, + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + num_tokens=0, ) -> torch.Tensor: assert attn_metadata is not None assert attn_metadata.attn_mask is not None @@ -330,10 +437,10 @@ def _forward_prefill_no_cache( return output[:num_tokens, :, :] def _forward_prefill_cache_hit( - self, - query: torch.Tensor, - attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, + self, + query: torch.Tensor, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert attn_metadata is not None assert attn_metadata.attn_mask is not None @@ -357,17 +464,17 @@ def _forward_prefill_cache_hit( return output def _forward_decode_only( - self, - query: torch.Tensor, - attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, + self, + query: torch.Tensor, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: if is_310p(): # seq_lens_tensor needs to be transferred to the device for 310P. attn_metadata.seq_lens = \ attn_metadata.seq_lens.to(device=query.device) if self.sliding_window is not None and attn_metadata.seq_lens.shape[ - 0] == query.size(0): + 0] == query.size(0): batch_size = attn_metadata.seq_lens.shape[0] block_size = 128 query = query.view(batch_size, 1, self.num_heads * self.head_size) @@ -444,10 +551,10 @@ def _forward_decode_only( return output def _forward_v1_style( - self, - query: torch.Tensor, - attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, + self, + query: torch.Tensor, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Use chunked prefill for head size 192 scenario, like deepseek # paged_attention_splitfuse maybe crash at such scenario. @@ -520,16 +627,243 @@ def _forward_v1_style( out=output) return output + def _pack_tnd_2_bsnd(self, tensor_tnd: torch.Tensor, + lengths: List[int]) -> torch.Tensor: + max_len = max(lengths) + splits = torch.split(tensor_tnd, lengths, dim=0) + + padded = [] + for s in splits: + pad_len = max_len - s.shape[0] + s_pad = F.pad(s, (0, 0, 0, 0, 0, pad_len)) + padded.append(s_pad) + + tensor_bsnd = torch.stack(padded, dim=0) + return tensor_bsnd + + def _unpack_bsnd_2_tnd(self, tensor_bsnd: torch.Tensor, + lengths: List[int]) -> torch.Tensor: + slices = [] + for i, length in enumerate(lengths): + slices.append(tensor_bsnd[i, :length]) + tensor_tnd = torch.cat(slices, dim=0) + return tensor_tnd + + def _attention_with_nomask_and_mask(self, q: torch.Tensor, + q_seqlens: List[int], + k_nomask: torch.Tensor, + v_nomask: torch.Tensor, + kv_seqlens_nomask: List[int], + k_mask: torch.Tensor, + v_mask: torch.Tensor, + kv_seqlens_mask: List[int], + mask: torch.Tensor) -> torch.Tensor: + q = self._pack_tnd_2_bsnd(q, q_seqlens) + + # nomask Attention + if k_nomask is not None: + attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score( + q, + self._pack_tnd_2_bsnd(k_nomask, kv_seqlens_nomask), + self._pack_tnd_2_bsnd(v_nomask, kv_seqlens_nomask), + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="BSND", + atten_mask=None, + scale=self.scale, + sparse_mode=0, + antiquant_mode=0, + antiquant_scale=None, + softmax_lse_flag=True, + actual_seq_lengths_kv=kv_seqlens_nomask, + actual_seq_lengths=q_seqlens) + attn_out_nomask = self._unpack_bsnd_2_tnd(attn_out_nomask, + q_seqlens) + # (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1) + attn_lse_nomask = self._unpack_bsnd_2_tnd( + attn_lse_nomask.permute([0, 2, 1, 3]), q_seqlens) + + # mask Attention + attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score( + q, + self._pack_tnd_2_bsnd(k_mask, kv_seqlens_mask), + self._pack_tnd_2_bsnd(v_mask, kv_seqlens_mask), + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="BSND", + atten_mask=mask, + scale=self.scale, + sparse_mode=0, + antiquant_mode=0, + antiquant_scale=None, + softmax_lse_flag=True, + actual_seq_lengths_kv=kv_seqlens_mask, + actual_seq_lengths=q_seqlens) + attn_out_mask = self._unpack_bsnd_2_tnd(attn_out_mask, q_seqlens) + attn_lse_mask = self._unpack_bsnd_2_tnd( + attn_lse_mask.permute([0, 2, 1, 3]), q_seqlens) + + # update + output = attn_out_mask + if k_nomask is not None: + output, _ = self._update_out_and_lse( + torch.stack([attn_out_nomask, attn_out_mask], dim=0), + torch.stack([attn_lse_nomask, attn_lse_mask], dim=0)) + + return output + + def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata) -> torch.Tensor: + assert attn_metadata is not None + assert attn_metadata.prefill is not None + assert attn_metadata.prefill.cp_metadata is not None + # Use precomputed indices from the metadata (already converted to tensors and on device) + q_head_idx = attn_metadata.prefill.cp_metadata.q_head_idx + q_tail_idx = attn_metadata.prefill.cp_metadata.q_tail_idx + kv_with_q_head_nomask_idx = attn_metadata.prefill.cp_metadata.kv_with_q_head_nomask_idx + kv_with_q_head_mask_idx = attn_metadata.prefill.cp_metadata.kv_with_q_head_mask_idx + kv_with_q_tail_nomask_idx = attn_metadata.prefill.cp_metadata.kv_with_q_tail_nomask_idx + kv_with_q_tail_mask_idx = attn_metadata.prefill.cp_metadata.kv_with_q_tail_mask_idx + attn_mask_seqlens = attn_metadata.prefill.cp_metadata.attn_mask_seqlens + head_attn_nomask_seqlens = attn_metadata.prefill.cp_metadata.head_attn_nomask_seqlens + tail_attn_nomask_seqlens = attn_metadata.prefill.cp_metadata.tail_attn_nomask_seqlens + mask = attn_metadata.prefill.cp_metadata.cp_prefill_mask + + # 1. Attention calculation in the first half of Q in load balancing + output_head = self._attention_with_nomask_and_mask( + q=torch.index_select(query, 0, q_head_idx), + q_seqlens=attn_mask_seqlens[0].tolist(), + k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx) + if self.cp_rank > 0 else None, + v_nomask=torch.index_select(value, 0, kv_with_q_head_nomask_idx) + if self.cp_rank > 0 else None, + kv_seqlens_nomask=head_attn_nomask_seqlens[1].tolist(), + k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx), + v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx), + kv_seqlens_mask=attn_mask_seqlens[0].tolist(), + mask=mask) + + # 2. the Attention calculation in the latter half of Q in load balancing + # cp_rank0: Q3*KV0~KV2 + Q3*KV3 + # cp_rank1: Q2*KV0~KV1 + Q2*KV2 + output_tail = self._attention_with_nomask_and_mask( + q=torch.index_select(query, 0, q_tail_idx), + q_seqlens=attn_mask_seqlens[0].tolist(), + k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx), + v_nomask=torch.index_select(value, 0, kv_with_q_tail_nomask_idx), + kv_seqlens_nomask=tail_attn_nomask_seqlens[1].tolist(), + k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx), + v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx), + kv_seqlens_mask=attn_mask_seqlens[0].tolist(), + mask=mask) + + # 3. Combine the output of the first half and second half. + q_full_idx = attn_metadata.prefill.cp_metadata.q_full_idx + output = torch.index_select( + torch.cat([output_head, output_tail], dim=0), 0, q_full_idx) + return output + + def _update_out_and_lse(self, out_list: torch.Tensor, + lse_list: torch.Tensor) -> torch.Tensor: + """LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i) + Args: + out_list: shape = [N, batch_size, num_heads, head_size] + lse_list: shape = [N, batch_size, num_heads, 1] + Returns: + out_final: shape = [batch_size, num_heads, head_size] + lse_final: shape = [batch_size, num_heads, 1] + """ + lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False) + out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list, + dim=0) + return out_final, lse_final + + def _forward_decode_dcp_cp(self, query: torch.Tensor, + attn_metadata: AscendMetadata) -> torch.Tensor: + assert self.key_cache is not None + assert self.value_cache is not None + + if self.dcp_size > 1: + query = get_dcp_group().all_gather( + query, 1) + num_heads = self.num_heads * self.dcp_size + else: + num_heads = self.num_heads + + # 1. Compute out&lse by "npu_fused_infer_attention_score" + attn_out, attn_lse = torch.ops.npu.npu_fused_infer_attention_score( + query.view(query.shape[0], 1, query.shape[1], query.shape[2]), + # [b,num_heads,head_size] -> [b,1,num_heads,head_size] + self.key_cache.view(self.key_cache.shape[0], + self.key_cache.shape[1], -1), + self.value_cache.view(self.key_cache.shape[0], + self.key_cache.shape[1], -1), + num_heads=num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="BSND", + atten_mask=None, + scale=self.scale, + antiquant_mode=0, + antiquant_scale=None, + softmax_lse_flag=True, + block_table=attn_metadata.block_tables, + block_size=self.key_cache.shape[1], + actual_seq_lengths_kv=attn_metadata.decode.num_computed_tokens_of_cp_sp[:, self.cp_rank, self.dcp_rank], + ) + + attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2], attn_out.shape[3]) + attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1) + if self.dcp_size > 1: + # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] + attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1) + # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs] + attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous() + attn_out_lse_all2all = torch.empty_like(attn_out_lse) + dist.all_to_all_single(attn_out_lse_all2all, + attn_out_lse, + group=self.dcp_group) + # permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1] + attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1]) + attn_out_lse_split_on_seq = list( + torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1)) + + attn_out_lse_split_dcp = torch.stack( + attn_out_lse_split_on_seq, + dim=0) # [dcp, batch_size, num_heads, head_size+1] + # Update out&lse + attn_out_split_dcp, attn_lse_split_dcp = torch.split( + attn_out_lse_split_dcp, [self.head_size, 1], dim=-1) + attn_out, attn_lse = self._update_out_and_lse(attn_out_split_dcp, + attn_lse_split_dcp) + if self.cp_size > 1: + # 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1] + attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1) + # 3. AllGather out&lse within CP group + attn_out_lse_list = [ + torch.empty_like(attn_out_lse) for _ in range(self.cp_size) + ] + dist.all_gather(attn_out_lse_list, attn_out_lse, group=self.cp_group) + # 4. Update out&lse + attn_out_lse_allgather = torch.stack( + attn_out_lse_list, + dim=0) # [cp, batch_size, num_heads, head_size+1] + attn_out_allgather, attn_lse_allgather = torch.split( + attn_out_lse_allgather, [self.head_size, 1], dim=-1) + attn_out, _ = self._update_out_and_lse(attn_out_allgather, + attn_lse_allgather) + return attn_out + def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Tuple[torch.Tensor], - attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, - trace_flag: bool = True, + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor], + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + trace_flag: bool = True, ) -> torch.Tensor: """Forward pass with Ascend attention. Args: @@ -587,28 +921,46 @@ def forward( # TODO: Remove this contiguous in the future. value = value.contiguous() + if self.cp_size > 1 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + kv = torch.cat([key, value], dim=-1) # [] + kv_list = [torch.empty_like(kv) for _ in range(self.cp_size)] + dist.all_gather(kv_list, kv, self.cp_group) + all_kv = torch.cat(kv_list, dim=0) + cp_kv_recover_idx = attn_metadata.prefill.cp_kv_recover_idx if attn_metadata.prefill else None + all_kv = torch.index_select(all_kv, 0, cp_kv_recover_idx) + key, value = all_kv.split([self.head_size, self.head_size], + dim=-1) + if len(kv_cache) > 1: if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping torch_npu._npu_reshape_and_cache( - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], + key=key if self.cp_size > 1 else key[:num_actual_tokens], + value=value + if self.cp_size > 1 else value[:num_actual_tokens], key_cache=self.key_cache, value_cache=self.value_cache, slot_indices=slots) # V0-Style scheduler situation. if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - output = self._forward_prefill_no_cache( - query, key, value, attn_metadata, output, num_tokens) + if self.cp_size > 1: + output = self._forward_prefill_cp(query, key, value, + attn_metadata) + else: + output = self._forward_prefill_no_cache( + query, key, value, attn_metadata, output, num_tokens) elif attn_metadata.attn_state == \ - AscendAttentionState.PrefillCacheHit: + AscendAttentionState.PrefillCacheHit: output = self._forward_prefill_cache_hit( query, attn_metadata, output) elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - output = self._forward_decode_only(query, attn_metadata, - output) + if self.cp_size * self.dcp_size > 1: + output = self._forward_decode_dcp_cp(query, attn_metadata) + else: + output = self._forward_decode_only(query, attn_metadata, + output) # Normal V1 situation. else: if torch.version.cann.startswith("8.3"): @@ -627,11 +979,11 @@ def forward( def unified_ascend_attention_with_output( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - layer_name: str, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, ) -> None: wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() @@ -653,11 +1005,11 @@ def unified_ascend_attention_with_output( def unified_attention_with_output_fake( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - layer_name: str, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, ) -> None: return diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 533c20ba14..d149819f59 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -21,7 +21,7 @@ import torch import torch_npu from vllm.config import get_current_vllm_config -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import (get_tensor_model_parallel_world_size, get_context_model_parallel_world_size) from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, get_tp_group) from vllm.forward_context import get_forward_context @@ -162,6 +162,7 @@ def __init__( tp_size: Optional[int] = None, ep_size: Optional[int] = None, dp_size: Optional[int] = None, + cp_size: Optional[int] = None, prefix: str = "", custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", @@ -206,6 +207,8 @@ def __init__( get_tensor_model_parallel_world_size()), dp_size_=(dp_size if dp_size is not None else get_dp_group().world_size), + cp_size_=(cp_size if cp_size is not None else + get_context_model_parallel_world_size()), vllm_parallel_config=vllm_config.parallel_config) self.top_k = top_k