@@ -268,7 +268,7 @@ def build_target_model(
268268 if (
269269 args .is_vlm
270270 and draft_model_config .target_model_type == "qwen2_5_vl"
271- and args .tp_size == 1
271+ and args .target_model_backend == "custom"
272272 ):
273273 from transformers import Qwen2_5_VLForConditionalGeneration
274274
@@ -456,7 +456,6 @@ def build_dataloaders(
456456 ),
457457 is_vlm = args .is_vlm ,
458458 )
459-
460459 if args .eval_data_path is not None or args .eval_hidden_states_path is not None :
461460 if args .eval_data_path is not None :
462461 eval_dataset = load_dataset ("json" , data_files = args .eval_data_path )["train" ]
@@ -547,7 +546,7 @@ def run_forward(
547546 target_model : Optional [Eagle3TargetModel ] = None ,
548547 is_online : bool = True ,
549548) -> Tuple [List [torch .Tensor ], List [torch .Tensor ]]:
550- if args .is_vlm :
549+ if args .is_vlm and args . target_model_backend == "custom" :
551550 plosses , _ , acces = eagle3_model (
552551 input_ids = data ["input_ids" ].cuda (),
553552 attention_mask = data ["attention_mask" ].cuda (),
@@ -556,13 +555,32 @@ def run_forward(
556555 image_grid_thw = data ["image_grid_thw" ].cuda (),
557556 )
558557 else :
558+ image_grid_thw = None
559559 if is_online :
560560 # we generate the eagle3 using the target model in an online fashion
561- eagle3_data = target_model .generate_eagle3_data (
562- input_ids = data ["input_ids" ].cuda (),
563- attention_mask = data ["attention_mask" ].cuda (),
564- loss_mask = data ["loss_mask" ].cuda (),
565- )
561+ # Handle VLM data: pixel_values and image_grid_thw are lists
562+ # pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None
563+ if args .is_vlm :
564+ image_grid_thw = (
565+ [thw .cuda ().squeeze () for thw in data ["image_grid_thw" ]]
566+ if args .is_vlm
567+ else None
568+ )
569+ pixel_values = data ["pixel_values" ].cuda ()
570+ eagle3_data = target_model .generate_eagle3_data (
571+ input_ids = data ["input_ids" ].cuda (),
572+ attention_mask = data ["attention_mask" ].cuda (),
573+ loss_mask = data ["loss_mask" ].cuda (),
574+ is_vlm = args .is_vlm ,
575+ pixel_values = pixel_values ,
576+ image_grid_thw = image_grid_thw ,
577+ )
578+ else :
579+ eagle3_data = target_model .generate_eagle3_data (
580+ input_ids = data ["input_ids" ].cuda (),
581+ attention_mask = data ["attention_mask" ].cuda (),
582+ loss_mask = data ["loss_mask" ].cuda (),
583+ )
566584
567585 input_ids = get_dp_data_shard_from_tp (eagle3_data .input_ids )
568586 attention_mask = get_dp_data_shard_from_tp (eagle3_data .attention_mask )
@@ -579,13 +597,14 @@ def run_forward(
579597 input_ids , target , loss_mask = target_model .preprocess (
580598 input_ids , target , loss_mask
581599 )
582-
583600 plosses , _ , acces = eagle3_model (
584601 input_ids = input_ids ,
585602 attention_mask = attention_mask ,
586603 loss_mask = loss_mask ,
587604 target = target ,
588605 hidden_states = hidden_states ,
606+ image_grid_thw = image_grid_thw ,
607+ is_vlm = args .is_vlm ,
589608 )
590609 return plosses , acces
591610
@@ -747,6 +766,8 @@ def main():
747766 if (
748767 args .is_vlm
749768 and getattr (draft_model_config , "target_model_type" , None ) == "qwen2_5_vl"
769+ and args .tp_size == 1
770+ and args .target_model_backend != "sglang"
750771 ):
751772 eagle3_model = QwenVLOnlineEagle3Model (
752773 target_model = target_model ,
@@ -757,6 +778,7 @@ def main():
757778 )
758779 else :
759780 eagle3_model = OnlineEagle3Model (
781+ target_model = target_model ,
760782 draft_model = draft_model ,
761783 length = args .ttt_length ,
762784 attention_backend = args .attention_backend ,
@@ -910,7 +932,6 @@ def main():
910932 tracker ,
911933 mode = "eval" ,
912934 )
913-
914935 # ================================================
915936 # 7.3 Save Checkpoints
916937 # ================================================
@@ -923,7 +944,6 @@ def main():
923944
924945 if args .max_num_steps is not None and global_step >= args .max_num_steps :
925946 break
926-
927947 # Save final checkpoint if training ended without saving
928948 if global_step % args .save_interval != 0 :
929949 print_on_rank0 (
0 commit comments