Skip to content

Commit d7ed845

Browse files
authored
Merge branch 'main' into fix_requirements
2 parents 963ade9 + ee29561 commit d7ed845

File tree

8 files changed

+26
-705
lines changed

8 files changed

+26
-705
lines changed

configs/qwen2.5-vl-32b-eagle3.json

Lines changed: 0 additions & 40 deletions
This file was deleted.

examples/run_qwen2.5_32b_vl_eagle3_online.sh

Lines changed: 0 additions & 33 deletions
This file was deleted.

scripts/train_eagle3.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def build_target_model(
268268
if (
269269
args.is_vlm
270270
and draft_model_config.target_model_type == "qwen2_5_vl"
271-
and args.target_model_backend == "custom"
271+
and args.tp_size == 1
272272
):
273273
from transformers import Qwen2_5_VLForConditionalGeneration
274274

@@ -456,6 +456,7 @@ def build_dataloaders(
456456
),
457457
is_vlm=args.is_vlm,
458458
)
459+
459460
if args.eval_data_path is not None or args.eval_hidden_states_path is not None:
460461
if args.eval_data_path is not None:
461462
eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"]
@@ -546,7 +547,7 @@ def run_forward(
546547
target_model: Optional[Eagle3TargetModel] = None,
547548
is_online: bool = True,
548549
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
549-
if args.is_vlm and args.target_model_backend == "custom":
550+
if args.is_vlm:
550551
plosses, _, acces = eagle3_model(
551552
input_ids=data["input_ids"].cuda(),
552553
attention_mask=data["attention_mask"].cuda(),
@@ -557,20 +558,10 @@ def run_forward(
557558
else:
558559
if is_online:
559560
# we generate the eagle3 using the target model in an online fashion
560-
# Handle VLM data: pixel_values and image_grid_thw are lists
561-
# pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None
562-
image_grid_thw = (
563-
[thw.cuda().squeeze() for thw in data["image_grid_thw"]]
564-
if args.is_vlm
565-
else None
566-
)
567561
eagle3_data = target_model.generate_eagle3_data(
568562
input_ids=data["input_ids"].cuda(),
569563
attention_mask=data["attention_mask"].cuda(),
570564
loss_mask=data["loss_mask"].cuda(),
571-
is_vlm=args.is_vlm,
572-
pixel_values=data["pixel_values"].cuda(),
573-
image_grid_thw=image_grid_thw,
574565
)
575566

576567
input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids)
@@ -588,14 +579,13 @@ def run_forward(
588579
input_ids, target, loss_mask = target_model.preprocess(
589580
input_ids, target, loss_mask
590581
)
582+
591583
plosses, _, acces = eagle3_model(
592584
input_ids=input_ids,
593585
attention_mask=attention_mask,
594586
loss_mask=loss_mask,
595587
target=target,
596588
hidden_states=hidden_states,
597-
image_grid_thw=image_grid_thw,
598-
is_vlm=args.is_vlm,
599589
)
600590
return plosses, acces
601591

@@ -757,8 +747,6 @@ def main():
757747
if (
758748
args.is_vlm
759749
and getattr(draft_model_config, "target_model_type", None) == "qwen2_5_vl"
760-
and args.tp_size == 1
761-
and args.target_model_backend != "sglang"
762750
):
763751
eagle3_model = QwenVLOnlineEagle3Model(
764752
target_model=target_model,
@@ -769,7 +757,6 @@ def main():
769757
)
770758
else:
771759
eagle3_model = OnlineEagle3Model(
772-
target_model=target_model,
773760
draft_model=draft_model,
774761
length=args.ttt_length,
775762
attention_backend=args.attention_backend,
@@ -923,6 +910,7 @@ def main():
923910
tracker,
924911
mode="eval",
925912
)
913+
926914
# ================================================
927915
# 7.3 Save Checkpoints
928916
# ================================================
@@ -935,6 +923,7 @@ def main():
935923

936924
if args.max_num_steps is not None and global_step >= args.max_num_steps:
937925
break
926+
938927
# Save final checkpoint if training ended without saving
939928
if global_step % args.save_interval != 0:
940929
print_on_rank0(

specforge/core/eagle3.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def __init__(
5959
draft_model: Eagle3DraftModel,
6060
length: int = 7,
6161
attention_backend="sdpa",
62-
target_model: Optional[Eagle3Model] = None,
6362
):
6463
"""
6564
Args:
@@ -71,7 +70,6 @@ def __init__(
7170
self.draft_model = draft_model
7271
self.length = length
7372
self.attention_backend = attention_backend
74-
self.target_model = target_model
7573

7674
if self.attention_backend == "usp":
7775
self.extract_func = EXTRACT_FUNC_DICT["basic"]
@@ -100,8 +98,6 @@ def forward(
10098
hidden_states: torch.Tensor,
10199
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
102100
position_ids: Optional[torch.Tensor] = None,
103-
image_grid_thw: Optional[torch.Tensor] = None,
104-
is_vlm: bool = False,
105101
**kwargs,
106102
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
107103
"""
@@ -136,22 +132,14 @@ def forward(
136132
past_key_values_length = past_key_values[0][0].shape[2]
137133
seq_length_with_past = seq_length_with_past + past_key_values_length
138134
if position_ids is None:
139-
if is_vlm:
140-
mrope_positions_ids, mrope_position_delta = (
141-
self.target_model.get_rope_index(
142-
input_ids=input_ids, image_grid_thw=image_grid_thw
143-
)
144-
)
145-
position_ids = mrope_positions_ids
146-
else:
147-
device = hidden_states.device
148-
position_ids = torch.arange(
149-
past_key_values_length,
150-
seq_length + past_key_values_length,
151-
dtype=torch.long,
152-
device=device,
153-
)
154-
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
135+
device = hidden_states.device
136+
position_ids = torch.arange(
137+
past_key_values_length,
138+
seq_length + past_key_values_length,
139+
dtype=torch.long,
140+
device=device,
141+
)
142+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
155143
else:
156144
position_ids = position_ids.view(-1, seq_length).long()
157145

0 commit comments

Comments
 (0)