3333from specforge .modeling .target .target_utils import TargetEmbeddingsAndHead
3434from specforge .optimizer import BF16Optimizer
3535from specforge .tracker import create_tracker
36- from specforge .utils import print_on_rank0 , print_with_rank
36+ from specforge .utils import get_last_checkpoint , print_on_rank0 , print_with_rank
3737
3838
3939def parse_args ():
@@ -108,6 +108,12 @@ def parse_args():
108108 training_group .add_argument ("--accumulation-steps" , type = int , default = 1 )
109109 training_group .add_argument ("--seed" , type = int , default = 42 )
110110 training_group .add_argument ("--resume" , action = "store_true" )
111+ training_group .add_argument (
112+ "--ckpt-dir" ,
113+ type = str ,
114+ default = None ,
115+ help = "Directory of the checkpoint to resume training from" ,
116+ )
111117
112118 output_group = parser .add_argument_group ("output" )
113119 output_group .add_argument ("--output-dir" , type = str , required = True )
@@ -162,25 +168,21 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]:
162168 draft_config = AutoConfig .from_pretrained (args .draft_config_path )
163169 print_on_rank0 (f"Loaded draft config from { args .draft_config_path } " )
164170 else :
165- # Load config from HF (needed for structure info even if backend is sglang)
166171 target_config = AutoConfig .from_pretrained (args .target_model_path )
167172 draft_config = AutoConfig .from_pretrained (args .target_model_path )
168173 draft_config .num_hidden_layers = args .num_draft_layers
169174 draft_config .block_size = args .block_size
170175 draft_config .num_target_layers = target_config .num_hidden_layers
171176 print_on_rank0 ("Auto-generated draft config from target model" )
172177
173- # Ensure dflash_config exists in config (for target_layer_ids / mask_token_id)
174178 if not hasattr (draft_config , "dflash_config" ) or draft_config .dflash_config is None :
175179 draft_config .dflash_config = {}
176180
177- # Set attention implementation based on backend
178181 draft_config ._attn_implementation = args .attention_backend
179182 print_on_rank0 (f"Using attention backend: { args .attention_backend } " )
180183
181184 draft_model = DFlashDraftModel (draft_config ).cuda ().to (torch .bfloat16 )
182185
183- # Set capture layers for target model based on draft model config
184186 target_model .set_capture_layers (draft_model .target_layer_ids )
185187
186188 print_on_rank0 (
@@ -199,7 +201,6 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]
199201 """Build train and eval dataloaders."""
200202 import hashlib
201203
202- # convert to dataloader
203204 cache_params_string = (
204205 f"{ args .train_data_path } -"
205206 f"{ args .max_length } -"
@@ -220,7 +221,6 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]
220221 num_proc = args .build_dataset_num_proc ,
221222 )
222223
223- # Filter out samples with too few loss tokens (DFlash requires >= 2 * block_size)
224224 min_loss_tokens = 2 * args .block_size
225225 original_size = len (train_eagle3_dataset )
226226 train_eagle3_dataset = train_eagle3_dataset .filter (
@@ -287,7 +287,6 @@ def save_checkpoint(args, epoch, step, dflash_model, draft_model, optimizer):
287287
288288 draft_model .save_pretrained (save_dir , state_dict = draft_state_dict )
289289
290- # Copy dflash.py for inference compatibility (matches auto_map in config)
291290 modeling_src = os .path .join (
292291 os .path .dirname (__file__ ),
293292 ".." ,
@@ -331,16 +330,13 @@ def record_metrics(
331330
332331
333332def main ():
334- # Configure logging to ensure we see INFO logs
333+
335334 logging .basicConfig (
336335 format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
337336 datefmt = "%m/%d/%Y %H:%M:%S" ,
338337 level = logging .INFO ,
339338 )
340- # Force the root logger to INFO as well, just in case
341339 logging .getLogger ().setLevel (logging .INFO )
342-
343- # Filter annoying FSDP warnings
344340 warnings .filterwarnings (
345341 "ignore" ,
346342 "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed" ,
@@ -354,9 +350,45 @@ def main():
354350
355351 target_model , draft_model = build_models (args )
356352
353+ draft_model_last_checkpoint = None
354+ if args .ckpt_dir is not None :
355+ if os .path .isdir (args .ckpt_dir ):
356+ draft_model_last_checkpoint = args .ckpt_dir
357+ print_on_rank0 (f"Using checkpoint: { draft_model_last_checkpoint } " )
358+ else :
359+ raise ValueError (
360+ f"Provided ckpt dir { args .ckpt_dir } is not a valid directory."
361+ )
362+
363+ if args .resume and os .path .isdir (args .output_dir ):
364+ draft_model_last_checkpoint = get_last_checkpoint (
365+ args .output_dir , prefix = r"epoch_\d+_step"
366+ )
367+ print_on_rank0 (f"Last checkpoint detected: { draft_model_last_checkpoint } " )
368+
369+ resume_state = None
370+ if draft_model_last_checkpoint :
371+ loaded_model = DFlashDraftModel .from_pretrained (
372+ draft_model_last_checkpoint , torch_dtype = torch .bfloat16
373+ )
374+ draft_model .load_state_dict (loaded_model .state_dict ())
375+ del loaded_model
376+ print_on_rank0 ("Loaded draft model weights from checkpoint" )
377+
378+ training_state_path = os .path .join (
379+ draft_model_last_checkpoint , "training_state.pt"
380+ )
381+ if os .path .exists (training_state_path ):
382+ resume_state = torch .load (
383+ training_state_path , map_location = "cpu" , weights_only = False
384+ )
385+ print_on_rank0 (
386+ f"Will resume from epoch { resume_state ['epoch' ]} , "
387+ f"step { resume_state ['global_step' ]} "
388+ )
389+
357390 tokenizer = AutoTokenizer .from_pretrained (args .target_model_path )
358391
359- # Get mask_token_id
360392 if args .mask_token_id is not None :
361393 mask_token_id = args .mask_token_id
362394 elif tokenizer .mask_token_id is not None :
@@ -366,9 +398,6 @@ def main():
366398 mask_token_id = tokenizer .mask_token_id
367399 print_on_rank0 (f"Using mask_token_id: { mask_token_id } " )
368400
369- # Write mask_token_id and target_layer_ids into draft config so that
370- # save_pretrained produces a config.json compatible with the official
371- # dflash inference code (which reads from config.dflash_config).
372401 draft_model .mask_token_id = mask_token_id
373402 draft_model .config .dflash_config ["mask_token_id" ] = mask_token_id
374403 draft_model .config .dflash_config ["target_layer_ids" ] = draft_model .target_layer_ids
@@ -380,10 +409,7 @@ def main():
380409 total_steps = args .num_epochs * steps_per_epoch
381410 print_on_rank0 (f"Total training steps: { total_steps } " )
382411
383- # Note: We need embedding layer for DFlash wrapper.
384- # For SGLang backend, we can't easily get the embedding layer object.
385- # We use TargetEmbeddingsAndHead to efficiently load only needed weights.
386- print_on_rank0 ("Loading target embeddings and head efficiently..." )
412+ print_on_rank0 ("Loading target embeddings and head..." )
387413 target_components = TargetEmbeddingsAndHead .from_pretrained (
388414 args .target_model_path ,
389415 embed_key = "model.embed_tokens.weight" , # Adjust if Qwen/Llama differs
@@ -423,14 +449,25 @@ def main():
423449 total_steps = total_steps ,
424450 )
425451
452+ start_epoch = 0
453+ global_step = 0
454+ if resume_state is not None :
455+ optimizer .scheduler .load_state_dict (resume_state ["scheduler_state_dict" ])
456+ start_epoch = resume_state ["epoch" ]
457+ global_step = resume_state ["global_step" ]
458+ del resume_state
459+ print_on_rank0 (f"Restored scheduler, lr={ optimizer .get_learning_rate ():.6f} " )
460+
461+ skip_steps = global_step - start_epoch * len (train_dataloader )
462+
426463 print_on_rank0 (f"Initializing tracker (report_to={ args .report_to } )..." )
427464 tracker = create_tracker (args , args .output_dir )
428465 print_on_rank0 ("Tracker initialized successfully." )
429466
430- global_step = 0
431467 last_time = time .time ()
468+ print_on_rank0 (f"Starting training from epoch { start_epoch } , step { global_step } " )
432469
433- for epoch in range (args .num_epochs ):
470+ for epoch in range (start_epoch , args .num_epochs ):
434471 train_dataloader .sampler .set_epoch (epoch )
435472 draft_model .train ()
436473
@@ -441,21 +478,20 @@ def main():
441478 else :
442479 progress_bar = train_dataloader
443480
444- for data in progress_bar :
481+ for step_in_epoch , data in enumerate (progress_bar ):
482+ if epoch == start_epoch and step_in_epoch < skip_steps :
483+ continue
445484 global_step += 1
446485
447486 input_ids = data ["input_ids" ].cuda ()
448487 attention_mask = data ["attention_mask" ].cuda ()
449488 loss_mask = data ["loss_mask" ].cuda ()
450489
451- # Generate context from Target Model (SGLang or HF)
452- # This calls the backend to get hidden states
453490 target_output = target_model .generate_dflash_data (
454491 input_ids , attention_mask , loss_mask
455492 )
456493 hidden_states = target_output .hidden_states .cuda () # Ensure on GPU
457494
458- # Forward pass (Parallel Training)
459495 loss , accuracy = dflash_model (
460496 input_ids = input_ids ,
461497 attention_mask = attention_mask ,
0 commit comments