Skip to content

Commit 947b28a

Browse files
(feat)Draft model supports Qwen3 MoE
1 parent 3e34e19 commit 947b28a

File tree

8 files changed

+695
-7
lines changed

8 files changed

+695
-7
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Eagle3 for Llama3 - Offline
2+
3+
## Introduction
4+
5+
This document provides a step-by-step guide on how to train the EAGLE3 model for the Llama3.1-8B-Instruct model in an offline manner. In offline training, we generate the hidden states required by EAGLE3 draft model beforehand and store them to the disk. During training, we load them back to the GPU memory. As offline training requires a lot of disk space, we do not recommend running this on large datasets such as Perfect-Blend.
6+
7+
## Training on ShareGPT dataset
8+
9+
### **Step 1. Prepare ShareGPT dataset**
10+
11+
First of all, we should download the dataset.
12+
13+
```shell
14+
python ./scripts/prepare_data.py --dataset sharegpt
15+
```
16+
17+
### **Step 2. Prepare Hidden States**
18+
19+
We need to prepare the hidden states for the training.
20+
21+
```shell
22+
torchrun --nproc_per_node=8 \
23+
    scripts/prepare_hidden_states.py \
24+
    --target-model-path /home/data/weights/Qwen3-32B \
25+
    --enable-aux-hidden-states \
26+
    --data-path ./cache/dataset/sharegpt_train.jsonl \
27+
    --chat-template qwen \
28+
    --max-length 2048 \
29+
    --tp-size 8 \
30+
    --batch-size 32 \
31+
    --num-samples 20 \
32+
    --output-path ./cache/hidden_states
33+
```
34+
35+
The hidden states will be saved to the disk in the `output-path` directory.
36+
37+
### **Step 3. Start Training**
38+
39+
```shell
40+
torchrun \
41+
--standalone \
42+
--nproc_per_node $NUM_GPUS \
43+
$ROOT_DIR/scripts/train_eagle3.py \
44+
--target-model-path /home/data/weights/Qwen3-30B-A3B/ \
45+
--draft-model-config $ROOT_DIR/configs/qwen3-30B-A3B-eagle3_moe.json \
46+
--train-data-path ./cache/dataset/sharegpt_train.jsonl \
47+
--train-hidden-states-path ./cache/hidden_states \
48+
--output-dir ./outputs/qwen3-moe-8b-eagle3-sharegpt-offline \
49+
--num-epochs 4 \
50+
--batch-size 1 \
51+
--learning-rate 1e-4 \
52+
--max-length 2048 \
53+
--save-interval 1984 \
54+
--chat-template qwen \
55+
--cache-dir $ROOT_DIR/cache \
56+
--embedding-key model.embed_tokens.weight \
57+
--sp-ulysses-size 4 \
58+
--attention-backend "usp" \
59+
--target-model-backend sglang
60+
61+
```

specforge/core/eagle3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def _compute_target_p_padded(target, t2d, loss_mask, length):
545545
return target_p_padded, position_mask
546546

547547

548-
@torch.compile(dynamic=None)
548+
# @torch.compile(dynamic=None)
549549
def _compute_target_p(target, t2d, loss_mask):
550550
target_head = target
551551
target_max_token = target_head.argmax(-1)
@@ -559,7 +559,7 @@ def _compute_target_p(target, t2d, loss_mask):
559559
return target_p, position_mask
560560

561561

562-
@torch.compile(dynamic=None)
562+
# @torch.compile(dynamic=None)
563563
def _compute_metric_acc(logits, target_p, position_mask, loss_mask):
564564
return (
565565
(logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)

specforge/core/loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
# Reference implementation
15-
@torch.compile(dynamic=None)
15+
# @torch.compile(dynamic=None)
1616
def _compute_loss(logits, target_p, position_mask):
1717
logits = logits.float()
1818
out_logp = nn.LogSoftmax(dim=2)(logits)
@@ -30,7 +30,7 @@ def _calculate_settings(n):
3030
raise RuntimeError(
3131
f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}."
3232
)
33-
33+
BLOCK_SIZE = 2048
3434
num_warps = 4
3535
if BLOCK_SIZE >= 32768:
3636
num_warps = 32

specforge/modeling/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# from .auto import AutoDistributedTargetModel, AutoDraftModelConfig, AutoEagle3DraftModel
22
from .auto import AutoDraftModelConfig, AutoEagle3DraftModel
33
from .draft.llama3_eagle import LlamaForCausalLMEagle3
4+
from .draft.qwen3_moe_eagle import Qwen3MoEForCausalLMEagle3
5+
from .draft.router_moe import Qwen3MoERouterForCausalLMEagle3
46
from .target.eagle3_target_model import (
57
CustomEagle3TargetModel,
68
HFEagle3TargetModel,
@@ -10,6 +12,7 @@
1012

1113
__all__ = [
1214
"LlamaForCausalLMEagle3",
15+
"Qwen3MoERouterForCausalLMEagle3",
1316
"SGLangEagle3TargetModel",
1417
"HFEagle3TargetModel",
1518
"CustomEagle3TargetModel",

specforge/modeling/auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
)
2020

2121
from .draft.llama3_eagle import LlamaForCausalLMEagle3
22+
from .draft.qwen3_moe_eagle import Qwen3MoEForCausalLMEagle3
23+
from .draft.router_moe import Qwen3MoERouterForCausalLMEagle3
2224
from .target.custom_backend import (
2325
GptOssForCausalLM,
2426
Llama4ForCausalLM,
@@ -34,6 +36,7 @@ class AutoEagle3DraftModel(AutoModelForCausalLMBase):
3436
# the model mapping is currently hardcoded, we should support lazy model mapping via registry
3537
_model_mapping = {
3638
LlamaConfig: LlamaForCausalLMEagle3,
39+
Qwen3MoeConfig: Qwen3MoERouterForCausalLMEagle3,
3740
}
3841

3942
@classmethod
@@ -133,6 +136,7 @@ class AutoDraftModelConfig:
133136

134137
_config_mapping = {
135138
"LlamaForCausalLMEagle3": LlamaConfig,
139+
"Qwen3MoERouterForCausalLMEagle3": Qwen3MoeConfig,
136140
}
137141

138142
@classmethod

specforge/modeling/draft/llama3_eagle.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def rotate_half(x):
102102
return torch.cat((-x2, x1), dim=-1)
103103

104104

105-
@torch.compile(dynamic=True)
105+
# @torch.compile(dynamic=True)
106106
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
107107
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
108108
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
@@ -272,7 +272,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
272272
"sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
273273
)
274274

275-
@torch.compile(dynamic=True)
275+
# @torch.compile(dynamic=True)
276276
def forward(self, x, seq_len=None):
277277
# x: [bs, num_attention_heads, seq_len, head_size]
278278
if seq_len and seq_len > self.max_seq_len_cached:
@@ -1347,7 +1347,7 @@ def __init__(self, hidden_size, eps=1e-6):
13471347
self.weight = nn.Parameter(torch.ones(hidden_size))
13481348
self.variance_epsilon = eps
13491349

1350-
@torch.compile(dynamic=True)
1350+
# @torch.compile(dynamic=True)
13511351
def forward(self, hidden_states):
13521352
input_dtype = hidden_states.dtype
13531353
hidden_states = hidden_states.to(torch.float32)
@@ -1474,6 +1474,7 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None:
14741474
d2t = torch.zeros(self.draft_vocab_size, dtype=torch.int64)
14751475
self.register_buffer("t2d", t2d)
14761476
self.register_buffer("d2t", d2t)
1477+
self.save_idx = 1
14771478

14781479
def forward(
14791480
self,
@@ -1528,6 +1529,12 @@ def forward(
15281529

15291530
# norm
15301531
hidden_states = self.norm(hidden_states)
1532+
self.save_idx += 1
1533+
save_path = f"./router/{self.save_idx}.pt"
1534+
# 2. 保存张量(转CPU避免GPU张量依赖,detach解耦计算图)
1535+
torch.save(hidden_states.detach().cpu(), save_path)
1536+
if self.save_idx >= 16:
1537+
exit(0)
15311538

15321539
return hidden_states
15331540

@@ -1563,3 +1570,5 @@ def backbone(
15631570
output_attentions=False,
15641571
use_cache=False,
15651572
)
1573+
1574+
#

0 commit comments

Comments
 (0)