Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions configs/qwen2.5-vl-32b-eagle3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"architectures": [
"LlamaForCausalLMEagle3"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 5120,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 8192,
"max_window_layers": 28,
"model_type": "llama",
"target_model_type": "qwen2_5_vl",
"num_attention_heads": 28,
"num_hidden_layers": 1,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"pretraining_tp": 1,
"rope_scaling": {
"type": "mrope",
"mrope_section": [
16,
24,
24
]
},
"rope_theta": 1000000,
"sliding_window": 32768,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 152064,
"draft_vocab_size": 32000
}
32 changes: 32 additions & 0 deletions examples/run_qwen2.5_32b_vl_eagle3_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

# support tp1 train eagle3 for qwen2.5-vl-7b-instruct
NUM_GPUS=${1:-1}
BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3.py \
--target-model-path Qwen/Qwen2.5-VL-32B-Instruct \
--draft-model-config $ROOT_DIR/configs/qwen2.5-vl-32b-eagle3.json \
--train-data-path $ROOT_DIR/cache/allava4v_train.jsonl \
--build-dataset-num-proc $BUILD_DATASET_NUM_PROC \
--output-dir $ROOT_DIR/outputs/qwen2.5-vl-32b-eagle3 \
--num-epochs 10 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length 4096 \
--dist-timeout 360 \
--chat-template qwen2-vl \
--target-model-backend sglang \
--cache-dir $ROOT_DIR/cache \
--embedding-key model.embed_tokens.weight \
--tp-size 4 \
--sglang-mem-fraction-static 0.5 \
--is-vlm \
--min-pixels 200704 \
--max-pixels 1003520
61 changes: 44 additions & 17 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def build_target_model(
if (
args.is_vlm
and draft_model_config.target_model_type == "qwen2_5_vl"
and args.tp_size == 1
and args.target_model_backend == "custom"
):
from transformers import Qwen2_5_VLForConditionalGeneration

Expand Down Expand Up @@ -456,7 +456,6 @@ def build_dataloaders(
),
is_vlm=args.is_vlm,
)

if args.eval_data_path is not None or args.eval_hidden_states_path is not None:
if args.eval_data_path is not None:
eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"]
Expand Down Expand Up @@ -547,7 +546,7 @@ def run_forward(
target_model: Optional[Eagle3TargetModel] = None,
is_online: bool = True,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
if args.is_vlm:
if args.is_vlm and args.target_model_backend == "custom":
plosses, _, acces = eagle3_model(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
Expand All @@ -556,13 +555,32 @@ def run_forward(
image_grid_thw=data["image_grid_thw"].cuda(),
)
else:
image_grid_thw = None
if is_online:
# we generate the eagle3 using the target model in an online fashion
eagle3_data = target_model.generate_eagle3_data(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
loss_mask=data["loss_mask"].cuda(),
)
# Handle VLM data: pixel_values and image_grid_thw are lists
# pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None
if args.is_vlm:
image_grid_thw = (
[thw.cuda().squeeze() for thw in data["image_grid_thw"]]
if args.is_vlm
else None
)
pixel_values = data["pixel_values"].cuda()
eagle3_data = target_model.generate_eagle3_data(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
loss_mask=data["loss_mask"].cuda(),
is_vlm=args.is_vlm,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
else:
eagle3_data = target_model.generate_eagle3_data(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
loss_mask=data["loss_mask"].cuda(),
)

input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids)
attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask)
Expand All @@ -579,13 +597,14 @@ def run_forward(
input_ids, target, loss_mask = target_model.preprocess(
input_ids, target, loss_mask
)

plosses, _, acces = eagle3_model(
input_ids=input_ids,
attention_mask=attention_mask,
loss_mask=loss_mask,
target=target,
hidden_states=hidden_states,
image_grid_thw=image_grid_thw,
is_vlm=args.is_vlm,
)
return plosses, acces

Expand Down Expand Up @@ -747,6 +766,8 @@ def main():
if (
args.is_vlm
and getattr(draft_model_config, "target_model_type", None) == "qwen2_5_vl"
and args.tp_size == 1
and args.target_model_backend != "sglang"
):
eagle3_model = QwenVLOnlineEagle3Model(
target_model=target_model,
Expand All @@ -756,12 +777,20 @@ def main():
attention_backend=args.attention_backend,
)
else:
eagle3_model = OnlineEagle3Model(
draft_model=draft_model,
length=args.ttt_length,
attention_backend=args.attention_backend,
)

if is_online:
eagle3_model = OnlineEagle3Model(
target_model=target_model,
draft_model=draft_model,
length=args.ttt_length,
attention_backend=args.attention_backend,
)
else:
# offline: the target_model is TargetHead not a model
eagle3_model = OnlineEagle3Model(
draft_model=draft_model,
length=args.ttt_length,
attention_backend=args.attention_backend,
)
eagle3_model = FSDP(
eagle3_model,
use_orig_params=True,
Expand Down Expand Up @@ -910,7 +939,6 @@ def main():
tracker,
mode="eval",
)

# ================================================
# 7.3 Save Checkpoints
# ================================================
Expand All @@ -923,7 +951,6 @@ def main():

if args.max_num_steps is not None and global_step >= args.max_num_steps:
break

# Save final checkpoint if training ended without saving
if global_step % args.save_interval != 0:
print_on_rank0(
Expand Down
28 changes: 20 additions & 8 deletions specforge/core/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
draft_model: Eagle3DraftModel,
length: int = 7,
attention_backend="sdpa",
target_model: Optional[Eagle3Model] = None,
):
"""
Args:
Expand All @@ -70,6 +71,7 @@ def __init__(
self.draft_model = draft_model
self.length = length
self.attention_backend = attention_backend
self.target_model = target_model

if self.attention_backend == "usp":
self.extract_func = EXTRACT_FUNC_DICT["basic"]
Expand Down Expand Up @@ -98,6 +100,8 @@ def forward(
hidden_states: torch.Tensor,
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
position_ids: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.Tensor] = None,
is_vlm: bool = False,
**kwargs,
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
"""
Expand Down Expand Up @@ -132,14 +136,22 @@ def forward(
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = hidden_states.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
if is_vlm:
mrope_positions_ids, mrope_position_delta = (
self.target_model.get_rope_index(
input_ids=input_ids, image_grid_thw=image_grid_thw
)
)
position_ids = mrope_positions_ids
else:
device = hidden_states.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()

Expand Down
Loading