Skip to content

Commit d2dd64c

Browse files
authored
Merge pull request #651 from transformerlab/add/userinfo-job-remote
Store user info in a REMOTE job and Job ID should be formatted inside the cluster name
2 parents 9301f1f + 04f1749 commit d2dd64c

File tree

3 files changed

+89
-262
lines changed

3 files changed

+89
-262
lines changed

test/api/test_remote.py

Lines changed: 0 additions & 244 deletions
Original file line numberDiff line numberDiff line change
@@ -117,211 +117,11 @@ def test_validate_env_vars_missing_port(self, monkeypatch):
117117
assert "GPU_ORCHESTRATION_SERVER_PORT" in error_response["message"]
118118

119119

120-
class TestCreateRemoteJob:
121-
"""Test the /remote/create-job endpoint"""
122-
123-
def test_create_remote_job_success(self, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
124-
"""Test creating a remote job successfully"""
125-
response = client.post(
126-
f"/remote/create-job?experimentId={mock_experiment_id}",
127-
data={
128-
"cluster_name": "test-cluster",
129-
"command": "echo 'test'",
130-
"task_name": "test-task",
131-
},
132-
)
133-
assert response.status_code == 200
134-
data = response.json()
135-
assert data["status"] == "success"
136-
assert "job_id" in data
137-
assert data["message"] == "Remote job created successfully"
138-
# Track job for cleanup
139-
job_cleanup.append((data["job_id"], mock_experiment_id))
140-
141-
def test_create_remote_job_with_optional_params(self, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
142-
"""Test creating a remote job with optional parameters"""
143-
response = client.post(
144-
f"/remote/create-job?experimentId={mock_experiment_id}",
145-
data={
146-
"cluster_name": "test-cluster",
147-
"command": "echo 'test'",
148-
"cpus": "4",
149-
"memory": "8GB",
150-
"disk_space": "100GB",
151-
"accelerators": "1xV100",
152-
"num_nodes": 2,
153-
},
154-
)
155-
assert response.status_code == 200
156-
data = response.json()
157-
assert data["status"] == "success"
158-
assert "job_id" in data
159-
# Track job for cleanup
160-
job_cleanup.append((data["job_id"], mock_experiment_id))
161-
162-
163-
class TestLaunchRemote:
164-
"""Test the /remote/launch endpoint"""
165-
166-
@patch("transformerlab.routers.remote.httpx.AsyncClient")
167-
def test_launch_remote_success(self, mock_client_class, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
168-
"""Test launching a remote job successfully"""
169-
# Mock the async client and response
170-
mock_response = MagicMock()
171-
mock_response.status_code = 200
172-
mock_response.json.return_value = {
173-
"request_id": "test-request-123",
174-
"cluster_name": "test-cluster",
175-
"status": "launched",
176-
}
177-
178-
# Set up the async context manager protocol for httpx.AsyncClient
179-
mock_httpx_client = AsyncMock()
180-
mock_httpx_client.post = AsyncMock(return_value=mock_response)
181-
# AsyncMock automatically handles __aenter__ and __aexit__, but we can be explicit
182-
mock_httpx_client.__aenter__.return_value = mock_httpx_client
183-
mock_httpx_client.__aexit__.return_value = None
184-
mock_client_class.return_value = mock_httpx_client
185-
186-
response = client.post(
187-
f"/remote/launch?experimentId={mock_experiment_id}",
188-
data={
189-
"cluster_name": "test-cluster",
190-
"command": "echo 'test'",
191-
},
192-
)
193-
assert response.status_code == 200
194-
data = response.json()
195-
assert data["status"] == "success"
196-
assert "job_id" in data
197-
assert data["data"]["request_id"] == "test-request-123"
198-
# Track job for cleanup
199-
job_cleanup.append((data["job_id"], mock_experiment_id))
200-
201-
@patch("transformerlab.routers.remote.httpx.AsyncClient")
202-
def test_launch_remote_with_existing_job_id(self, mock_client_class, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
203-
"""Test launching with an existing job_id"""
204-
# First create a job
205-
create_response = client.post(
206-
f"/remote/create-job?experimentId={mock_experiment_id}",
207-
data={
208-
"cluster_name": "test-cluster",
209-
"command": "echo 'test'",
210-
},
211-
)
212-
job_id = create_response.json()["job_id"]
213-
# Track job for cleanup
214-
job_cleanup.append((job_id, mock_experiment_id))
215-
216-
# Mock the async client and response
217-
mock_response = MagicMock()
218-
mock_response.status_code = 200
219-
mock_response.json.return_value = {
220-
"request_id": "test-request-456",
221-
"cluster_name": "test-cluster",
222-
}
223-
224-
mock_httpx_client = AsyncMock()
225-
mock_httpx_client.post = AsyncMock(return_value=mock_response)
226-
mock_httpx_client.__aenter__.return_value = mock_httpx_client
227-
mock_httpx_client.__aexit__.return_value = None
228-
mock_client_class.return_value = mock_httpx_client
229-
230-
response = client.post(
231-
f"/remote/launch?experimentId={mock_experiment_id}",
232-
data={
233-
"job_id": job_id,
234-
"cluster_name": "test-cluster",
235-
"command": "echo 'test'",
236-
},
237-
)
238-
assert response.status_code == 200
239-
data = response.json()
240-
assert data["status"] == "success"
241-
assert data["job_id"] == str(job_id)
242-
243-
def test_launch_remote_missing_env_vars(self, client, no_gpu_orchestration_env, mock_experiment_id):
244-
"""Test launching when GPU orchestration env vars are not set"""
245-
response = client.post(
246-
f"/remote/launch?experimentId={mock_experiment_id}",
247-
data={
248-
"cluster_name": "test-cluster",
249-
"command": "echo 'test'",
250-
},
251-
)
252-
assert response.status_code == 200
253-
data = response.json()
254-
assert data["status"] == "error"
255-
assert "GPU_ORCHESTRATION_SERVER" in data["message"]
256-
257-
@patch("transformerlab.routers.remote.httpx.AsyncClient")
258-
def test_launch_remote_orchestrator_error(self, mock_client_class, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
259-
"""Test handling orchestrator error response"""
260-
mock_response = MagicMock()
261-
mock_response.status_code = 500
262-
mock_response.text = "Internal Server Error"
263-
264-
mock_httpx_client = AsyncMock()
265-
mock_httpx_client.post = AsyncMock(return_value=mock_response)
266-
mock_httpx_client.__aenter__.return_value = mock_httpx_client
267-
mock_httpx_client.__aexit__.return_value = None
268-
mock_client_class.return_value = mock_httpx_client
269-
270-
response = client.post(
271-
f"/remote/launch?experimentId={mock_experiment_id}",
272-
data={
273-
"cluster_name": "test-cluster",
274-
"command": "echo 'test'",
275-
},
276-
)
277-
assert response.status_code == 200
278-
data = response.json()
279-
assert data["status"] == "error"
280-
assert "500" in data["message"]
281-
# Even if launch failed, a job might have been created, so track it if present
282-
if "job_id" in data:
283-
job_cleanup.append((data["job_id"], mock_experiment_id))
284120

285121

286122
class TestStopRemote:
287123
"""Test the /remote/stop endpoint"""
288124

289-
@patch("transformerlab.routers.remote.httpx.AsyncClient")
290-
def test_stop_remote_success(self, mock_client_class, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
291-
"""Test stopping a remote job successfully"""
292-
# Create a job first
293-
create_response = client.post(
294-
f"/remote/create-job?experimentId={mock_experiment_id}",
295-
data={
296-
"cluster_name": "test-cluster",
297-
"command": "echo 'test'",
298-
},
299-
)
300-
job_id = create_response.json()["job_id"]
301-
# Track job for cleanup
302-
job_cleanup.append((job_id, mock_experiment_id))
303-
304-
# Mock the async client and response
305-
mock_response = MagicMock()
306-
mock_response.status_code = 200
307-
mock_response.json.return_value = {"status": "stopped"}
308-
309-
mock_httpx_client = AsyncMock()
310-
mock_httpx_client.post = AsyncMock(return_value=mock_response)
311-
mock_httpx_client.__aenter__.return_value = mock_httpx_client
312-
mock_httpx_client.__aexit__.return_value = None
313-
mock_client_class.return_value = mock_httpx_client
314-
315-
response = client.post(
316-
"/remote/stop",
317-
data={
318-
"job_id": job_id,
319-
"cluster_name": "test-cluster",
320-
},
321-
)
322-
assert response.status_code == 200
323-
data = response.json()
324-
assert data["status"] == "success"
325125

326126
def test_stop_remote_missing_env_vars(self, client, no_gpu_orchestration_env):
327127
"""Test stopping when GPU orchestration env vars are not set"""
@@ -400,50 +200,6 @@ def test_check_status_no_launching_jobs(self, mock_client_class, client, gpu_orc
400200
assert "updated_jobs" in data
401201
assert data["updated_jobs"] == []
402202

403-
@patch("transformerlab.routers.remote.httpx.AsyncClient")
404-
def test_check_status_with_jobs(self, mock_client_class, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
405-
"""Test checking status with LAUNCHING jobs"""
406-
# Create a remote job in LAUNCHING state
407-
create_response = client.post(
408-
f"/remote/create-job?experimentId={mock_experiment_id}",
409-
data={
410-
"cluster_name": "test-cluster",
411-
"command": "echo 'test'",
412-
},
413-
)
414-
# Track job for cleanup
415-
job_id = None
416-
if create_response.status_code == 200:
417-
job_data = create_response.json()
418-
if "job_id" in job_data:
419-
job_id = job_data["job_id"]
420-
job_cleanup.append((job_id, mock_experiment_id))
421-
422-
# Verify job was created
423-
assert job_id is not None, "Job should have been created"
424-
425-
# Mock the async client and response for status check
426-
# This mocks the call to check_remote_job_status which calls the orchestrator
427-
mock_response = MagicMock()
428-
mock_response.status_code = 200
429-
mock_response.json.return_value = {
430-
"jobs": [
431-
{"status": "SUCCEEDED"},
432-
],
433-
}
434-
435-
mock_httpx_client = AsyncMock()
436-
mock_httpx_client.get = AsyncMock(return_value=mock_response)
437-
mock_httpx_client.__aenter__.return_value = mock_httpx_client
438-
mock_httpx_client.__aexit__.return_value = None
439-
mock_client_class.return_value = mock_httpx_client
440-
441-
response = client.get("/remote/check-status")
442-
assert response.status_code == 200
443-
data = response.json()
444-
assert data["status"] == "success"
445-
assert "updated_jobs" in data
446-
447203

448204
class TestGetOrchestratorLogs:
449205
"""Test the /remote/logs/{request_id} endpoint"""

transformerlab/routers/experiment/jobs.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,6 @@ async def get_checkpoints(job_id: str, request: Request):
526526
return {"checkpoints": []}
527527

528528
job_data = job["job_data"]
529-
530529
# First try to use the new SDK method to get checkpoints
531530
try:
532531
from lab.job import Job
@@ -535,26 +534,33 @@ async def get_checkpoints(job_id: str, request: Request):
535534
sdk_job = Job(job_id)
536535
checkpoint_paths = sdk_job.get_checkpoint_paths()
537536

537+
538538
if checkpoint_paths and len(checkpoint_paths) > 0:
539539
checkpoints = []
540540
for checkpoint_path in checkpoint_paths:
541541
try:
542-
stat = os.stat(checkpoint_path)
543-
modified_time = stat.st_mtime
544-
filesize = stat.st_size
545-
# Format the timestamp as ISO 8601 string
546-
formatted_time = datetime.fromtimestamp(modified_time).isoformat()
542+
if os.path.isdir(checkpoint_path):
543+
# Dont set formatted_time and filesize for directories (os.stat messes it up for fused filesystems)
544+
formatted_time = None
545+
filesize = None
546+
else:
547+
stat = os.stat(checkpoint_path)
548+
modified_time = stat.st_mtime
549+
filesize = stat.st_size
550+
# Format the timestamp as ISO 8601 string
551+
formatted_time = datetime.fromtimestamp(modified_time).isoformat()
552+
547553
filename = os.path.basename(checkpoint_path)
548554
checkpoints.append({"filename": filename, "date": formatted_time, "size": filesize})
549555
except Exception as e:
550-
logging.error(f"Error getting stat for checkpoint {checkpoint_path}: {e}")
556+
print(f"Error getting stat for checkpoint {checkpoint_path}: {e}")
551557
continue
552558

553559
# Sort checkpoints by filename in reverse (descending) order for consistent ordering
554560
checkpoints.sort(key=lambda x: x["filename"], reverse=True)
555561
return {"checkpoints": checkpoints}
556562
except Exception as e:
557-
logging.info(f"SDK checkpoint method failed for job {job_id}, falling back to legacy method: {e}")
563+
print(f"SDK checkpoint method failed for job {job_id}, falling back to legacy method: {e}")
558564

559565
# Fallback to the original logic if SDK method doesn't work or returns nothing
560566
# Check if the job has a supports_checkpoints flag
@@ -577,18 +583,35 @@ async def get_checkpoints(job_id: str, request: Request):
577583
default_adaptor_dir = os.path.join(workspace_dir, "adaptors", secure_filename(model_name), adaptor_name)
578584

579585
# print(f"Default adaptor directory: {default_adaptor_dir}")
580-
581-
checkpoints_dir = job_data.get("checkpoints_dir", default_adaptor_dir)
586+
# Get job directory from t
587+
checkpoints_dir = job_data.get("checkpoints_dir")
588+
if not checkpoints_dir:
589+
from lab.dirs import get_job_checkpoints_dir
590+
checkpoints_dir = get_job_checkpoints_dir(job_id)
582591
if not checkpoints_dir or not os.path.exists(checkpoints_dir):
583-
# print(f"Checkpoints directory does not exist: {checkpoints_dir}")
584592
return {"checkpoints": []}
593+
elif os.path.isdir(checkpoints_dir):
594+
checkpoints = []
595+
if len(os.listdir(checkpoints_dir)) > 0:
596+
for filename in os.listdir(checkpoints_dir):
597+
if fnmatch(filename, "*_adapters.safetensors"):
598+
file_path = os.path.join(checkpoints_dir, filename)
599+
stat = os.stat(file_path)
600+
modified_time = stat.st_mtime
601+
filesize = stat.st_size
602+
checkpoints.append({"filename": filename, "date": modified_time, "size": filesize})
603+
# allow directories too
604+
elif os.path.isdir(os.path.join(checkpoints_dir, filename)):
605+
checkpoints.append({"filename": filename, "date": None, "size": None})
606+
return {"checkpoints": checkpoints}
607+
585608

609+
# Fallback to using default adaptor directory as checkpoints directory
610+
checkpoints_dir = default_adaptor_dir
586611
checkpoints_file_filter = job_data.get("checkpoints_file_filter", "*_adapters.safetensors")
587612
if not checkpoints_file_filter:
588613
checkpoints_file_filter = "*_adapters.safetensors"
589614

590-
# print(f"Checkpoints directory: {checkpoints_dir}")
591-
# print(f"Checkpoints file filter: {checkpoints_file_filter}")
592615

593616
checkpoints = []
594617
try:
@@ -607,7 +630,7 @@ async def get_checkpoints(job_id: str, request: Request):
607630
filesize = None
608631
checkpoints.append({"filename": filename, "date": formatted_time, "size": filesize})
609632
except OSError as e:
610-
logging.error(f"Error reading checkpoints directory {checkpoints_dir}: {e}")
633+
print(f"Error reading checkpoints directory {checkpoints_dir}: {e}")
611634

612635
# Sort checkpoints by filename in reverse (descending) order for consistent ordering
613636
checkpoints.sort(key=lambda x: x["filename"], reverse=True)

0 commit comments

Comments
 (0)