-
Notifications
You must be signed in to change notification settings - Fork 468
[main] flashcomm_v2 optim solution #3232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,8 +2,11 @@ | |||||
|
||||||
import torch | ||||||
from vllm.config import ParallelConfig | ||||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, | ||||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, get_tp_group, | ||||||
init_model_parallel_group) | ||||||
import vllm_ascend.envs as envs_ascend | ||||||
from vllm.logger import logger | ||||||
from vllm_ascend.utils import flashcomm2_enable, oproj_tp_enable | ||||||
|
||||||
import vllm_ascend.envs as envs_ascend | ||||||
from vllm_ascend.ascend_config import get_ascend_config | ||||||
|
@@ -13,6 +16,9 @@ | |||||
_MLP_TP: Optional[GroupCoordinator] = None | ||||||
_OTP: Optional[GroupCoordinator] = None | ||||||
_LMTP: Optional[GroupCoordinator] = None | ||||||
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None | ||||||
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None | ||||||
|
||||||
|
||||||
|
||||||
def get_mc2_group() -> GroupCoordinator: | ||||||
|
@@ -25,12 +31,18 @@ | |||||
"output tensor parallel group is not initialized") | ||||||
return _OTP | ||||||
|
||||||
|
||||||
def get_lmhead_tp_group() -> GroupCoordinator: | ||||||
assert _LMTP is not None, ( | ||||||
"lm head tensor parallel group is not initialized") | ||||||
return _LMTP | ||||||
|
||||||
def get_flashcomm2_otp_group() -> GroupCoordinator: | ||||||
return _FLASHCOMM2_OTP | ||||||
|
||||||
def get_flashcomm2_odp_group() -> GroupCoordinator: | ||||||
assert _FLASHCOMM2_ODP is not None, ( | ||||||
"output data parallel group for flashcomm2 is not initialized") | ||||||
return _FLASHCOMM2_ODP | ||||||
|
||||||
def get_mlp_tp_group() -> GroupCoordinator: | ||||||
assert _MLP_TP is not None, ("mlp group is not initialized") | ||||||
|
@@ -110,6 +122,39 @@ | |||||
get_world_group().local_rank, | ||||||
backend, | ||||||
group_name="lmheadtp") | ||||||
|
||||||
if flashcomm2_enable(): | ||||||
flashcomm2_otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size | ||||||
global_tp_size = get_tp_group().world_size | ||||||
num_oproj_tensor_parallel_groups: int = (global_tp_size // flashcomm2_otp_size) | ||||||
|
||||||
global _FLASHCOMM2_OTP | ||||||
global _FLASHCOMM2_ODP | ||||||
Check failure on line 132 in vllm_ascend/distributed/parallel_state.py
|
||||||
|
||||||
_FLASHCOMM2_OTP = None | ||||||
_FLASHCOMM2_ODP = get_tp_group() | ||||||
|
||||||
if flashcomm2_otp_size > 1: | ||||||
otp_group_ranks = [] | ||||||
odp_group_ranks = [[] for _ in range(flashcomm2_otp_size)] | ||||||
dp_group_index = torch.distributed.get_rank() // global_tp_size | ||||||
|
||||||
for i in range(num_oproj_tensor_parallel_groups): | ||||||
ranks = [] | ||||||
Check failure on line 143 in vllm_ascend/distributed/parallel_state.py
|
||||||
for j in range(flashcomm2_otp_size): | ||||||
rank_idx = dp_group_index * global_tp_size + i + j * num_oproj_tensor_parallel_groups | ||||||
ranks.append(rank_idx) | ||||||
odp_group_ranks[j].append(rank_idx) | ||||||
otp_group_ranks.append(ranks) | ||||||
|
||||||
_FLASHCOMM2_OTP = init_model_parallel_group(otp_group_ranks, | ||||||
get_world_group().local_rank, | ||||||
backend, | ||||||
group_name="flashcomm2_otp") | ||||||
_FLASHCOMM2_ODP = init_model_parallel_group(odp_group_ranks, | ||||||
get_world_group().local_rank, | ||||||
backend, | ||||||
group_name="flashcomm2_odp") | ||||||
|
||||||
|
||||||
def get_mlp_tensor_model_parallel_world_size(): | ||||||
|
@@ -142,3 +187,13 @@ | |||||
if _OTP: | ||||||
_OTP.destroy() | ||||||
_OTP = None | ||||||
|
||||||
global _FLASHCOMM2_OTP | ||||||
if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition
Suggested change
|
||||||
_FLASHCOMM2_OTP.destroy() | ||||||
_FLASHCOMM2_OTP = None | ||||||
|
||||||
global _FLASHCOMM2_ODP | ||||||
if _FLASHCOMM2_ODP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1: | ||||||
_FLASHCOMM2_ODP.destroy() | ||||||
_FLASHCOMM2_ODP = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The process group creation for FlashComm2 is guarded by
if flashcomm2_otp_size > 1:
. This causes_FLASHCOMM2_OTP
to beNone
whenflashcomm2_oproj_tensor_parallel_size
is 1. However,Flashcomm2OProjRowParallelOp
is still used in this case, and it attempts to access methods on the_FLASHCOMM2_OTP
group, which will lead to a crash. The logic within thisif
block appears to correctly handle thesize == 1
case by creating groups of size 1. The conditional guard should be removed, and its content unindented, to fix this critical bug.