Skip to content

Conversation

Levi-JQ
Copy link

@Levi-JQ Levi-JQ commented Sep 28, 2025

What this PR does / why we need it?

Supports generalized FlashComm2 optimization, which reduces communication overhead, decreases RmsNorm computation, and saves one AllGather step by replacing Allreduce operations in Attention/MLP modules with pre-AlltoAll and post-AllGather operations. This feature is enabled during the Prefill phase and delivers broad performance improvements, especially in long sequence scenarios with large tensor parallelism (TP). Benchmark tests show a 10%-20% performance acceleration under TP16DP1 configuration.

Does this PR introduce any user-facing change?

How was this patch tested?

Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the FlashComm2 optimization for tensor parallelism on Ascend NPUs, aiming to improve performance by optimizing communication patterns. The changes span configuration, parallel state management, and operator implementations. My review has identified a few issues: a critical bug in the parallel group initialization that can lead to a crash, a related potential resource leak in the group destruction logic, and incorrect formatting of error messages in the configuration validation. These issues should be addressed to ensure correctness and robustness.

_FLASHCOMM2_OTP = None
_FLASHCOMM2_ODP = get_tp_group()

if flashcomm2_otp_size > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The process group creation for FlashComm2 is guarded by if flashcomm2_otp_size > 1:. This causes _FLASHCOMM2_OTP to be None when flashcomm2_oproj_tensor_parallel_size is 1. However, Flashcomm2OProjRowParallelOp is still used in this case, and it attempts to access methods on the _FLASHCOMM2_OTP group, which will lead to a crash. The logic within this if block appears to correctly handle the size == 1 case by creating groups of size 1. The conditional guard should be removed, and its content unindented, to fix this critical bug.

Comment on lines 107 to 115
raise AssertionError(
"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})"
)
if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0:
raise AssertionError(
"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The error message strings are not f-strings, so the variables inside the curly braces will not be interpolated. This will result in confusing and unhelpful error messages for users.

Suggested change
raise AssertionError(
"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})"
)
if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0:
raise AssertionError(
"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})"
)
raise AssertionError(
f"flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size}) cannot exceed global tensor parallel size ({global_tp_size})"
)
if global_tp_size % self.flashcomm2_oproj_tensor_parallel_size != 0:
raise AssertionError(
f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({self.flashcomm2_oproj_tensor_parallel_size})"
)

_OTP = None

global _FLASHCOMM2_OTP
if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1 will prevent the _FLASHCOMM2_OTP group from being destroyed when its size is 1. If the initialization logic is fixed to create a group for size 1 (as suggested in another comment), this will cause a resource leak. The group should be destroyed if it was created, regardless of its size.

Suggested change
if _FLASHCOMM2_OTP and get_ascend_config().flashcomm2_oproj_tensor_parallel_size != 1:
if _FLASHCOMM2_OTP:

@Levi-JQ Levi-JQ force-pushed the official-fc2 branch 4 times, most recently from 8b9a5a2 to 5b6c013 Compare September 30, 2025 02:34
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant