Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp
from vllm_ascend.ascend_config import get_ascend_config


class FusedMoEState(Enum):
Expand Down Expand Up @@ -110,11 +111,14 @@ def set_ascend_forward_context(
tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000

if get_ascend_config().enable_shared_expert_dp:
sp_enabled = True
if sp_enabled:
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
forward_context.sp_enabled = sp_enabled
forward_context.num_tokens = num_tokens

# set this for rope forward_oot using
forward_context.is_first_layer = True
Expand Down
21 changes: 10 additions & 11 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.utils import npu_prefetch
from vllm_ascend.worker.npu_input_batch import InputBatch
from vllm.forward_context import get_forward_context

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -887,9 +888,10 @@ def _mla_preprocess(self, layer_name, hidden_states, kv_cache,

kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
# Process for shared_expert_dp
if need_gather_q_kv:
q_c = get_tp_group().all_gather(q_c, 0)
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
q_c, need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split, need_gather_q_kv)
decode_preprocess_res = None
prefill_preprocess_res = None
if has_prefill:
Expand Down Expand Up @@ -959,7 +961,8 @@ def forward(
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_tokens, ...]
o_proj_input_shape = (num_actual_tokens,
num_padded_tokens = get_forward_context().num_tokens
o_proj_input_shape = (num_padded_tokens,
self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states.dtype,
Expand Down Expand Up @@ -1000,7 +1003,7 @@ def forward(
o_proj_input[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
else:
o_proj_input[num_decode_tokens:] = output_prefill
o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
Expand All @@ -1011,19 +1014,15 @@ def forward(
enabled=self.enable_prefetch)

output[...] = self.o_proj(
o_proj_input,
is_prefill=prefill_preprocess_res is not None,
is_force_scatter=self.enable_shared_expert_dp)[0]
o_proj_input)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(
o_proj_input,
is_prefill=prefill_preprocess_res is not None,
is_force_scatter=self.enable_shared_expert_dp)[0]
o_proj_input)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input

Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __init__(
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
self.o_proj = CustomDeepseekV2RowParallelLinear(
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
Expand Down
14 changes: 5 additions & 9 deletions vllm_ascend/models/layers/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,15 @@ def forward(
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
forward_context = get_forward_context()
sp_enabled = forward_context.sp_enabled
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# Simulate all gather to calculate output shape
num_tokens = num_tokens * self.tp_size
if sp_enabled and self.debug_layer_idx > 0:
need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
if not self.enable_shared_expert_dp or self.debug_layer_idx > 0:
output_shape = hidden_states.shape
else:
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
output_shape = torch.chunk(hidden_states,self.tp_size,dim=0)[0].shape
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,9 @@ def forward_impl(self, hidden_states: torch.Tensor,
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
shared_out = tensor_model_parallel_all_reduce(shared_out)
if not get_forward_context().sp_enabled:
shared_out = tensor_model_parallel_all_reduce(shared_out)

fused_output = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def forward_oot(

if residual is not None:
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
#assert x.size(0) == residual.size(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Commenting out assertions is dangerous as it can hide underlying bugs. The assertion x.size(0) == residual.size(0) was likely failing. Instead of disabling it, the root cause should be investigated and fixed. The preceding line residual = torch.ops.vllm.maybe_chunk_residual(x, residual) is supposed to ensure tensor shapes are compatible for the residual connection. If this assertion fails, it indicates a problem with maybe_chunk_residual or the tensors x and residual passed to it, especially in sequence parallelism scenarios where tensor shapes can be tricky. Please either fix the underlying issue and re-enable the assertion, or provide a detailed explanation for why this assertion is no longer valid.

x, residual = _addrmsnorm_forward_oot(
self, x, residual, self.next_need_quant_fusion_linear)
return x, residual
Expand Down
11 changes: 8 additions & 3 deletions vllm_ascend/ops/linear_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from vllm.distributed import split_tensor_along_last_dim
from vllm.distributed.parallel_state import get_tp_group

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
Expand Down Expand Up @@ -420,11 +421,13 @@ def get_column_parallel_op(
SequenceMergedColumnParallelOp,
SequenceQKVParallelOp,
]] = None
if "gate_up_proj" in prefix and mlp_tp_enable():
if "shared_experts.gate_up_proj" in prefix and enable_sp():
return None, 0, 1
elif "gate_up_proj" in prefix and mlp_tp_enable():
custom_op = MLPColumnParallelOp(layer)
elif "gate_up_proj" in prefix and enable_sp():
custom_op = SequenceMergedColumnParallelOp(layer)
elif enable_sp():
elif "qkv_proj" in prefix and enable_sp():
custom_op = SequenceQKVParallelOp(layer, prefix)

if custom_op is not None:
Expand All @@ -444,7 +447,9 @@ def get_row_parallel_op(
custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]] = None
if "down_proj" in prefix and mlp_tp_enable():
if "shared_experts.down_proj" in prefix and enable_sp():
return None, 0, 1
elif "down_proj" in prefix and mlp_tp_enable():
custom_op = MLPRowParallelOp(layer)
elif "o_proj" in prefix and oproj_tp_enable():
custom_op = OProjRowParallelOp(layer)
Expand Down
29 changes: 26 additions & 3 deletions vllm_ascend/ops/register_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op

from vllm_ascend.ascend_config import get_ascend_config
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType

Expand All @@ -34,6 +35,11 @@ def _maybe_chunk_residual_impl(x: torch.Tensor,
return residual


def fake_maybe_chunk_residual_impl(x: torch.Tensor,
residual: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)


def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
label: bool) -> torch.Tensor:
try:
Expand All @@ -50,6 +56,15 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
return x


def fake_maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
label: bool) -> torch.Tensor:
if get_ascend_config().enable_shared_expert_dp:
tp_size = get_tensor_model_parallel_world_size()
num_tokens = x.shape[0]*tp_size
return torch.empty((num_tokens,x.shape[1]),dtype=x.dtype,device=x.device)
return x


def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
try:
forward_context = get_forward_context()
Expand All @@ -66,6 +81,14 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
return tensor_model_parallel_all_reduce(x)


def fake_maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
if get_ascend_config().enable_shared_expert_dp:
tp_size = get_tensor_model_parallel_world_size()
num_tokens = x.shape[0] // tp_size
return torch.empty((num_tokens,x.shape[1]),dtype=x.dtype,device=x.device)
return x


def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
prefix: str) -> None:
try:
Expand Down Expand Up @@ -160,18 +183,18 @@ def _maybe_all_reduce_tensor_model_parallel_impl(

direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: residual,
fake_impl=fake_maybe_chunk_residual_impl,
mutates_args=[],
dispatch_key="PrivateUse1")

direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
op_func=_maybe_all_gather_and_maybe_unpad_impl,
fake_impl=lambda x, label: x,
fake_impl=fake_maybe_all_gather_and_maybe_unpad_impl,
mutates_args=[],
dispatch_key="PrivateUse1")

direct_register_custom_op(op_name="maybe_pad_and_reduce",
op_func=_maybe_pad_and_reduce_impl,
op_func=fake_maybe_pad_and_reduce_impl,
fake_impl=lambda x: x,
mutates_args=[],
dispatch_key="PrivateUse1")
Expand Down
4 changes: 1 addition & 3 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if parallel_config and parallel_config.worker_cls == "auto":
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
os.environ["VLLM_ALL2ALL_BACKEND"] = "flashinfer_all2allv"
if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp:
if ascend_config.torchair_graph_config.enabled:
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
else:
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
Expand Down Expand Up @@ -308,8 +308,6 @@ def get_attn_backend_cls(cls,
ascend_config = get_ascend_config()

if use_mla and ascend_config.enable_shared_expert_dp:
if use_mla and not use_sfa:
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
if use_mla and use_sfa:
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"

Expand Down
5 changes: 3 additions & 2 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,9 @@
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return (
vllm_config.compilation_config.pass_config.enable_sequence_parallelism
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM)
get_cached_compilation_config().pass_config.enable_sequence_parallelism

Check failure on line 605 in vllm_ascend/utils.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "get_cached_compilation_config" is not defined [name-defined]

Check failure on line 605 in vllm_ascend/utils.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "get_cached_compilation_config" is not defined [name-defined]

Check failure on line 605 in vllm_ascend/utils.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "get_cached_compilation_config" is not defined [name-defined]

Check failure on line 605 in vllm_ascend/utils.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "get_cached_compilation_config" is not defined [name-defined]

Check failure on line 605 in vllm_ascend/utils.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "get_cached_compilation_config" is not defined [name-defined]

Check failure on line 605 in vllm_ascend/utils.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "get_cached_compilation_config" is not defined [name-defined]

Check failure on line 605 in vllm_ascend/utils.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "get_cached_compilation_config" is not defined [name-defined]

Check failure on line 605 in vllm_ascend/utils.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "get_cached_compilation_config" is not defined [name-defined]

Check failure on line 605 in vllm_ascend/utils.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "get_cached_compilation_config" is not defined [name-defined]

Check failure on line 605 in vllm_ascend/utils.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Ruff (F821)

vllm_ascend/utils.py:605:9: F821 Undefined name `get_cached_compilation_config`
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM
or get_ascend_config().enable_shared_expert_dp)


def is_moe_model(vllm_config: VllmConfig):
Expand Down
27 changes: 25 additions & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed import (tensor_model_parallel_all_gather,
get_tensor_model_parallel_world_size)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
Expand Down Expand Up @@ -3498,14 +3499,36 @@ def initialize_aclgraph_capture(self) -> None:
self.aclgraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode,
self.uniform_decode_query_len)

def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when
# enable sequence parallelism
tp_size = get_tensor_model_parallel_world_size
removed_sizes = [
size for size in possible_sizes
if size % tp_size != 0
]
if removed_sizes:
logger.warning(
"Batch sizes %s are removed because they are not "
"multiple of tp_size %d when "
"sequence parallelism is enabled", removed_sizes,
tp_size)

return [
size for size in possible_sizes
if size % tp_size == 0
]

def _capture_aclgraphs(self, compilation_cases: list[int],
aclgraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool):
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
aclgraph_runtime_mode in [CUDAGraphMode.FULL,
CUDAGraphMode.PIECEWISE]


compilation_cases = self.update_sizes_for_sequence_parallelism(compilation_cases)
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
logger.info(
Expand Down
Loading