77from transformerlab .services .job_service import job_update_status
88from transformerlab .routers .auth .api_key_auth import get_user_or_api_key
99from transformerlab .services .auth import AuthenticatedIdentity , auth_service
10+ from lab .dirs import get_workspace_dir , get_job_checkpoints_dir
1011
1112
1213router = APIRouter (prefix = "/remote" , tags = ["remote" ])
@@ -115,7 +116,7 @@ async def launch_remote(
115116 experimentId : str ,
116117 identity : AuthenticatedIdentity = Depends (get_user_or_api_key ),
117118 job_id : Optional [str ] = Form (None ),
118- cluster_name : str = Form (... ),
119+ cluster_name : Optional [ str ] = Form (None ),
119120 command : str = Form ("echo 'Hello World'" ),
120121 task_name : Optional [str ] = Form (None ),
121122 cpus : Optional [str ] = Form (None ),
@@ -125,16 +126,91 @@ async def launch_remote(
125126 num_nodes : Optional [int ] = Form (None ),
126127 setup : Optional [str ] = Form (None ),
127128 uploaded_dir_path : Optional [str ] = Form (None ),
129+ checkpoint : Optional [str ] = Form (None ),
130+ parent_job_id : Optional [str ] = Form (None ),
128131):
129132 """
130133 Launch a remote instance via Lattice orchestrator. If job_id is provided, use existing job, otherwise create new one.
134+ If checkpoint and parent_job_id are provided, resume training from the specified checkpoint.
131135 """
132- # If job_id is provided, use existing job, otherwise create a new one
133136 formatted_cluster_name = cluster_name
134- if job_id :
135- # Trust the frontend-provided cluster_name when re-launching an existing job.
136- pass
137- else :
137+ # Handle resume from checkpoint logic
138+ if checkpoint and parent_job_id :
139+ # Get the parent job
140+ parent_job = job_service .job_get (parent_job_id )
141+ if not parent_job :
142+ return {"status" : "error" , "message" : f"Parent job { parent_job_id } not found" }
143+
144+ # Get the parent job data
145+ parent_job_data = parent_job .get ("job_data" , {})
146+
147+ # Validate checkpoint existence
148+ checkpoints_dir = get_job_checkpoints_dir (parent_job_id )
149+ checkpoint_path = os .path .normpath (os .path .join (checkpoints_dir , checkpoint ))
150+
151+ # Validate that the checkpoint path is within the checkpoints directory
152+ if not checkpoint_path .startswith (os .path .abspath (checkpoints_dir ) + os .sep ):
153+ return {"status" : "error" , "message" : "Invalid checkpoint name (potential directory traversal detected)" }
154+
155+ if not os .path .exists (checkpoint_path ):
156+ return {"status" : "error" , "message" : f"Checkpoint { checkpoint } not found at { checkpoint_path } " }
157+
158+ # Get the original command
159+ command = parent_job_data .get ("command" , "" )
160+ if not command :
161+ return {"status" : "error" , "message" : "Original command not found in parent job data" }
162+
163+ # Create a simple, meaningful task name for the resumed training
164+ task_name = f"resume_training_{ parent_job_id } "
165+
166+ # Use ALL parameters from parent job for resume (user just presses button)
167+ cluster_name = parent_job_data .get ("cluster_name" )
168+ cpus = parent_job_data .get ("cpus" )
169+ memory = parent_job_data .get ("memory" )
170+ disk_space = parent_job_data .get ("disk_space" )
171+ accelerators = parent_job_data .get ("accelerators" )
172+ num_nodes = parent_job_data .get ("num_nodes" )
173+ setup = parent_job_data .get ("setup" )
174+ uploaded_dir_path = parent_job_data .get ("uploaded_dir_path" )
175+
176+ # Force creation of new job for resume (don't use existing job_id)
177+ job_id = None
178+
179+ # Validate required fields
180+ if not cluster_name :
181+ return {"status" : "error" , "message" : "cluster_name is required" }
182+
183+ # Build a unified data structure with all parameters
184+ data = {
185+ "cluster_name" : cluster_name ,
186+ "command" : command ,
187+ "task_name" : task_name ,
188+ }
189+
190+ # Add resume metadata if resuming from checkpoint
191+ if checkpoint and parent_job_id :
192+ data ["resumed_from_checkpoint" ] = checkpoint
193+ data ["checkpoint_path" ] = checkpoint_path
194+ data ["parent_job_id" ] = parent_job_id
195+
196+ # Add optional parameters if provided
197+ if cpus :
198+ data ["cpus" ] = cpus
199+ if memory :
200+ data ["memory" ] = memory
201+ if disk_space :
202+ data ["disk_space" ] = disk_space
203+ if accelerators :
204+ data ["accelerators" ] = accelerators
205+ if num_nodes :
206+ data ["num_nodes" ] = num_nodes
207+ if setup :
208+ data ["setup" ] = setup
209+ if uploaded_dir_path :
210+ data ["uploaded_dir_path" ] = uploaded_dir_path
211+
212+ # If job_id is provided, use existing job, otherwise create a new one
213+ if not job_id :
138214 # Get user information from the authentication identity
139215 user_info_payload = auth_service .get_user_info (identity )
140216
@@ -147,22 +223,18 @@ async def launch_remote(
147223 ]).strip ()
148224 if user_info_payload .get ("email" ):
149225 user_info ["email" ] = user_info_payload ["email" ]
150-
151- # Create a new REMOTE job
152- job_data = {"task_name" : task_name , "command" : command , "cluster_name" : cluster_name }
153-
226+
154227 # Add user_info to job_data if we have any user information
155228 if user_info :
156- job_data ["user_info" ] = user_info
157-
229+ data ["user_info" ] = user_info
158230 try :
159231 job_id = job_service .job_create (
160232 type = "REMOTE" ,
161233 status = "LAUNCHING" ,
162234 experiment_id = experimentId ,
163235 )
164- # Update the job data to add fields from job_data (this ensures default fields stay in the job)
165- for key , value in job_data .items ():
236+ # Store all data in the job (this ensures default fields stay in the job)
237+ for key , value in data .items ():
166238 job_service .job_update_job_data_insert_key_value (job_id , key , value , experimentId )
167239
168240 # Format cluster_name as <user_value>-job-<job_id> and persist it
@@ -172,6 +244,9 @@ async def launch_remote(
172244 print (f"Failed to create job: { str (e )} " )
173245 return {"status" : "error" , "message" : "Failed to create job" }
174246
247+ # Use task_name as job_name if provided, otherwise fall back to cluster_name
248+ job_name = task_name if task_name else formatted_cluster_name
249+ data ["job_name" ] = job_name
175250 # Validate environment variables
176251 result = validate_gpu_orchestrator_env_vars ()
177252 gpu_orchestrator_url , gpu_orchestrator_port = result
@@ -180,32 +255,14 @@ async def launch_remote(
180255 elif isinstance (gpu_orchestrator_port , dict ):
181256 return gpu_orchestrator_port # Error response
182257
183- # Prepare the request data for Lattice orchestrator
184- request_data = {
185- "cluster_name" : formatted_cluster_name ,
186- "command" : command ,
187- "tlab_job_id" : job_id , # Pass the job_id to the orchestrator
188- }
258+ # Prepare request data for orchestrator by copying the data and adding orchestrator-specific fields
259+ request_data = data .copy ()
260+ request_data ["tlab_job_id" ] = job_id
261+ request_data ["cluster_name" ] = formatted_cluster_name
189262
190263 # Use task_name as job_name if provided, otherwise fall back to cluster_name
191- job_name = task_name if task_name else formatted_cluster_name
192- request_data ["job_name" ] = job_name
264+ request_data ["job_name" ] = task_name if task_name else cluster_name
193265
194- # Add optional parameters if provided
195- if cpus :
196- request_data ["cpus" ] = cpus
197- if memory :
198- request_data ["memory" ] = memory
199- if disk_space :
200- request_data ["disk_space" ] = disk_space
201- if accelerators :
202- request_data ["accelerators" ] = accelerators
203- if num_nodes :
204- request_data ["num_nodes" ] = num_nodes
205- if setup :
206- request_data ["setup" ] = setup
207- if uploaded_dir_path :
208- request_data ["uploaded_dir_path" ] = uploaded_dir_path
209266
210267 gpu_orchestrator_url = f"{ gpu_orchestrator_url } :{ gpu_orchestrator_port } /api/v1/instances/launch"
211268
@@ -243,7 +300,7 @@ async def launch_remote(
243300 "status" : "success" ,
244301 "data" : response_data ,
245302 "job_id" : job_id ,
246- "message" : "Remote instance launched successfully" ,
303+ "message" : f"Training resumed from checkpoint { checkpoint } " if checkpoint else "Remote instance launched successfully" ,
247304 }
248305 else :
249306 return {
@@ -333,7 +390,6 @@ async def upload_directory(
333390 Upload a directory to the remote Lattice orchestrator for later use in cluster launches.
334391 Files are stored locally first, then sent to orchestrator.
335392 """
336- from lab .dirs import get_workspace_dir
337393
338394 # Validate environment variables
339395 result = validate_gpu_orchestrator_env_vars ()
@@ -628,4 +684,4 @@ async def check_remote_jobs_status(request: Request):
628684
629685 except Exception as e :
630686 print (f"Error checking remote job status: { str (e )} " )
631- return {"status" : "error" , "message" : "Error checking remote job status" }
687+ return {"status" : "error" , "message" : "Error checking remote job status" }
0 commit comments