@@ -91,6 +91,73 @@ def update_progress(self, progress: int) -> None:
9191 # Check for wandb URL on every progress update
9292 self ._check_and_capture_wandb_url ()
9393
94+ # ------------- checkpoint resume support -------------
95+ def get_checkpoint_to_resume (self ) -> Optional [str ]:
96+ """
97+ Get the checkpoint path to resume training from.
98+
99+ This method checks for checkpoint resume information stored in the job data
100+ when resuming training from a checkpoint.
101+
102+ Returns:
103+ Optional[str]: The full path to the checkpoint to resume from, or None if no
104+ checkpoint resume is requested.
105+ """
106+ if not self ._job :
107+ return None
108+
109+ job_data = self ._job .get_job_data ()
110+ if not job_data :
111+ return None
112+
113+ parent_job_id = job_data .get ('parent_job_id' )
114+ checkpoint_name = job_data .get ('resumed_from_checkpoint' )
115+
116+ if not parent_job_id or not checkpoint_name :
117+ return None
118+
119+ # Build the checkpoint path from parent job's checkpoints directory
120+ checkpoint_path = self .get_parent_job_checkpoint_path (parent_job_id , checkpoint_name )
121+
122+ # Verify the checkpoint exists
123+ if checkpoint_path and os .path .exists (checkpoint_path ):
124+ return checkpoint_path
125+
126+ return None
127+
128+ def get_parent_job_checkpoint_path (self , parent_job_id : str , checkpoint_name : str ) -> Optional [str ]:
129+ """
130+ Get the full path to a checkpoint from a parent job.
131+
132+ This is a helper method that constructs the path to a specific checkpoint
133+ from a parent job's checkpoints directory.
134+
135+ Args:
136+ parent_job_id (str): The ID of the parent job that created the checkpoint
137+ checkpoint_name (str): The name of the checkpoint file or directory
138+
139+ Returns:
140+ Optional[str]: The full path to the checkpoint, or None if it doesn't exist
141+ """
142+ try :
143+ checkpoints_dir = dirs .get_job_checkpoints_dir (parent_job_id )
144+ checkpoint_path = os .path .join (checkpoints_dir , checkpoint_name )
145+
146+ # Security check: ensure the checkpoint path is within the checkpoints directory
147+ checkpoint_path_normalized = os .path .normpath (checkpoint_path )
148+ checkpoints_dir_normalized = os .path .normpath (checkpoints_dir )
149+
150+ if not checkpoint_path_normalized .startswith (checkpoints_dir_normalized + os .sep ):
151+ return None
152+
153+ if os .path .exists (checkpoint_path_normalized ):
154+ return checkpoint_path_normalized
155+
156+ return None
157+ except Exception as e :
158+ print (f"Error getting parent job checkpoint path: { str (e )} " )
159+ return None
160+
94161 # ------------- completion -------------
95162 def finish (
96163 self ,
@@ -506,8 +573,8 @@ def save_dataset(self, df, dataset_id: str, additional_metadata: Optional[Dict[s
506573 try :
507574 if hasattr (df , "to_pandas" ) and callable (getattr (df , "to_pandas" )):
508575 df = df .to_pandas ()
509- except Exception :
510- pass
576+ except Exception as e :
577+ print ( f"Warning: Failed to convert dataset to pandas DataFrame: { str ( e ) } " )
511578
512579 # Prepare dataset directory
513580 dataset_id_safe = dataset_id .strip ()
@@ -562,16 +629,17 @@ def save_dataset(self, df, dataset_id: str, additional_metadata: Optional[Dict[s
562629 )
563630 except Exception as e :
564631 # Do not fail the save if metadata write fails; log to job data
632+ print (f"Warning: Failed to create dataset metadata: { str (e )} " )
565633 try :
566634 self ._job .update_job_data_field ("dataset_metadata_error" , str (e )) # type: ignore[union-attr]
567- except Exception :
568- pass
635+ except Exception as e2 :
636+ print ( f"Warning: Failed to log dataset metadata error: { str ( e2 ) } " )
569637
570638 # Track dataset on the job for provenance
571639 try :
572640 self ._job .update_job_data_field ("dataset_id" , dataset_id_safe ) # type: ignore[union-attr]
573- except Exception :
574- pass
641+ except Exception as e :
642+ print ( f"Warning: Failed to track dataset in job_data: { str ( e ) } " )
575643
576644 self .log (f"Dataset saved to '{ output_path } ' and registered as generated dataset '{ dataset_id_safe } '" )
577645 return output_path
@@ -615,8 +683,8 @@ def save_checkpoint(self, source_path: str, name: Optional[str] = None) -> str:
615683 ckpt_list .append (dest )
616684 self ._job .update_job_data_field ("checkpoints" , ckpt_list )
617685 self ._job .update_job_data_field ("latest_checkpoint" , dest )
618- except Exception :
619- pass
686+ except Exception as e :
687+ print ( f"Warning: Failed to track checkpoint in job_data: { str ( e ) } " )
620688
621689 return dest
622690
0 commit comments