Skip to content

Add TP support for Qwen3-Dense model online training.#77

Merged
yubofredwang merged 22 commits intosgl-project:mainfrom
Ximingwang-09:main
Aug 12, 2025
Merged

Add TP support for Qwen3-Dense model online training.#77
yubofredwang merged 22 commits intosgl-project:mainfrom
Ximingwang-09:main

Conversation

@Ximingwang-09
Copy link
Contributor

@Ximingwang-09 Ximingwang-09 commented Jul 28, 2025

Add TP support for Qwen3-Dense model online training. And More Qwen3 model config support.
Related Issue:#85

run_qwen3_dense_eagle3_online.sh

#!/bin/bash

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

# support tp4 train eagle3 for Qwen3-8B
NUM_GPUS=${1:-1}

torchrun \
    --standalone \
    --nproc_per_node $NUM_GPUS \
    $ROOT_DIR/scripts/train_eagle3_online.py \
    --target-model-path /mnt/Qwen3-4B \
    --draft-model-config $ROOT_DIR/configs/qwen3-4b-eagle3.json \
    --train-data-path /mnt/sharegpt_2000.jsonl \
    --output-dir /mnt4/data/eagle/mxc_slow_think \
    --num-epochs 4 \
    --batch-size 1 \
    --learning-rate 1e-4 \
    --max-length 4096 \
    --chat-template qwen \
    --embedding-key model.embed_tokens.weight \
    --tp-size $NUM_GPUS

cmd
bash run_qwen3_dense_eagle3_online.sh 8

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Ximingwang-09, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request expands our system's capability by integrating new configuration files for specific Qwen3 model variants. These additions are crucial for enabling the proper loading and utilization of the Qwen3 32B Eagle3 and Qwen3 4B Eagle3 models within our framework, providing their essential architectural parameters.

Highlights

  • New Model Configurations: I've added two new JSON configuration files for Qwen3 models, specifically the 'Eagle3' variants.
  • Qwen3 32B Eagle3 Support: A configuration file (configs/qwen3-32b-eagle3.json) has been introduced to define the parameters for the Qwen3 32B Eagle3 model, including its architecture (LlamaForCausalLMEagle3), hidden size, attention heads, and vocabulary size.
  • Qwen3 4B Eagle3 Support: A separate configuration file (configs/qwen3-4b-eagle3.json) has been added to specify the parameters for the Qwen3 4B Eagle3 model, detailing its architecture, hidden size, attention heads, and vocabulary size.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds two new configuration files for qwen3-32b-eagle3 and qwen3-4b-eagle3 models. The configurations look good, but I've identified a few minor formatting and consistency issues that should be addressed to improve maintainability and prevent potential issues with tooling. Specifically, both files contain a UTF-8 BOM, are missing a final newline character, and use an integer for rope_theta where a float would be more consistent with other configurations in the project. I've provided suggestions to fix these.

@@ -0,0 +1,31 @@
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file starts with a UTF-8 Byte Order Mark (BOM) character (\ufeff). The BOM is unnecessary for UTF-8 encoded files and can cause compatibility issues with some parsers and tools. It's recommended to save the file as 'UTF-8' without a BOM.

{

"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other model configurations in the repository (e.g., qwen3-235B-A22B-eagle3.json), it's better to represent rope_theta as a float. While this may not cause issues with the current implementation, standardizing on the float type improves maintainability and readability, especially since this value is used in floating-point calculations.

    "rope_theta": 1000000.0,

"use_sliding_window": false,
"vocab_size": 151936,
"draft_vocab_size": 32000
} No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file is missing a final newline character. It's a standard convention (part of the POSIX standard) for text files to end with a newline. Some tools may not process the last line correctly without it.

  }

@@ -0,0 +1,31 @@
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file starts with a UTF-8 Byte Order Mark (BOM) character (\ufeff). The BOM is unnecessary for UTF-8 encoded files and can cause compatibility issues with some parsers and tools. It's recommended to save the file as 'UTF-8' without a BOM.

{

"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other model configurations in the repository (e.g., qwen3-235B-A22B-eagle3.json), it's better to represent rope_theta as a float. While this may not cause issues with the current implementation, standardizing on the float type improves maintainability and readability, especially since this value is used in floating-point calculations.

    "rope_theta": 1000000.0,

"use_sliding_window": false,
"vocab_size": 151936,
"draft_vocab_size": 32000
} No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file is missing a final newline character. It's a standard convention (part of the POSIX standard) for text files to end with a newline. Some tools may not process the last line correctly without it.

  }

@sleepcoo
Copy link
Collaborator

sleepcoo commented Jul 29, 2025

Have you tried training, Can you share the results of your training? Could you please provide the training script?

@Ximingwang-09
Copy link
Contributor Author

Have you tried training, Can you share the results of your training? Could you please provide the training script?

Yes, I'm trying to train a qwen3-32b eagle model, I'll share the results and script after I finish it.

@Ximingwang-09
Copy link
Contributor Author

/gemini review

@Ximingwang-09 Ximingwang-09 changed the title Add more qwen3 model config. Add Qwen3-Dense model online training support. Jul 30, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Qwen3 dense models, including new configuration files and a model implementation with tensor parallelism. The changes are well-structured, but I've identified a couple of critical issues in the tensor parallelism logic within the new Qwen3Attention module that could lead to incorrect model outputs or runtime errors. Addressing these is important for the correctness and robustness of the implementation.

if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, attn_weights = attention_interface(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line will cause a ValueError at runtime if self.config._attn_implementation is set to an implementation (like flash_attention_2) that returns more than two values (e.g., when use_cache=True). The code currently unpacks exactly two return values, but some attention backends return a third value (past_key_value).

To fix this, you should handle the variable number of return values from attention_interface.

A second, related issue is that the function's return type hint on line 207 is incorrect. It should be updated to match the actual return values, which are (attn_output, attn_weights).

Comment on lines +240 to +242
attn_output = self.o_proj(attn_output)
# Add all_reduce for TP
dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a critical issue here with how the bias of the output projection (o_proj) is handled in tensor parallelism. The o_proj is a RowParallelLinear layer, and if it has a bias (i.e., attention_bias=True in the config), the bias is added to the output on each TP rank before the all_reduce operation. This results in the bias being incorrectly scaled by the tensor parallel world size (tp_size).

While the configs in this PR set attention_bias: false, the code should be robust for cases where it is true (which is the default in transformers.Qwen3Config). The bias should be added after the all_reduce operation.

Suggested change
attn_output = self.o_proj(attn_output)
# Add all_reduce for TP
dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group)
attn_output = torch.nn.functional.linear(attn_output, self.o_proj.weight)
dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group)
if self.o_proj.bias is not None:
attn_output += self.o_proj.bias

"use_sliding_window": false,
"vocab_size": 151936,
"draft_vocab_size": 32000
} No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's a common convention and good practice to end files with a newline character. Many tools (like git, cat, etc.) and editors handle files better when they end with a newline. Please consider adding one here.

"use_sliding_window": false,
"vocab_size": 151936,
"draft_vocab_size": 32000
} No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's a common convention and good practice to end files with a newline character. Many tools (like git, cat, etc.) and editors handle files better when they end with a newline. Please consider adding one here.

纬杭 added 2 commits July 30, 2025 13:39
@Ximingwang-09 Ximingwang-09 changed the title Add Qwen3-Dense model online training support. Add TP support for Qwen3-Dense model online training. Jul 30, 2025
@Ximingwang-09
Copy link
Contributor Author

Ximingwang-09 commented Jul 30, 2025

Have you tried training, Can you share the results of your training? Could you please provide the training script?

I found that the current project doesn't support TP for Qwen dense model, which may lead to OOM. So I implemented it briefly. Training has now started successfully, and I will keep you updated on the training results in a timely manner.

@Ximingwang-09
Copy link
Contributor Author

Have you tried training, Can you share the results of your training? Could you please provide the training script?

I validated by training the Qwen3-4B model with 2000 ShareGPT samples on 8 H20 GPUs. The command is above, and results are below:

Train Epoch [1/4], position 0,  Acc: 0.27
Train Epoch [1/4], position 1,  Acc: 0.21
Train Epoch [1/4], position 2,  Acc: 0.18
Train Epoch [1/4], position 3,  Acc: 0.16
Train Epoch [1/4], position 4,  Acc: 0.15
Train Epoch [1/4], position 5,  Acc: 0.14
Train Epoch [1/4], position 6,  Acc: 0.14
Train Epoch [1/4], position 0, pLoss: 3.69
Train Epoch [1/4], position 1, pLoss: 4.05
Train Epoch [1/4], position 2, pLoss: 4.22
Train Epoch [1/4], position 3, pLoss: 4.33
Train Epoch [1/4], position 4, pLoss: 4.40
Train Epoch [1/4], position 5, pLoss: 4.46
Train Epoch [1/4], position 6, pLoss: 4.51

Train Epoch [2/4], position 0,  Acc: 0.45
Train Epoch [2/4], position 1,  Acc: 0.37
Train Epoch [2/4], position 2,  Acc: 0.33
Train Epoch [2/4], position 3,  Acc: 0.30
Train Epoch [2/4], position 4,  Acc: 0.29
Train Epoch [2/4], position 5,  Acc: 0.27
Train Epoch [2/4], position 6,  Acc: 0.26
Train Epoch [2/4], position 0, pLoss: 2.16
Train Epoch [2/4], position 1, pLoss: 2.53
Train Epoch [2/4], position 2, pLoss: 2.72
Train Epoch [2/4], position 3, pLoss: 2.86
Train Epoch [2/4], position 4, pLoss: 2.96
Train Epoch [2/4], position 5, pLoss: 3.04
Train Epoch [2/4], position 6, pLoss: 3.12

Train Epoch [3/4], position 0,  Acc: 0.53
Train Epoch [3/4], position 1,  Acc: 0.45
Train Epoch [3/4], position 2,  Acc: 0.40
Train Epoch [3/4], position 3,  Acc: 0.38
Train Epoch [3/4], position 4,  Acc: 0.36
Train Epoch [3/4], position 5,  Acc: 0.34
Train Epoch [3/4], position 6,  Acc: 0.33
Train Epoch [3/4], position 0, pLoss: 1.76
Train Epoch [3/4], position 1, pLoss: 2.07
Train Epoch [3/4], position 2, pLoss: 2.24
Train Epoch [3/4], position 3, pLoss: 2.37
Train Epoch [3/4], position 4, pLoss: 2.46
Train Epoch [3/4], position 5, pLoss: 2.55
Train Epoch [3/4], position 6, pLoss: 2.64

Train Epoch [4/4], position 0,  Acc: 0.55
Train Epoch [4/4], position 1,  Acc: 0.47
Train Epoch [4/4], position 2,  Acc: 0.43
Train Epoch [4/4], position 3,  Acc: 0.40
Train Epoch [4/4], position 4,  Acc: 0.38
Train Epoch [4/4], position 5,  Acc: 0.36
Train Epoch [4/4], position 6,  Acc: 0.34
Train Epoch [4/4], position 0, pLoss: 1.68
Train Epoch [4/4], position 1, pLoss: 1.97
Train Epoch [4/4], position 2, pLoss: 2.13
Train Epoch [4/4], position 3, pLoss: 2.25
Train Epoch [4/4], position 4, pLoss: 2.35
Train Epoch [4/4], position 5, pLoss: 2.44
Train Epoch [4/4], position 6, pLoss: 2.53

@Ximingwang-09
Copy link
Contributor Author

Ximingwang-09 commented Aug 1, 2025

@sleepcoo I think it's ready, Can you help me review this? Thanks.

@yubofredwang yubofredwang merged commit c091803 into sgl-project:main Aug 12, 2025
1 check passed
qibaoyuan pushed a commit to qibaoyuan/SpecForge that referenced this pull request Aug 18, 2025
* add more qwen config

* support qwen3 dense tp

* fix

* fix

* fix

* lint

* lint fix

* fix lint

* fix lint

* lint

* fix

* fix

* fix

* clean

* lint fix

* align head_dim with qwen config

---------

Co-authored-by: 纬杭 <ximing.wxm@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants