Skip to content

Commit 3f67d9c

Browse files
authored
Merge pull request #638 from transformerlab/add/resume-training-checkpoint
Add first draft of resuming training from latest checkpoint
2 parents 325a2b1 + 15472e2 commit 3f67d9c

File tree

2 files changed

+96
-41
lines changed

2 files changed

+96
-41
lines changed

transformerlab/routers/experiment/jobs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,6 @@ async def get_checkpoints(job_id: str, request: Request):
633633

634634
# Sort checkpoints by filename in reverse (descending) order for consistent ordering
635635
checkpoints.sort(key=lambda x: x["filename"], reverse=True)
636-
# print(f"Sorted checkpoints: {checkpoints}")
637636

638637
return {
639638
"checkpoints": checkpoints,

transformerlab/routers/remote.py

Lines changed: 96 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from transformerlab.services.job_service import job_update_status
88
from transformerlab.routers.auth.api_key_auth import get_user_or_api_key
99
from transformerlab.services.auth import AuthenticatedIdentity, auth_service
10+
from lab.dirs import get_workspace_dir, get_job_checkpoints_dir
1011

1112

1213
router = APIRouter(prefix="/remote", tags=["remote"])
@@ -115,7 +116,7 @@ async def launch_remote(
115116
experimentId: str,
116117
identity: AuthenticatedIdentity = Depends(get_user_or_api_key),
117118
job_id: Optional[str] = Form(None),
118-
cluster_name: str = Form(...),
119+
cluster_name: Optional[str] = Form(None),
119120
command: str = Form("echo 'Hello World'"),
120121
task_name: Optional[str] = Form(None),
121122
cpus: Optional[str] = Form(None),
@@ -125,16 +126,91 @@ async def launch_remote(
125126
num_nodes: Optional[int] = Form(None),
126127
setup: Optional[str] = Form(None),
127128
uploaded_dir_path: Optional[str] = Form(None),
129+
checkpoint: Optional[str] = Form(None),
130+
parent_job_id: Optional[str] = Form(None),
128131
):
129132
"""
130133
Launch a remote instance via Lattice orchestrator. If job_id is provided, use existing job, otherwise create new one.
134+
If checkpoint and parent_job_id are provided, resume training from the specified checkpoint.
131135
"""
132-
# If job_id is provided, use existing job, otherwise create a new one
133136
formatted_cluster_name = cluster_name
134-
if job_id:
135-
# Trust the frontend-provided cluster_name when re-launching an existing job.
136-
pass
137-
else:
137+
# Handle resume from checkpoint logic
138+
if checkpoint and parent_job_id:
139+
# Get the parent job
140+
parent_job = job_service.job_get(parent_job_id)
141+
if not parent_job:
142+
return {"status": "error", "message": f"Parent job {parent_job_id} not found"}
143+
144+
# Get the parent job data
145+
parent_job_data = parent_job.get("job_data", {})
146+
147+
# Validate checkpoint existence
148+
checkpoints_dir = get_job_checkpoints_dir(parent_job_id)
149+
checkpoint_path = os.path.normpath(os.path.join(checkpoints_dir, checkpoint))
150+
151+
# Validate that the checkpoint path is within the checkpoints directory
152+
if not checkpoint_path.startswith(os.path.abspath(checkpoints_dir) + os.sep):
153+
return {"status": "error", "message": "Invalid checkpoint name (potential directory traversal detected)"}
154+
155+
if not os.path.exists(checkpoint_path):
156+
return {"status": "error", "message": f"Checkpoint {checkpoint} not found at {checkpoint_path}"}
157+
158+
# Get the original command
159+
command = parent_job_data.get("command", "")
160+
if not command:
161+
return {"status": "error", "message": "Original command not found in parent job data"}
162+
163+
# Create a simple, meaningful task name for the resumed training
164+
task_name = f"resume_training_{parent_job_id}"
165+
166+
# Use ALL parameters from parent job for resume (user just presses button)
167+
cluster_name = parent_job_data.get("cluster_name")
168+
cpus = parent_job_data.get("cpus")
169+
memory = parent_job_data.get("memory")
170+
disk_space = parent_job_data.get("disk_space")
171+
accelerators = parent_job_data.get("accelerators")
172+
num_nodes = parent_job_data.get("num_nodes")
173+
setup = parent_job_data.get("setup")
174+
uploaded_dir_path = parent_job_data.get("uploaded_dir_path")
175+
176+
# Force creation of new job for resume (don't use existing job_id)
177+
job_id = None
178+
179+
# Validate required fields
180+
if not cluster_name:
181+
return {"status": "error", "message": "cluster_name is required"}
182+
183+
# Build a unified data structure with all parameters
184+
data = {
185+
"cluster_name": cluster_name,
186+
"command": command,
187+
"task_name": task_name,
188+
}
189+
190+
# Add resume metadata if resuming from checkpoint
191+
if checkpoint and parent_job_id:
192+
data["resumed_from_checkpoint"] = checkpoint
193+
data["checkpoint_path"] = checkpoint_path
194+
data["parent_job_id"] = parent_job_id
195+
196+
# Add optional parameters if provided
197+
if cpus:
198+
data["cpus"] = cpus
199+
if memory:
200+
data["memory"] = memory
201+
if disk_space:
202+
data["disk_space"] = disk_space
203+
if accelerators:
204+
data["accelerators"] = accelerators
205+
if num_nodes:
206+
data["num_nodes"] = num_nodes
207+
if setup:
208+
data["setup"] = setup
209+
if uploaded_dir_path:
210+
data["uploaded_dir_path"] = uploaded_dir_path
211+
212+
# If job_id is provided, use existing job, otherwise create a new one
213+
if not job_id:
138214
# Get user information from the authentication identity
139215
user_info_payload = auth_service.get_user_info(identity)
140216

@@ -147,22 +223,18 @@ async def launch_remote(
147223
]).strip()
148224
if user_info_payload.get("email"):
149225
user_info["email"] = user_info_payload["email"]
150-
151-
# Create a new REMOTE job
152-
job_data = {"task_name": task_name, "command": command, "cluster_name": cluster_name}
153-
226+
154227
# Add user_info to job_data if we have any user information
155228
if user_info:
156-
job_data["user_info"] = user_info
157-
229+
data["user_info"] = user_info
158230
try:
159231
job_id = job_service.job_create(
160232
type="REMOTE",
161233
status="LAUNCHING",
162234
experiment_id=experimentId,
163235
)
164-
# Update the job data to add fields from job_data (this ensures default fields stay in the job)
165-
for key, value in job_data.items():
236+
# Store all data in the job (this ensures default fields stay in the job)
237+
for key, value in data.items():
166238
job_service.job_update_job_data_insert_key_value(job_id, key, value, experimentId)
167239

168240
# Format cluster_name as <user_value>-job-<job_id> and persist it
@@ -172,6 +244,9 @@ async def launch_remote(
172244
print(f"Failed to create job: {str(e)}")
173245
return {"status": "error", "message": "Failed to create job"}
174246

247+
# Use task_name as job_name if provided, otherwise fall back to cluster_name
248+
job_name = task_name if task_name else formatted_cluster_name
249+
data["job_name"] = job_name
175250
# Validate environment variables
176251
result = validate_gpu_orchestrator_env_vars()
177252
gpu_orchestrator_url, gpu_orchestrator_port = result
@@ -180,32 +255,14 @@ async def launch_remote(
180255
elif isinstance(gpu_orchestrator_port, dict):
181256
return gpu_orchestrator_port # Error response
182257

183-
# Prepare the request data for Lattice orchestrator
184-
request_data = {
185-
"cluster_name": formatted_cluster_name,
186-
"command": command,
187-
"tlab_job_id": job_id, # Pass the job_id to the orchestrator
188-
}
258+
# Prepare request data for orchestrator by copying the data and adding orchestrator-specific fields
259+
request_data = data.copy()
260+
request_data["tlab_job_id"] = job_id
261+
request_data["cluster_name"] = formatted_cluster_name
189262

190263
# Use task_name as job_name if provided, otherwise fall back to cluster_name
191-
job_name = task_name if task_name else formatted_cluster_name
192-
request_data["job_name"] = job_name
264+
request_data["job_name"] = task_name if task_name else cluster_name
193265

194-
# Add optional parameters if provided
195-
if cpus:
196-
request_data["cpus"] = cpus
197-
if memory:
198-
request_data["memory"] = memory
199-
if disk_space:
200-
request_data["disk_space"] = disk_space
201-
if accelerators:
202-
request_data["accelerators"] = accelerators
203-
if num_nodes:
204-
request_data["num_nodes"] = num_nodes
205-
if setup:
206-
request_data["setup"] = setup
207-
if uploaded_dir_path:
208-
request_data["uploaded_dir_path"] = uploaded_dir_path
209266

210267
gpu_orchestrator_url = f"{gpu_orchestrator_url}:{gpu_orchestrator_port}/api/v1/instances/launch"
211268

@@ -243,7 +300,7 @@ async def launch_remote(
243300
"status": "success",
244301
"data": response_data,
245302
"job_id": job_id,
246-
"message": "Remote instance launched successfully",
303+
"message": f"Training resumed from checkpoint {checkpoint}" if checkpoint else "Remote instance launched successfully",
247304
}
248305
else:
249306
return {
@@ -333,7 +390,6 @@ async def upload_directory(
333390
Upload a directory to the remote Lattice orchestrator for later use in cluster launches.
334391
Files are stored locally first, then sent to orchestrator.
335392
"""
336-
from lab.dirs import get_workspace_dir
337393

338394
# Validate environment variables
339395
result = validate_gpu_orchestrator_env_vars()
@@ -628,4 +684,4 @@ async def check_remote_jobs_status(request: Request):
628684

629685
except Exception as e:
630686
print(f"Error checking remote job status: {str(e)}")
631-
return {"status": "error", "message": "Error checking remote job status"}
687+
return {"status": "error", "message": "Error checking remote job status"}

0 commit comments

Comments
 (0)