Skip to content

Add TP8 online training support for the Qwen3-MoE model.#74

Merged
sleepcoo merged 6 commits intosgl-project:mainfrom
ZhengHSI:add_qwen3_moe_tp8
Jul 31, 2025
Merged

Add TP8 online training support for the Qwen3-MoE model.#74
sleepcoo merged 6 commits intosgl-project:mainfrom
ZhengHSI:add_qwen3_moe_tp8

Conversation

@ZhengHSI
Copy link
Collaborator

  1. Add TP8 online training support for the Qwen3-MoE model.Since Qwen3-MoE has num_key_value_heads < 8, the attention layer replicates each KV head across ranks instead of sharding them.
  2. Qwen3-235B-A22B online training is also supported
  3. You can use SpecForge/tests/test_qwen3_tp.py to test the correctness of TP parallelism.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ZhengHSI, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly extends the training capabilities for large language models, specifically the Qwen3-MoE and Qwen3-235B-A22B models, by introducing robust support for TP8 (Tensor Parallelism) online training. The core changes involve adapting the parallel linear layers and the model's weight loading process to intelligently handle key-value head replication when the number of available GPUs exceeds the model's native key-value head count, ensuring efficient and correct distributed training.

Highlights

  • Qwen3-MoE TP8 Training: Implemented support for TP8 (Tensor Parallelism with 8 GPUs) online training for the Qwen3-MoE model. This includes specific handling for scenarios where the number of key-value heads is less than the TP size, by replicating KV heads across ranks rather than sharding them.
  • Qwen3-235B-A22B Support: Added foundational support for online training of the Qwen3-235B-A22B model, including a commented-out example script in run_qwen3_moe_eagle3_online.sh demonstrating TP8 training for this model.
  • Generalized Parallel Linear Layers: Enhanced the RowParallelLinear and ColumnParallelLinear modules to accept a kv_head_replicas flag. This allows these layers to dynamically adjust their sharding behavior, enabling proper handling of replicated key-value heads in attention mechanisms.
  • Attention Weight Sharding Logic: Updated the Qwen3MoeForCausalLM's weight loading mechanism to correctly shard or replicate attention layer weights (q_proj, k_proj, v_proj, o_proj). This includes a new conditional path for k_proj and v_proj to manage KV head replication by selecting the appropriate head slice when tp_size exceeds the total number of KV heads.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds TP8 online training support for the Qwen3-MoE model. A critical bug was identified in the tensor parallelism logic that could lead to runtime errors, and a detailed suggestion for a fix has been provided. There are also a few medium-severity recommendations to improve code maintainability and usability.

# --standalone \
# --nproc_per_node $NUM_GPUS \
# $ROOT_DIR/scripts/train_eagle3_online.py \
# --target-model-path /root/model/Qwen3-235B-A22B \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The model path is hardcoded. This makes the script less portable and harder to use for others who may have the model stored in a different location. It's a good practice to use an environment variable for paths like this, with a sensible default if needed.

Suggested change
# --target-model-path /root/model/Qwen3-235B-A22B \
# --target-model-path ${TARGET_MODEL_PATH:-/root/model/Qwen3-235B-A22B} \

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

Comment on lines +10 to +18
def __init__(
self,
in_features,
out_features,
bias=True,
device=None,
dtype=None,
kv_head_replicas=False,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The kv_head_replicas parameter and its associated logic (at lines 28-31) are not used for RowParallelLinear anywhere in this pull request. This layer is used for o_proj and MoE down_proj, neither of which requires this special handling for KV head replication.

To improve clarity and avoid maintaining unused code, I recommend removing this parameter and its conditional logic from RowParallelLinear.

layer_match = key.split(".")
layer_idx = None
for i, part in enumerate(layer_match):
if part.startswith("layers") and i + 1 < len(layer_match):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using part == "layers" is more precise and robust than part.startswith("layers"). While startswith works in this context, using an exact match better conveys the intent to find a specific segment of the model key and prevents potential mismatches if other keys were to begin with "layers" (e.g., "layers_config").

Suggested change
if part.startswith("layers") and i + 1 < len(layer_match):
if part == "layers" and i + 1 < len(layer_match):

@sleepcoo sleepcoo self-requested a review July 28, 2025 08:35
@Xu-Chen
Copy link

Xu-Chen commented Jul 28, 2025

Any plan to add TP8 online training support for the Qwen3-Dense model?

@ZhengHSI
Copy link
Collaborator Author

Any plan to add TP8 online training support for the Qwen3-Dense model?

coming soon

@ZhengHSI
Copy link
Collaborator Author

ZhengHSI commented Jul 31, 2025

I trained for 1 epoch on ShareGPT and obtained the following result.

Train Epoch [1/10], position 0,  Acc: 0.52
Train Epoch [1/10], position 1,  Acc: 0.47
Train Epoch [1/10], position 2,  Acc: 0.45
Train Epoch [1/10], position 3,  Acc: 0.43
Train Epoch [1/10], position 4,  Acc: 0.42
Train Epoch [1/10], position 5,  Acc: 0.40
Train Epoch [1/10], position 6,  Acc: 0.39
Train Epoch [1/10], position 0, pLoss: 1.36
Train Epoch [1/10], position 1, pLoss: 1.47
Train Epoch [1/10], position 2, pLoss: 1.54
Train Epoch [1/10], position 3, pLoss: 1.59
Train Epoch [1/10], position 4, pLoss: 1.63
Train Epoch [1/10], position 5, pLoss: 1.67
Train Epoch [1/10], position 6, pLoss: 1.71

@sleepcoo sleepcoo merged commit 731e0d7 into sgl-project:main Jul 31, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants