Skip to content

Commit 41df212

Browse files
authored
Merge branch 'main' into fix/logging-print
2 parents 10a8351 + d2dd64c commit 41df212

File tree

3 files changed

+86
-259
lines changed

3 files changed

+86
-259
lines changed

test/api/test_remote.py

Lines changed: 0 additions & 244 deletions
Original file line numberDiff line numberDiff line change
@@ -117,211 +117,11 @@ def test_validate_env_vars_missing_port(self, monkeypatch):
117117
assert "GPU_ORCHESTRATION_SERVER_PORT" in error_response["message"]
118118

119119

120-
class TestCreateRemoteJob:
121-
"""Test the /remote/create-job endpoint"""
122-
123-
def test_create_remote_job_success(self, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
124-
"""Test creating a remote job successfully"""
125-
response = client.post(
126-
f"/remote/create-job?experimentId={mock_experiment_id}",
127-
data={
128-
"cluster_name": "test-cluster",
129-
"command": "echo 'test'",
130-
"task_name": "test-task",
131-
},
132-
)
133-
assert response.status_code == 200
134-
data = response.json()
135-
assert data["status"] == "success"
136-
assert "job_id" in data
137-
assert data["message"] == "Remote job created successfully"
138-
# Track job for cleanup
139-
job_cleanup.append((data["job_id"], mock_experiment_id))
140-
141-
def test_create_remote_job_with_optional_params(self, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
142-
"""Test creating a remote job with optional parameters"""
143-
response = client.post(
144-
f"/remote/create-job?experimentId={mock_experiment_id}",
145-
data={
146-
"cluster_name": "test-cluster",
147-
"command": "echo 'test'",
148-
"cpus": "4",
149-
"memory": "8GB",
150-
"disk_space": "100GB",
151-
"accelerators": "1xV100",
152-
"num_nodes": 2,
153-
},
154-
)
155-
assert response.status_code == 200
156-
data = response.json()
157-
assert data["status"] == "success"
158-
assert "job_id" in data
159-
# Track job for cleanup
160-
job_cleanup.append((data["job_id"], mock_experiment_id))
161-
162-
163-
class TestLaunchRemote:
164-
"""Test the /remote/launch endpoint"""
165-
166-
@patch("transformerlab.routers.remote.httpx.AsyncClient")
167-
def test_launch_remote_success(self, mock_client_class, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
168-
"""Test launching a remote job successfully"""
169-
# Mock the async client and response
170-
mock_response = MagicMock()
171-
mock_response.status_code = 200
172-
mock_response.json.return_value = {
173-
"request_id": "test-request-123",
174-
"cluster_name": "test-cluster",
175-
"status": "launched",
176-
}
177-
178-
# Set up the async context manager protocol for httpx.AsyncClient
179-
mock_httpx_client = AsyncMock()
180-
mock_httpx_client.post = AsyncMock(return_value=mock_response)
181-
# AsyncMock automatically handles __aenter__ and __aexit__, but we can be explicit
182-
mock_httpx_client.__aenter__.return_value = mock_httpx_client
183-
mock_httpx_client.__aexit__.return_value = None
184-
mock_client_class.return_value = mock_httpx_client
185-
186-
response = client.post(
187-
f"/remote/launch?experimentId={mock_experiment_id}",
188-
data={
189-
"cluster_name": "test-cluster",
190-
"command": "echo 'test'",
191-
},
192-
)
193-
assert response.status_code == 200
194-
data = response.json()
195-
assert data["status"] == "success"
196-
assert "job_id" in data
197-
assert data["data"]["request_id"] == "test-request-123"
198-
# Track job for cleanup
199-
job_cleanup.append((data["job_id"], mock_experiment_id))
200-
201-
@patch("transformerlab.routers.remote.httpx.AsyncClient")
202-
def test_launch_remote_with_existing_job_id(self, mock_client_class, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
203-
"""Test launching with an existing job_id"""
204-
# First create a job
205-
create_response = client.post(
206-
f"/remote/create-job?experimentId={mock_experiment_id}",
207-
data={
208-
"cluster_name": "test-cluster",
209-
"command": "echo 'test'",
210-
},
211-
)
212-
job_id = create_response.json()["job_id"]
213-
# Track job for cleanup
214-
job_cleanup.append((job_id, mock_experiment_id))
215-
216-
# Mock the async client and response
217-
mock_response = MagicMock()
218-
mock_response.status_code = 200
219-
mock_response.json.return_value = {
220-
"request_id": "test-request-456",
221-
"cluster_name": "test-cluster",
222-
}
223-
224-
mock_httpx_client = AsyncMock()
225-
mock_httpx_client.post = AsyncMock(return_value=mock_response)
226-
mock_httpx_client.__aenter__.return_value = mock_httpx_client
227-
mock_httpx_client.__aexit__.return_value = None
228-
mock_client_class.return_value = mock_httpx_client
229-
230-
response = client.post(
231-
f"/remote/launch?experimentId={mock_experiment_id}",
232-
data={
233-
"job_id": job_id,
234-
"cluster_name": "test-cluster",
235-
"command": "echo 'test'",
236-
},
237-
)
238-
assert response.status_code == 200
239-
data = response.json()
240-
assert data["status"] == "success"
241-
assert data["job_id"] == str(job_id)
242-
243-
def test_launch_remote_missing_env_vars(self, client, no_gpu_orchestration_env, mock_experiment_id):
244-
"""Test launching when GPU orchestration env vars are not set"""
245-
response = client.post(
246-
f"/remote/launch?experimentId={mock_experiment_id}",
247-
data={
248-
"cluster_name": "test-cluster",
249-
"command": "echo 'test'",
250-
},
251-
)
252-
assert response.status_code == 200
253-
data = response.json()
254-
assert data["status"] == "error"
255-
assert "GPU_ORCHESTRATION_SERVER" in data["message"]
256-
257-
@patch("transformerlab.routers.remote.httpx.AsyncClient")
258-
def test_launch_remote_orchestrator_error(self, mock_client_class, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
259-
"""Test handling orchestrator error response"""
260-
mock_response = MagicMock()
261-
mock_response.status_code = 500
262-
mock_response.text = "Internal Server Error"
263-
264-
mock_httpx_client = AsyncMock()
265-
mock_httpx_client.post = AsyncMock(return_value=mock_response)
266-
mock_httpx_client.__aenter__.return_value = mock_httpx_client
267-
mock_httpx_client.__aexit__.return_value = None
268-
mock_client_class.return_value = mock_httpx_client
269-
270-
response = client.post(
271-
f"/remote/launch?experimentId={mock_experiment_id}",
272-
data={
273-
"cluster_name": "test-cluster",
274-
"command": "echo 'test'",
275-
},
276-
)
277-
assert response.status_code == 200
278-
data = response.json()
279-
assert data["status"] == "error"
280-
assert "500" in data["message"]
281-
# Even if launch failed, a job might have been created, so track it if present
282-
if "job_id" in data:
283-
job_cleanup.append((data["job_id"], mock_experiment_id))
284120

285121

286122
class TestStopRemote:
287123
"""Test the /remote/stop endpoint"""
288124

289-
@patch("transformerlab.routers.remote.httpx.AsyncClient")
290-
def test_stop_remote_success(self, mock_client_class, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
291-
"""Test stopping a remote job successfully"""
292-
# Create a job first
293-
create_response = client.post(
294-
f"/remote/create-job?experimentId={mock_experiment_id}",
295-
data={
296-
"cluster_name": "test-cluster",
297-
"command": "echo 'test'",
298-
},
299-
)
300-
job_id = create_response.json()["job_id"]
301-
# Track job for cleanup
302-
job_cleanup.append((job_id, mock_experiment_id))
303-
304-
# Mock the async client and response
305-
mock_response = MagicMock()
306-
mock_response.status_code = 200
307-
mock_response.json.return_value = {"status": "stopped"}
308-
309-
mock_httpx_client = AsyncMock()
310-
mock_httpx_client.post = AsyncMock(return_value=mock_response)
311-
mock_httpx_client.__aenter__.return_value = mock_httpx_client
312-
mock_httpx_client.__aexit__.return_value = None
313-
mock_client_class.return_value = mock_httpx_client
314-
315-
response = client.post(
316-
"/remote/stop",
317-
data={
318-
"job_id": job_id,
319-
"cluster_name": "test-cluster",
320-
},
321-
)
322-
assert response.status_code == 200
323-
data = response.json()
324-
assert data["status"] == "success"
325125

326126
def test_stop_remote_missing_env_vars(self, client, no_gpu_orchestration_env):
327127
"""Test stopping when GPU orchestration env vars are not set"""
@@ -400,50 +200,6 @@ def test_check_status_no_launching_jobs(self, mock_client_class, client, gpu_orc
400200
assert "updated_jobs" in data
401201
assert data["updated_jobs"] == []
402202

403-
@patch("transformerlab.routers.remote.httpx.AsyncClient")
404-
def test_check_status_with_jobs(self, mock_client_class, client, gpu_orchestration_env_vars, mock_experiment_id, job_cleanup):
405-
"""Test checking status with LAUNCHING jobs"""
406-
# Create a remote job in LAUNCHING state
407-
create_response = client.post(
408-
f"/remote/create-job?experimentId={mock_experiment_id}",
409-
data={
410-
"cluster_name": "test-cluster",
411-
"command": "echo 'test'",
412-
},
413-
)
414-
# Track job for cleanup
415-
job_id = None
416-
if create_response.status_code == 200:
417-
job_data = create_response.json()
418-
if "job_id" in job_data:
419-
job_id = job_data["job_id"]
420-
job_cleanup.append((job_id, mock_experiment_id))
421-
422-
# Verify job was created
423-
assert job_id is not None, "Job should have been created"
424-
425-
# Mock the async client and response for status check
426-
# This mocks the call to check_remote_job_status which calls the orchestrator
427-
mock_response = MagicMock()
428-
mock_response.status_code = 200
429-
mock_response.json.return_value = {
430-
"jobs": [
431-
{"status": "SUCCEEDED"},
432-
],
433-
}
434-
435-
mock_httpx_client = AsyncMock()
436-
mock_httpx_client.get = AsyncMock(return_value=mock_response)
437-
mock_httpx_client.__aenter__.return_value = mock_httpx_client
438-
mock_httpx_client.__aexit__.return_value = None
439-
mock_client_class.return_value = mock_httpx_client
440-
441-
response = client.get("/remote/check-status")
442-
assert response.status_code == 200
443-
data = response.json()
444-
assert data["status"] == "success"
445-
assert "updated_jobs" in data
446-
447203

448204
class TestGetOrchestratorLogs:
449205
"""Test the /remote/logs/{request_id} endpoint"""

transformerlab/routers/experiment/jobs.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,6 @@ async def get_checkpoints(job_id: str, request: Request):
525525
return {"checkpoints": []}
526526

527527
job_data = job["job_data"]
528-
529528
# First try to use the new SDK method to get checkpoints
530529
try:
531530
from lab.job import Job
@@ -534,15 +533,22 @@ async def get_checkpoints(job_id: str, request: Request):
534533
sdk_job = Job(job_id)
535534
checkpoint_paths = sdk_job.get_checkpoint_paths()
536535

536+
537537
if checkpoint_paths and len(checkpoint_paths) > 0:
538538
checkpoints = []
539539
for checkpoint_path in checkpoint_paths:
540540
try:
541-
stat = os.stat(checkpoint_path)
542-
modified_time = stat.st_mtime
543-
filesize = stat.st_size
544-
# Format the timestamp as ISO 8601 string
545-
formatted_time = datetime.fromtimestamp(modified_time).isoformat()
541+
if os.path.isdir(checkpoint_path):
542+
# Dont set formatted_time and filesize for directories (os.stat messes it up for fused filesystems)
543+
formatted_time = None
544+
filesize = None
545+
else:
546+
stat = os.stat(checkpoint_path)
547+
modified_time = stat.st_mtime
548+
filesize = stat.st_size
549+
# Format the timestamp as ISO 8601 string
550+
formatted_time = datetime.fromtimestamp(modified_time).isoformat()
551+
546552
filename = os.path.basename(checkpoint_path)
547553
checkpoints.append({"filename": filename, "date": formatted_time, "size": filesize})
548554
except Exception as e:
@@ -576,18 +582,35 @@ async def get_checkpoints(job_id: str, request: Request):
576582
default_adaptor_dir = os.path.join(workspace_dir, "adaptors", secure_filename(model_name), adaptor_name)
577583

578584
# print(f"Default adaptor directory: {default_adaptor_dir}")
579-
580-
checkpoints_dir = job_data.get("checkpoints_dir", default_adaptor_dir)
585+
# Get job directory from t
586+
checkpoints_dir = job_data.get("checkpoints_dir")
587+
if not checkpoints_dir:
588+
from lab.dirs import get_job_checkpoints_dir
589+
checkpoints_dir = get_job_checkpoints_dir(job_id)
581590
if not checkpoints_dir or not os.path.exists(checkpoints_dir):
582-
# print(f"Checkpoints directory does not exist: {checkpoints_dir}")
583591
return {"checkpoints": []}
592+
elif os.path.isdir(checkpoints_dir):
593+
checkpoints = []
594+
if len(os.listdir(checkpoints_dir)) > 0:
595+
for filename in os.listdir(checkpoints_dir):
596+
if fnmatch(filename, "*_adapters.safetensors"):
597+
file_path = os.path.join(checkpoints_dir, filename)
598+
stat = os.stat(file_path)
599+
modified_time = stat.st_mtime
600+
filesize = stat.st_size
601+
checkpoints.append({"filename": filename, "date": modified_time, "size": filesize})
602+
# allow directories too
603+
elif os.path.isdir(os.path.join(checkpoints_dir, filename)):
604+
checkpoints.append({"filename": filename, "date": None, "size": None})
605+
return {"checkpoints": checkpoints}
606+
584607

608+
# Fallback to using default adaptor directory as checkpoints directory
609+
checkpoints_dir = default_adaptor_dir
585610
checkpoints_file_filter = job_data.get("checkpoints_file_filter", "*_adapters.safetensors")
586611
if not checkpoints_file_filter:
587612
checkpoints_file_filter = "*_adapters.safetensors"
588613

589-
# print(f"Checkpoints directory: {checkpoints_dir}")
590-
# print(f"Checkpoints file filter: {checkpoints_file_filter}")
591614

592615
checkpoints = []
593616
try:

0 commit comments

Comments
 (0)