@@ -268,7 +268,7 @@ def build_target_model(
268268 if (
269269 args .is_vlm
270270 and draft_model_config .target_model_type == "qwen2_5_vl"
271- and args .target_model_backend == "custom"
271+ and args .tp_size == 1
272272 ):
273273 from transformers import Qwen2_5_VLForConditionalGeneration
274274
@@ -456,6 +456,7 @@ def build_dataloaders(
456456 ),
457457 is_vlm = args .is_vlm ,
458458 )
459+
459460 if args .eval_data_path is not None or args .eval_hidden_states_path is not None :
460461 if args .eval_data_path is not None :
461462 eval_dataset = load_dataset ("json" , data_files = args .eval_data_path )["train" ]
@@ -546,7 +547,7 @@ def run_forward(
546547 target_model : Optional [Eagle3TargetModel ] = None ,
547548 is_online : bool = True ,
548549) -> Tuple [List [torch .Tensor ], List [torch .Tensor ]]:
549- if args .is_vlm and args . target_model_backend == "custom" :
550+ if args .is_vlm :
550551 plosses , _ , acces = eagle3_model (
551552 input_ids = data ["input_ids" ].cuda (),
552553 attention_mask = data ["attention_mask" ].cuda (),
@@ -557,20 +558,10 @@ def run_forward(
557558 else :
558559 if is_online :
559560 # we generate the eagle3 using the target model in an online fashion
560- # Handle VLM data: pixel_values and image_grid_thw are lists
561- # pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None
562- image_grid_thw = (
563- [thw .cuda ().squeeze () for thw in data ["image_grid_thw" ]]
564- if args .is_vlm
565- else None
566- )
567561 eagle3_data = target_model .generate_eagle3_data (
568562 input_ids = data ["input_ids" ].cuda (),
569563 attention_mask = data ["attention_mask" ].cuda (),
570564 loss_mask = data ["loss_mask" ].cuda (),
571- is_vlm = args .is_vlm ,
572- pixel_values = data ["pixel_values" ].cuda (),
573- image_grid_thw = image_grid_thw ,
574565 )
575566
576567 input_ids = get_dp_data_shard_from_tp (eagle3_data .input_ids )
@@ -588,14 +579,13 @@ def run_forward(
588579 input_ids , target , loss_mask = target_model .preprocess (
589580 input_ids , target , loss_mask
590581 )
582+
591583 plosses , _ , acces = eagle3_model (
592584 input_ids = input_ids ,
593585 attention_mask = attention_mask ,
594586 loss_mask = loss_mask ,
595587 target = target ,
596588 hidden_states = hidden_states ,
597- image_grid_thw = image_grid_thw ,
598- is_vlm = args .is_vlm ,
599589 )
600590 return plosses , acces
601591
@@ -757,8 +747,6 @@ def main():
757747 if (
758748 args .is_vlm
759749 and getattr (draft_model_config , "target_model_type" , None ) == "qwen2_5_vl"
760- and args .tp_size == 1
761- and args .target_model_backend != "sglang"
762750 ):
763751 eagle3_model = QwenVLOnlineEagle3Model (
764752 target_model = target_model ,
@@ -769,7 +757,6 @@ def main():
769757 )
770758 else :
771759 eagle3_model = OnlineEagle3Model (
772- target_model = target_model ,
773760 draft_model = draft_model ,
774761 length = args .ttt_length ,
775762 attention_backend = args .attention_backend ,
@@ -923,6 +910,7 @@ def main():
923910 tracker ,
924911 mode = "eval" ,
925912 )
913+
926914 # ================================================
927915 # 7.3 Save Checkpoints
928916 # ================================================
@@ -935,6 +923,7 @@ def main():
935923
936924 if args .max_num_steps is not None and global_step >= args .max_num_steps :
937925 break
926+
938927 # Save final checkpoint if training ended without saving
939928 if global_step % args .save_interval != 0 :
940929 print_on_rank0 (
0 commit comments