Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,25 @@ def __init__(self, vllm_config):
raise AssertionError(
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
)
self.flashcomm2_oproj_tensor_parallel_size = additional_config.get(
"flashcomm2_oproj_tensor_parallel_size", None)
if self.flashcomm2_oproj_tensor_parallel_size is not None:
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
logger.info(
f"Enable Flashcomm2 with flashcomm2_oproj_tensor_parallel_size={self.flashcomm2_oproj_tensor_parallel_size} and global_tp_size={global_tp_size}"
)
if self.oproj_tensor_parallel_size is not None:
raise AssertionError(
f"flashcomm2_oproj_tensor_parallel_size cannot be enabled simultaneously with oproj_tensor_parallel_size"
)
if global_tp_size <= self.flashcomm2_oproj_tensor_parallel_size:
raise AssertionError(
f"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})"
)
if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0:
raise AssertionError(
f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})"
)


class TorchairGraphConfig:
Expand Down
9 changes: 8 additions & 1 deletion vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp
from vllm_ascend.utils import flashcomm2_enable


class FusedMoEState(Enum):
Expand Down Expand Up @@ -109,12 +110,18 @@ def set_ascend_forward_context(
sp_enabled = enable_sp(vllm_config) and \
tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000

flashcomm_v2_enabled = flashcomm2_enable() and \
tp_world_size > 1 and \
num_tokens is not None

if sp_enabled:
if sp_enabled or flashcomm_v2_enabled:
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size

forward_context.sp_enabled = sp_enabled
forward_context.flashcomm_v2_enabled = flashcomm_v2_enabled

# set this for rope forward_oot using
forward_context.is_first_layer = True
Expand Down
59 changes: 57 additions & 2 deletions vllm_ascend/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import torch
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, get_tp_group,
init_model_parallel_group)
import vllm_ascend.envs as envs_ascend
from vllm.logger import logger
from vllm_ascend.utils import flashcomm2_enable, oproj_tp_enable

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
Expand All @@ -13,6 +16,9 @@
_MLP_TP: Optional[GroupCoordinator] = None
_OTP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None



def get_mc2_group() -> GroupCoordinator:
Expand All @@ -25,12 +31,18 @@
"output tensor parallel group is not initialized")
return _OTP


def get_lmhead_tp_group() -> GroupCoordinator:
assert _LMTP is not None, (
"lm head tensor parallel group is not initialized")
return _LMTP

def get_flashcomm2_otp_group() -> GroupCoordinator:
return _FLASHCOMM2_OTP

def get_flashcomm2_odp_group() -> GroupCoordinator:
assert _FLASHCOMM2_ODP is not None, (
"output data parallel group for flashcomm2 is not initialized")
return _FLASHCOMM2_ODP

def get_mlp_tp_group() -> GroupCoordinator:
assert _MLP_TP is not None, ("mlp group is not initialized")
Expand Down Expand Up @@ -110,6 +122,39 @@
get_world_group().local_rank,
backend,
group_name="lmheadtp")

if flashcomm2_enable():
flashcomm2_otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size
global_tp_size = get_tp_group().world_size
num_oproj_tensor_parallel_groups: int = (global_tp_size // flashcomm2_otp_size)

global _FLASHCOMM2_OTP
global _FLASHCOMM2_ODP

Check failure on line 132 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "num_oproj_tensor_parallel_groups" already defined on line 102 [no-redef]

Check failure on line 132 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "num_oproj_tensor_parallel_groups" already defined on line 102 [no-redef]

Check failure on line 132 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "num_oproj_tensor_parallel_groups" already defined on line 102 [no-redef]

Check failure on line 132 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "num_oproj_tensor_parallel_groups" already defined on line 102 [no-redef]

Check failure on line 132 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "num_oproj_tensor_parallel_groups" already defined on line 102 [no-redef]

_FLASHCOMM2_OTP = None
_FLASHCOMM2_ODP = get_tp_group()

if flashcomm2_otp_size > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The process group creation for FlashComm2 is guarded by if flashcomm2_otp_size > 1:. This causes _FLASHCOMM2_OTP to be None when flashcomm2_oproj_tensor_parallel_size is 1. However, Flashcomm2OProjRowParallelOp is still used in this case, and it attempts to access methods on the _FLASHCOMM2_OTP group, which will lead to a crash. The logic within this if block appears to correctly handle the size == 1 case by creating groups of size 1. The conditional guard should be removed, and its content unindented, to fix this critical bug.

otp_group_ranks = []
odp_group_ranks = [[] for _ in range(flashcomm2_otp_size)]
dp_group_index = torch.distributed.get_rank() // global_tp_size

for i in range(num_oproj_tensor_parallel_groups):
ranks = []

Check failure on line 143 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "odp_group_ranks" [var-annotated]

Check failure on line 143 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "odp_group_ranks" [var-annotated]

Check failure on line 143 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "odp_group_ranks" [var-annotated]

Check failure on line 143 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "odp_group_ranks" [var-annotated]

Check failure on line 143 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "odp_group_ranks" [var-annotated]
for j in range(flashcomm2_otp_size):
rank_idx = dp_group_index * global_tp_size + i + j * num_oproj_tensor_parallel_groups
ranks.append(rank_idx)
odp_group_ranks[j].append(rank_idx)
otp_group_ranks.append(ranks)

_FLASHCOMM2_OTP = init_model_parallel_group(otp_group_ranks,
get_world_group().local_rank,
backend,
group_name="flashcomm2_otp")
_FLASHCOMM2_ODP = init_model_parallel_group(odp_group_ranks,
get_world_group().local_rank,
backend,
group_name="flashcomm2_odp")


def get_mlp_tensor_model_parallel_world_size():
Expand Down Expand Up @@ -142,3 +187,13 @@
if _OTP:
_OTP.destroy()
_OTP = None

global _FLASHCOMM2_OTP
if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1 will prevent the _FLASHCOMM2_OTP group from being destroyed when its size is 1. If the initialization logic is fixed to create a group for size 1 (as suggested in another comment), this will cause a resource leak. The group should be destroyed if it was created, regardless of its size.

Suggested change
if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1:
if _FLASHCOMM2_OTP:

_FLASHCOMM2_OTP.destroy()
_FLASHCOMM2_OTP = None

global _FLASHCOMM2_ODP
if _FLASHCOMM2_ODP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1:
_FLASHCOMM2_ODP.destroy()
_FLASHCOMM2_ODP = None
13 changes: 13 additions & 0 deletions vllm_ascend/models/layers/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -133,6 +134,18 @@ def forward(
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])

forward_context = get_forward_context()
is_prefill = forward_context.with_prefill
if forward_context.flashcomm_v2_enabled and forward_context.flashcomm1_ds_prefill:
num_padding_tokens = forward_context.pad_size
if is_prefill and self.debug_layer_idx > 0 and self.debug_layer_idx < self.layers:
output_shape = hidden_states.shape
else:
B = (hidden_states.shape[0] + num_padding_tokens) // get_tensor_model_parallel_world_size()
H = hidden_states.shape[1]
output_shape = (B, H)

# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def forward(
self,
input_,
is_prefill: bool = True,
is_force_scatter: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.custom_op is not None:
return self.custom_op.apply(input_)
Expand Down
Loading
Loading