From c2030f932d79396ad37d636aa112636f15e914c8 Mon Sep 17 00:00:00 2001 From: zhaozx-cn Date: Sun, 28 Sep 2025 11:04:56 +0800 Subject: [PATCH 1/2] shared expert dp for ds Signed-off-by: zhaozx-cn --- vllm_ascend/ascend_forward_context.py | 3 +++ vllm_ascend/attention/mla_v1.py | 15 ++++++--------- vllm_ascend/models/deepseek_v2.py | 2 +- vllm_ascend/models/layers/mla.py | 15 +++++++-------- vllm_ascend/ops/common_fused_moe.py | 3 ++- vllm_ascend/ops/layernorm.py | 2 +- vllm_ascend/ops/linear_op.py | 11 ++++++++--- vllm_ascend/platform.py | 5 +---- vllm_ascend/utils.py | 3 ++- 9 files changed, 31 insertions(+), 28 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 8888b70186..7886224f77 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -12,6 +12,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.utils import enable_sp +from vllm_ascend.ascend_config import get_ascend_config class FusedMoEState(Enum): @@ -110,6 +111,8 @@ def set_ascend_forward_context( tp_world_size > 1 and \ num_tokens is not None and num_tokens > 1000 + if get_ascend_config().enable_shared_expert_dp: + sp_enabled = True if sp_enabled: pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 73cbae6207..8c9de0ba62 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -887,9 +887,10 @@ def _mla_preprocess(self, layer_name, hidden_states, kv_cache, kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] # Process for shared_expert_dp - if need_gather_q_kv: - q_c = get_tp_group().all_gather(q_c, 0) - kv_no_split = get_tp_group().all_gather(kv_no_split, 0) + q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + q_c, need_gather_q_kv) + kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + kv_no_split, need_gather_q_kv) decode_preprocess_res = None prefill_preprocess_res = None if has_prefill: @@ -1011,9 +1012,7 @@ def forward( enabled=self.enable_prefetch) output[...] = self.o_proj( - o_proj_input, - is_prefill=prefill_preprocess_res is not None, - is_force_scatter=self.enable_shared_expert_dp)[0] + o_proj_input)[0] else: with torch.npu.stream(current_ms_metadata.comm_stream): npu_prefetch(self.o_proj.weight, @@ -1021,9 +1020,7 @@ def forward( max_size=MAX_O_PROJ_PREFETCH_SIZE, enabled=self.enable_prefetch) output[...] = self.o_proj( - o_proj_input, - is_prefill=prefill_preprocess_res is not None, - is_force_scatter=self.enable_shared_expert_dp)[0] + o_proj_input)[0] current_ms_metadata.after_comm_event.record() del o_proj_input diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 988de33460..aeef81dba8 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -181,7 +181,7 @@ def __init__( bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_b_proj") - self.o_proj = CustomDeepseekV2RowParallelLinear( + self.o_proj = RowParallelLinear( self.num_heads * self.v_head_dim, self.hidden_size, bias=False, diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py index 57c91bd278..8bb6e098c6 100644 --- a/vllm_ascend/models/layers/mla.py +++ b/vllm_ascend/models/layers/mla.py @@ -120,18 +120,17 @@ def forward( hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - num_tokens = hidden_states.shape[0] + forward_context = get_forward_context() + sp_enabled = forward_context.sp_enabled need_gather_q_kv = False - if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: - # Simulate all gather to calculate output shape - num_tokens = num_tokens * self.tp_size + if sp_enabled and self.debug_layer_idx > 0: need_gather_q_kv = True - if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: + if not self.enable_shared_expert_dp or self.debug_layer_idx > 0: output_shape = hidden_states.shape else: - rows = num_tokens // self.tp_size - if num_tokens % self.tp_size: - rows += 1 + num_tokens = hidden_states.shape[0] + pad_size = forward_context.pad_size + rows = (num_tokens + pad_size) // self.tp_size output_shape = (rows, hidden_states.shape[1]) # FIXME: This does not seem right, should make sure the buffer is fixed output = torch.empty(output_shape, diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 554b40e002..faf8e80db6 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -346,7 +346,8 @@ def forward( forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: - shared_out = tensor_model_parallel_all_reduce(shared_out) + if not get_forward_context().sp_enabled: + shared_out = tensor_model_parallel_all_reduce(shared_out) _, fused_out = AscendFusedMoE.forward( self, diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index da48362f46..d05c3dc4ae 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -65,7 +65,7 @@ def forward_oot( if residual is not None: residual = torch.ops.vllm.maybe_chunk_residual(x, residual) - assert x.size(0) == residual.size(0) + #assert x.size(0) == residual.size(0) x, residual = _addrmsnorm_forward_oot( self, x, residual, self.next_need_quant_fusion_linear) return x, residual diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index f6feadde60..42be57a9d9 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -46,6 +46,7 @@ from vllm.distributed import split_tensor_along_last_dim from vllm.distributed.parallel_state import get_tp_group +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) from vllm_ascend.utils import (dense_optim_enable, enable_sp, @@ -418,11 +419,13 @@ def get_column_parallel_op( SequenceMergedColumnParallelOp, SequenceQKVParallelOp, ]] = None - if "gate_up_proj" in prefix and mlp_tp_enable(): + if "shared_experts.gate_up_proj" in prefix and enable_sp(): + return None, 0, 1 + elif "gate_up_proj" in prefix and mlp_tp_enable(): custom_op = MLPColumnParallelOp(layer) elif "gate_up_proj" in prefix and enable_sp(): custom_op = SequenceMergedColumnParallelOp(layer) - elif enable_sp(): + elif "qkv_proj" in prefix and enable_sp(): custom_op = SequenceQKVParallelOp(layer, prefix) if custom_op is not None: @@ -442,7 +445,9 @@ def get_row_parallel_op( custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp]] = None - if "down_proj" in prefix and mlp_tp_enable(): + if "shared_experts.down_proj" in prefix and enable_sp(): + return None, 0, 1 + elif "down_proj" in prefix and mlp_tp_enable(): custom_op = MLPRowParallelOp(layer) elif "o_proj" in prefix and oproj_tp_enable(): custom_op = OProjRowParallelOp(layer) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index dbfe1dc137..cd3ddaad0e 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -260,7 +260,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: compilation_config.level = CompilationLevel.NO_COMPILATION if parallel_config and parallel_config.worker_cls == "auto": - if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp: + if ascend_config.torchair_graph_config.enabled: parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker" else: parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker" @@ -303,9 +303,6 @@ def get_attn_backend_cls(cls, ascend_config = get_ascend_config() - if use_mla and ascend_config.enable_shared_expert_dp: - return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend" - use_torchair = ascend_config.torchair_graph_config.enabled # choose attention backend based on use_mla and use_torchair backend_map = { diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 570756fd0c..5febb6bc17 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -595,7 +595,8 @@ def enable_sp() -> bool: return ( get_cached_compilation_config().pass_config.enable_sequence_parallelism - or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM) + or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM + or get_ascend_config().enable_shared_expert_dp) def is_moe_model(vllm_config: VllmConfig): From 61b3e6a7ecfbfe05bc116a3c9ee0dabba86abb02 Mon Sep 17 00:00:00 2001 From: zhaozx-cn Date: Tue, 30 Sep 2025 10:43:00 +0800 Subject: [PATCH 2/2] shared expert dp for deppseek aclgraph Signed-off-by: zhaozx-cn --- vllm_ascend/ascend_forward_context.py | 1 + vllm_ascend/attention/mla_v1.py | 6 ++++-- vllm_ascend/models/layers/mla.py | 5 +---- vllm_ascend/ops/register_custom_ops.py | 29 +++++++++++++++++++++++--- vllm_ascend/worker/model_runner_v1.py | 27 ++++++++++++++++++++++-- 5 files changed, 57 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 7886224f77..e829dd0517 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -118,6 +118,7 @@ def set_ascend_forward_context( (num_tokens % tp_world_size)) % tp_world_size forward_context.pad_size = pad_size forward_context.sp_enabled = sp_enabled + forward_context.num_tokens = num_tokens # set this for rope forward_oot using forward_context.is_first_layer = True diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 8c9de0ba62..bb94dab18f 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -26,6 +26,7 @@ from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.utils import npu_prefetch from vllm_ascend.worker.npu_input_batch import InputBatch +from vllm.forward_context import get_forward_context if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -960,7 +961,8 @@ def forward( # Inputs and outputs may be padded for CUDA graphs output_padded = output output = output[:num_actual_tokens, ...] - o_proj_input_shape = (num_actual_tokens, + num_padded_tokens = get_forward_context().num_tokens + o_proj_input_shape = (num_padded_tokens, self.num_heads * self.v_head_dim) o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, @@ -1001,7 +1003,7 @@ def forward( o_proj_input[num_decode_tokens:] = output_prefill current_ms_metadata.after_comm_event.record() else: - o_proj_input[num_decode_tokens:] = output_prefill + o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill # O proj current_ms_metadata = get_multistream_comm_context() MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py index 8bb6e098c6..a5e3c8a705 100644 --- a/vllm_ascend/models/layers/mla.py +++ b/vllm_ascend/models/layers/mla.py @@ -128,10 +128,7 @@ def forward( if not self.enable_shared_expert_dp or self.debug_layer_idx > 0: output_shape = hidden_states.shape else: - num_tokens = hidden_states.shape[0] - pad_size = forward_context.pad_size - rows = (num_tokens + pad_size) // self.tp_size - output_shape = (rows, hidden_states.shape[1]) + output_shape = torch.chunk(hidden_states,self.tp_size,dim=0)[0].shape # FIXME: This does not seem right, should make sure the buffer is fixed output = torch.empty(output_shape, dtype=hidden_states.dtype, diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index a702b3521d..c5d6494405 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -9,6 +9,7 @@ from vllm.forward_context import get_forward_context from vllm.utils import direct_register_custom_op +from vllm_ascend.ascend_config import get_ascend_config import vllm_ascend.envs as envs_ascend @@ -33,6 +34,11 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, return residual +def fake_maybe_chunk_residual_impl(x: torch.Tensor, + residual: torch.Tensor) -> torch.Tensor: + return torch.empty_like(x) + + def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool) -> torch.Tensor: try: @@ -49,6 +55,15 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, return x +def fake_maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, + label: bool) -> torch.Tensor: + if get_ascend_config().enable_shared_expert_dp: + tp_size = get_tensor_model_parallel_world_size() + num_tokens = x.shape[0]*tp_size + return torch.empty((num_tokens,x.shape[1]),dtype=x.dtype,device=x.device) + return x + + def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: try: forward_context = get_forward_context() @@ -65,6 +80,14 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: return tensor_model_parallel_all_reduce(x) +def fake_maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: + if get_ascend_config().enable_shared_expert_dp: + tp_size = get_tensor_model_parallel_world_size() + num_tokens = x.shape[0] // tp_size + return torch.empty((num_tokens,x.shape[1]),dtype=x.dtype,device=x.device) + return x + + def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, prefix: str) -> None: try: @@ -149,18 +172,18 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None: direct_register_custom_op(op_name="maybe_chunk_residual", op_func=_maybe_chunk_residual_impl, - fake_impl=lambda x, residual: residual, + fake_impl=fake_maybe_chunk_residual_impl, mutates_args=[], dispatch_key="PrivateUse1") direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", op_func=_maybe_all_gather_and_maybe_unpad_impl, - fake_impl=lambda x, label: x, + fake_impl=fake_maybe_all_gather_and_maybe_unpad_impl, mutates_args=[], dispatch_key="PrivateUse1") direct_register_custom_op(op_name="maybe_pad_and_reduce", - op_func=_maybe_pad_and_reduce_impl, + op_func=fake_maybe_pad_and_reduce_impl, fake_impl=lambda x: x, mutates_args=[], dispatch_key="PrivateUse1") diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ff055e47b6..ff09be0c7b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -44,7 +44,8 @@ from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config) -from vllm.distributed import tensor_model_parallel_all_gather +from vllm.distributed import (tensor_model_parallel_all_gather, + get_tensor_model_parallel_world_size) from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 @@ -3361,6 +3362,27 @@ def initialize_aclgraph_capture(self) -> None: self.aclgraph_dispatcher.initialize_cudagraph_keys( self.compilation_config.cudagraph_mode, self.uniform_decode_query_len) + + def update_sizes_for_sequence_parallelism(self, + possible_sizes: list) -> list: + # remove the sizes that not multiple of tp_size when + # enable sequence parallelism + tp_size = get_tensor_model_parallel_world_size + removed_sizes = [ + size for size in possible_sizes + if size % tp_size != 0 + ] + if removed_sizes: + logger.warning( + "Batch sizes %s are removed because they are not " + "multiple of tp_size %d when " + "sequence parallelism is enabled", removed_sizes, + tp_size) + + return [ + size for size in possible_sizes + if size % tp_size == 0 + ] def _capture_aclgraphs(self, compilation_cases: list[int], aclgraph_runtime_mode: CUDAGraphMode, @@ -3368,7 +3390,8 @@ def _capture_aclgraphs(self, compilation_cases: list[int], assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \ aclgraph_runtime_mode in [CUDAGraphMode.FULL, CUDAGraphMode.PIECEWISE] - + + compilation_cases = self.update_sizes_for_sequence_parallelism(compilation_cases) # Only rank 0 should print progress bar during capture if is_global_first_rank(): logger.info(