-
Notifications
You must be signed in to change notification settings - Fork 468
[main] flashcomm_v2 optim solution #3232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the FlashComm2 optimization for tensor parallelism on Ascend NPUs, aiming to improve performance by optimizing communication patterns. The changes span configuration, parallel state management, and operator implementations. My review has identified a few issues: a critical bug in the parallel group initialization that can lead to a crash, a related potential resource leak in the group destruction logic, and incorrect formatting of error messages in the configuration validation. These issues should be addressed to ensure correctness and robustness.
_FLASHCOMM2_OTP = None | ||
_FLASHCOMM2_ODP = get_tp_group() | ||
|
||
if flashcomm2_otp_size > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The process group creation for FlashComm2 is guarded by if flashcomm2_otp_size > 1:
. This causes _FLASHCOMM2_OTP
to be None
when flashcomm2_oproj_tensor_parallel_size
is 1. However, Flashcomm2OProjRowParallelOp
is still used in this case, and it attempts to access methods on the _FLASHCOMM2_OTP
group, which will lead to a crash. The logic within this if
block appears to correctly handle the size == 1
case by creating groups of size 1. The conditional guard should be removed, and its content unindented, to fix this critical bug.
raise AssertionError( | ||
"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})" | ||
) | ||
if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0: | ||
raise AssertionError( | ||
"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message strings are not f-strings, so the variables inside the curly braces will not be interpolated. This will result in confusing and unhelpful error messages for users.
raise AssertionError( | |
"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})" | |
) | |
if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0: | |
raise AssertionError( | |
"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})" | |
) | |
raise AssertionError( | |
f"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})" | |
) | |
if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0: | |
raise AssertionError( | |
f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})" | |
) |
_OTP = None | ||
|
||
global _FLASHCOMM2_OTP | ||
if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1
will prevent the _FLASHCOMM2_OTP
group from being destroyed when its size is 1. If the initialization logic is fixed to create a group for size 1 (as suggested in another comment), this will cause a resource leak. The group should be destroyed if it was created, regardless of its size.
if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1: | |
if _FLASHCOMM2_OTP: |
8b9a5a2
to
5b6c013
Compare
5b6c013
to
1b8cdb3
Compare
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
What this PR does / why we need it?
Supports generalized FlashComm2 optimization, which reduces communication overhead, decreases RmsNorm computation, and saves one AllGather step by replacing Allreduce operations in Attention/MLP modules with pre-AlltoAll and post-AllGather operations. This feature is enabled during the Prefill phase and delivers broad performance improvements, especially in long sequence scenarios with large tensor parallelism (TP). Benchmark tests show a 10%-20% performance acceleration under TP16DP1 configuration.
Does this PR introduce any user-facing change?
How was this patch tested?