Skip to content

Commit cdd81c3

Browse files
committed
support qwen2.5_32b_eagle3 by sglang backend
1 parent b85f89c commit cdd81c3

File tree

8 files changed

+718
-31
lines changed

8 files changed

+718
-31
lines changed

configs/qwen2.5-vl-32b-eagle3.json

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
{
2+
"architectures": [
3+
"LlamaForCausalLMEagle3"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"bos_token_id": 151643,
8+
"eos_token_id": 151645,
9+
"head_dim": 128,
10+
"hidden_act": "silu",
11+
"hidden_size": 5120,
12+
"initializer_range": 0.02,
13+
"intermediate_size": 18944,
14+
"max_position_embeddings": 8192,
15+
"max_window_layers": 28,
16+
"model_type": "llama",
17+
"target_model_type": "qwen2_5_vl",
18+
"num_attention_heads": 28,
19+
"num_hidden_layers": 1,
20+
"num_key_value_heads": 4,
21+
"rms_norm_eps": 1e-06,
22+
"pretraining_tp": 1,
23+
"rope_scaling": {
24+
"type": "mrope",
25+
"mrope_section": [
26+
16,
27+
24,
28+
24
29+
]
30+
},
31+
"rope_theta": 1000000,
32+
"sliding_window": 32768,
33+
"tie_word_embeddings": false,
34+
"torch_dtype": "bfloat16",
35+
"transformers_version": "4.51.0",
36+
"use_cache": true,
37+
"use_sliding_window": false,
38+
"vocab_size": 152064,
39+
"draft_vocab_size": 32000
40+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/bin/bash
2+
3+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
4+
ROOT_DIR=$(dirname $SCRIPT_DIR)
5+
6+
# support tp1 train eagle3 for qwen2.5-vl-7b-instruct
7+
NUM_GPUS=${1:-1}
8+
BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64}
9+
10+
torchrun \
11+
--standalone \
12+
--nproc_per_node $NUM_GPUS \
13+
$ROOT_DIR/scripts/train_eagle3.py \
14+
--target-model-path Qwen/Qwen2.5-VL-32B-Instruct \
15+
--draft-model-config $ROOT_DIR/configs/qwen2.5-vl-32b-eagle3.json \
16+
--train-data-path $ROOT_DIR/cache/allava4v_train.jsonl \
17+
--build-dataset-num-proc $BUILD_DATASET_NUM_PROC \
18+
--output-dir $ROOT_DIR/outputs/qwen2.5-vl-32b-eagle3 \
19+
--num-epochs 10 \
20+
--batch-size 1 \
21+
--learning-rate 1e-4 \
22+
--max-length 4096 \
23+
--dist-timeout 360 \
24+
--chat-template qwen2-vl \
25+
--target-model-backend sglang \
26+
--cache-dir $ROOT_DIR/cache \
27+
--embedding-key model.embed_tokens.weight \
28+
--tp-size 4 \
29+
--sglang-mem-fraction-static 0.5 \
30+
--is-vlm \
31+
--min-pixels 200704 \
32+
--max-pixels 1003520

scripts/train_eagle3.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def build_target_model(
268268
if (
269269
args.is_vlm
270270
and draft_model_config.target_model_type == "qwen2_5_vl"
271-
and args.tp_size == 1
271+
and args.target_model_backend == "custom"
272272
):
273273
from transformers import Qwen2_5_VLForConditionalGeneration
274274

@@ -456,7 +456,6 @@ def build_dataloaders(
456456
),
457457
is_vlm=args.is_vlm,
458458
)
459-
460459
if args.eval_data_path is not None or args.eval_hidden_states_path is not None:
461460
if args.eval_data_path is not None:
462461
eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"]
@@ -547,7 +546,7 @@ def run_forward(
547546
target_model: Optional[Eagle3TargetModel] = None,
548547
is_online: bool = True,
549548
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
550-
if args.is_vlm:
549+
if args.is_vlm and args.target_model_backend == "custom":
551550
plosses, _, acces = eagle3_model(
552551
input_ids=data["input_ids"].cuda(),
553552
attention_mask=data["attention_mask"].cuda(),
@@ -556,13 +555,32 @@ def run_forward(
556555
image_grid_thw=data["image_grid_thw"].cuda(),
557556
)
558557
else:
558+
image_grid_thw = None
559559
if is_online:
560560
# we generate the eagle3 using the target model in an online fashion
561-
eagle3_data = target_model.generate_eagle3_data(
562-
input_ids=data["input_ids"].cuda(),
563-
attention_mask=data["attention_mask"].cuda(),
564-
loss_mask=data["loss_mask"].cuda(),
565-
)
561+
# Handle VLM data: pixel_values and image_grid_thw are lists
562+
# pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None
563+
if args.is_vlm:
564+
image_grid_thw = (
565+
[thw.cuda().squeeze() for thw in data["image_grid_thw"]]
566+
if args.is_vlm
567+
else None
568+
)
569+
pixel_values = data["pixel_values"].cuda()
570+
eagle3_data = target_model.generate_eagle3_data(
571+
input_ids=data["input_ids"].cuda(),
572+
attention_mask=data["attention_mask"].cuda(),
573+
loss_mask=data["loss_mask"].cuda(),
574+
is_vlm=args.is_vlm,
575+
pixel_values=pixel_values,
576+
image_grid_thw=image_grid_thw,
577+
)
578+
else:
579+
eagle3_data = target_model.generate_eagle3_data(
580+
input_ids=data["input_ids"].cuda(),
581+
attention_mask=data["attention_mask"].cuda(),
582+
loss_mask=data["loss_mask"].cuda(),
583+
)
566584

567585
input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids)
568586
attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask)
@@ -579,13 +597,14 @@ def run_forward(
579597
input_ids, target, loss_mask = target_model.preprocess(
580598
input_ids, target, loss_mask
581599
)
582-
583600
plosses, _, acces = eagle3_model(
584601
input_ids=input_ids,
585602
attention_mask=attention_mask,
586603
loss_mask=loss_mask,
587604
target=target,
588605
hidden_states=hidden_states,
606+
image_grid_thw=image_grid_thw,
607+
is_vlm=args.is_vlm,
589608
)
590609
return plosses, acces
591610

@@ -747,6 +766,8 @@ def main():
747766
if (
748767
args.is_vlm
749768
and getattr(draft_model_config, "target_model_type", None) == "qwen2_5_vl"
769+
and args.tp_size == 1
770+
and args.target_model_backend != "sglang"
750771
):
751772
eagle3_model = QwenVLOnlineEagle3Model(
752773
target_model=target_model,
@@ -757,6 +778,7 @@ def main():
757778
)
758779
else:
759780
eagle3_model = OnlineEagle3Model(
781+
target_model=target_model,
760782
draft_model=draft_model,
761783
length=args.ttt_length,
762784
attention_backend=args.attention_backend,
@@ -910,7 +932,6 @@ def main():
910932
tracker,
911933
mode="eval",
912934
)
913-
914935
# ================================================
915936
# 7.3 Save Checkpoints
916937
# ================================================
@@ -923,7 +944,6 @@ def main():
923944

924945
if args.max_num_steps is not None and global_step >= args.max_num_steps:
925946
break
926-
927947
# Save final checkpoint if training ended without saving
928948
if global_step % args.save_interval != 0:
929949
print_on_rank0(

specforge/core/eagle3.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
draft_model: Eagle3DraftModel,
6060
length: int = 7,
6161
attention_backend="sdpa",
62+
target_model: Optional[Eagle3Model] = None,
6263
):
6364
"""
6465
Args:
@@ -70,6 +71,7 @@ def __init__(
7071
self.draft_model = draft_model
7172
self.length = length
7273
self.attention_backend = attention_backend
74+
self.target_model = target_model
7375

7476
if self.attention_backend == "usp":
7577
self.extract_func = EXTRACT_FUNC_DICT["basic"]
@@ -98,6 +100,8 @@ def forward(
98100
hidden_states: torch.Tensor,
99101
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
100102
position_ids: Optional[torch.Tensor] = None,
103+
image_grid_thw: Optional[torch.Tensor] = None,
104+
is_vlm: bool = False,
101105
**kwargs,
102106
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
103107
"""
@@ -132,14 +136,22 @@ def forward(
132136
past_key_values_length = past_key_values[0][0].shape[2]
133137
seq_length_with_past = seq_length_with_past + past_key_values_length
134138
if position_ids is None:
135-
device = hidden_states.device
136-
position_ids = torch.arange(
137-
past_key_values_length,
138-
seq_length + past_key_values_length,
139-
dtype=torch.long,
140-
device=device,
141-
)
142-
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
139+
if is_vlm:
140+
mrope_positions_ids, mrope_position_delta = (
141+
self.target_model.get_rope_index(
142+
input_ids=input_ids, image_grid_thw=image_grid_thw
143+
)
144+
)
145+
position_ids = mrope_positions_ids
146+
else:
147+
device = hidden_states.device
148+
position_ids = torch.arange(
149+
past_key_values_length,
150+
seq_length + past_key_values_length,
151+
dtype=torch.long,
152+
device=device,
153+
)
154+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
143155
else:
144156
position_ids = position_ids.view(-1, seq_length).long()
145157

0 commit comments

Comments
 (0)