Skip to content

Commit 15472e2

Browse files
committed
merge conflict
2 parents 6a3694a + 325a2b1 commit 15472e2

File tree

22 files changed

+502
-165
lines changed

22 files changed

+502
-165
lines changed

test/api/test_experiment_jobs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,15 @@ def test_job_edge_cases(client):
127127

128128
resp = client.get("/experiment/alpha/jobs/create?type=DOWNLOAD_MODEL&status=QUEUED&data={}")
129129
assert resp.status_code == 200
130+
131+
132+
def test_train_sweep_results(client):
133+
resp = client.get("/experiment/alpha/jobs/1/sweep_results")
134+
assert resp.status_code == 200
135+
data = resp.json()
136+
assert "status" in data
137+
assert data["status"] in ("success", "error")
138+
if data["status"] == "success":
139+
assert "data" in data
140+
else:
141+
assert "message" in data

test/api/test_remote.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import pytest
2+
from unittest.mock import patch, AsyncMock, MagicMock
3+
from io import BytesIO
4+
5+
6+
# Use the client fixture from conftest.py - no need to create our own
7+
8+
9+
@pytest.fixture
10+
def gpu_orchestration_env_vars(monkeypatch):
11+
"""Set up GPU orchestration environment variables for testing"""
12+
monkeypatch.setenv("GPU_ORCHESTRATION_SERVER", "http://test-orchestrator.example.com")
13+
monkeypatch.setenv("GPU_ORCHESTRATION_SERVER_PORT", "8080")
14+
yield
15+
# Cleanup - remove env vars after test
16+
monkeypatch.delenv("GPU_ORCHESTRATION_SERVER", raising=False)
17+
monkeypatch.delenv("GPU_ORCHESTRATION_SERVER_PORT", raising=False)
18+
19+
20+
@pytest.fixture
21+
def no_gpu_orchestration_env(monkeypatch):
22+
"""Ensure GPU orchestration environment variables are not set"""
23+
monkeypatch.delenv("GPU_ORCHESTRATION_SERVER", raising=False)
24+
monkeypatch.delenv("GPU_ORCHESTRATION_SERVER_PORT", raising=False)
25+
26+
27+
@pytest.fixture
28+
def mock_experiment_id(client):
29+
"""Create a test experiment and return its ID, cleaning up after the test"""
30+
import os
31+
import time
32+
import uuid
33+
from transformerlab.services import experiment_service
34+
35+
# Use a unique name to avoid conflicts - add UUID for better uniqueness
36+
unique_name = f"test_exp_remote_{os.getpid()}_{int(time.time())}_{uuid.uuid4().hex[:8]}"
37+
38+
# Check if experiment already exists and delete it if it does
39+
existing = experiment_service.experiment_get(unique_name)
40+
if existing:
41+
experiment_service.experiment_delete(unique_name)
42+
43+
# Create the experiment
44+
exp_id = experiment_service.experiment_create(unique_name, {})
45+
46+
yield exp_id
47+
48+
# Cleanup: delete all jobs in the experiment, then delete the experiment
49+
try:
50+
from transformerlab.services import job_service
51+
job_service.job_delete_all(exp_id)
52+
except Exception:
53+
pass
54+
55+
try:
56+
experiment_service.experiment_delete(exp_id)
57+
except Exception:
58+
# Ignore errors during cleanup
59+
pass
60+
61+
62+
@pytest.fixture
63+
def job_cleanup():
64+
"""Fixture to track and cleanup jobs created during tests"""
65+
created_jobs = [] # List of (job_id, experiment_id) tuples
66+
67+
yield created_jobs
68+
69+
# Cleanup: delete all tracked jobs
70+
from transformerlab.services import job_service
71+
for job_id, experiment_id in created_jobs:
72+
try:
73+
job_service.job_delete(job_id, experiment_id)
74+
except Exception:
75+
# Ignore errors during cleanup
76+
pass
77+
78+
79+
class TestValidateGPUOrchestratorEnvVars:
80+
"""Test the validate_gpu_orchestrator_env_vars function"""
81+
82+
def test_validate_env_vars_with_both_set(self, gpu_orchestration_env_vars):
83+
"""Test validation when both env vars are set"""
84+
from transformerlab.routers.remote import validate_gpu_orchestrator_env_vars
85+
86+
url, port = validate_gpu_orchestrator_env_vars()
87+
assert url == "http://test-orchestrator.example.com"
88+
assert port == "8080"
89+
90+
def test_validate_env_vars_missing_url(self, monkeypatch):
91+
"""Test validation when GPU_ORCHESTRATION_SERVER is missing"""
92+
monkeypatch.delenv("GPU_ORCHESTRATION_SERVER", raising=False)
93+
monkeypatch.setenv("GPU_ORCHESTRATION_SERVER_PORT", "8080")
94+
95+
from transformerlab.routers.remote import validate_gpu_orchestrator_env_vars
96+
97+
result = validate_gpu_orchestrator_env_vars()
98+
url, error_response = result
99+
assert url is None
100+
assert isinstance(error_response, dict)
101+
assert error_response["status"] == "error"
102+
assert "GPU_ORCHESTRATION_SERVER" in error_response["message"]
103+
104+
def test_validate_env_vars_missing_port(self, monkeypatch):
105+
"""Test validation when GPU_ORCHESTRATION_SERVER_PORT is missing"""
106+
monkeypatch.setenv("GPU_ORCHESTRATION_SERVER", "http://test-orchestrator.example.com")
107+
monkeypatch.delenv("GPU_ORCHESTRATION_SERVER_PORT", raising=False)
108+
109+
from transformerlab.routers.remote import validate_gpu_orchestrator_env_vars
110+
111+
result = validate_gpu_orchestrator_env_vars()
112+
url, error_response = result
113+
# When port is missing, the function returns None as the URL
114+
assert url is None
115+
assert isinstance(error_response, dict)
116+
assert error_response["status"] == "error"
117+
assert "GPU_ORCHESTRATION_SERVER_PORT" in error_response["message"]
118+
119+
120+
121+
122+
class TestStopRemote:
123+
"""Test the /remote/stop endpoint"""
124+
125+
126+
def test_stop_remote_missing_env_vars(self, client, no_gpu_orchestration_env):
127+
"""Test stopping when GPU orchestration env vars are not set"""
128+
response = client.post(
129+
"/remote/stop",
130+
data={
131+
"job_id": "test-job-id",
132+
"cluster_name": "test-cluster",
133+
},
134+
)
135+
assert response.status_code == 200
136+
data = response.json()
137+
assert data["status"] == "error"
138+
assert "GPU_ORCHESTRATION_SERVER" in data["message"]
139+
140+
141+
class TestUploadDirectory:
142+
"""Test the /remote/upload endpoint"""
143+
144+
@patch("transformerlab.routers.remote.httpx.AsyncClient")
145+
def test_upload_directory_success(self, mock_client_class, client, gpu_orchestration_env_vars):
146+
"""Test uploading a directory successfully"""
147+
# Create test files using BytesIO for proper file-like objects
148+
files = [
149+
("dir_files", ("test1.txt", BytesIO(b"content1"), "text/plain")),
150+
("dir_files", ("test2.txt", BytesIO(b"content2"), "text/plain")),
151+
]
152+
153+
# Mock the async client and response
154+
mock_response = MagicMock()
155+
mock_response.status_code = 200
156+
mock_response.json.return_value = {
157+
"status": "uploaded",
158+
"upload_path": "/remote/path",
159+
}
160+
161+
mock_httpx_client = AsyncMock()
162+
mock_httpx_client.post = AsyncMock(return_value=mock_response)
163+
mock_httpx_client.__aenter__.return_value = mock_httpx_client
164+
mock_httpx_client.__aexit__.return_value = None
165+
mock_client_class.return_value = mock_httpx_client
166+
167+
response = client.post(
168+
"/remote/upload",
169+
files=files,
170+
data={"dir_name": "test-dir"},
171+
)
172+
assert response.status_code == 200
173+
data = response.json()
174+
assert data["status"] == "success"
175+
assert "local_storage_path" in data
176+
177+
def test_upload_directory_missing_env_vars(self, client, no_gpu_orchestration_env):
178+
"""Test uploading when GPU orchestration env vars are not set"""
179+
files = [("dir_files", ("test.txt", BytesIO(b"content"), "text/plain"))]
180+
181+
response = client.post(
182+
"/remote/upload",
183+
files=files,
184+
)
185+
assert response.status_code == 200
186+
data = response.json()
187+
assert data["status"] == "error"
188+
assert "GPU_ORCHESTRATION_SERVER" in data["message"]
189+
190+
191+
class TestCheckRemoteJobStatus:
192+
"""Test the /remote/check-status endpoint"""
193+
194+
@patch("transformerlab.routers.remote.httpx.AsyncClient")
195+
def test_check_status_no_launching_jobs(self, mock_client_class, client, gpu_orchestration_env_vars):
196+
"""Test checking status when there are no LAUNCHING jobs"""
197+
response = client.get("/remote/check-status")
198+
assert response.status_code == 200
199+
data = response.json()
200+
assert "updated_jobs" in data
201+
assert data["updated_jobs"] == []
202+
203+
204+
class TestGetOrchestratorLogs:
205+
"""Test the /remote/logs/{request_id} endpoint"""
206+
207+
@patch("transformerlab.routers.remote.httpx.AsyncClient")
208+
def test_get_logs_success(self, mock_client_class, client, gpu_orchestration_env_vars):
209+
"""Test getting logs successfully"""
210+
request_id = "test-request-123"
211+
212+
# Mock streaming response
213+
mock_stream_response = MagicMock()
214+
mock_stream_response.status_code = 200
215+
mock_stream_response.aiter_bytes = AsyncMock(return_value=iter([b"log line 1\n", b"log line 2\n"]))
216+
217+
mock_httpx_client = AsyncMock()
218+
mock_httpx_client.stream = AsyncMock(return_value=mock_stream_response)
219+
mock_httpx_client.__aenter__.return_value = mock_httpx_client
220+
mock_httpx_client.__aexit__.return_value = None
221+
mock_client_class.return_value = mock_httpx_client
222+
223+
response = client.get(f"/remote/logs/{request_id}")
224+
# Streaming responses return 200 with appropriate headers
225+
assert response.status_code == 200
226+
227+
def test_get_logs_missing_env_vars(self, client, no_gpu_orchestration_env):
228+
"""Test getting logs when GPU orchestration env vars are not set"""
229+
response = client.get("/remote/logs/test-request-123")
230+
# The endpoint should return an error response
231+
assert response.status_code in [200, 500] # May return error as JSON or raise exception
232+

test/api/test_server_info.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,50 @@ def fake_check_output(*args, **kwargs):
6161
from transformerlab.routers import serverinfo
6262

6363
assert serverinfo.is_wsl() is False
64+
65+
66+
def test_healthz_local_mode(client, monkeypatch):
67+
"""Test healthz endpoint in local mode (no GPU orchestration)"""
68+
# Ensure GPU orchestration env vars are not set
69+
monkeypatch.delenv("GPU_ORCHESTRATION_SERVER", raising=False)
70+
monkeypatch.delenv("GPU_ORCHESTRATION_SERVER_PORT", raising=False)
71+
72+
response = client.get("/healthz")
73+
assert response.status_code == 200
74+
data = response.json()
75+
assert data["message"] == "OK"
76+
assert data["mode"] == "local"
77+
assert data["gpu_orchestration_server"] == ""
78+
assert data["gpu_orchestration_server_port"] == ""
79+
80+
81+
def test_healthz_gpu_orchestration_mode(client, monkeypatch):
82+
"""Test healthz endpoint in GPU orchestration mode"""
83+
# Set GPU orchestration env vars
84+
monkeypatch.setenv("GPU_ORCHESTRATION_SERVER", "http://orchestrator.example.com")
85+
monkeypatch.setenv("GPU_ORCHESTRATION_SERVER_PORT", "8080")
86+
87+
# The healthz endpoint reads env vars at request time, so monkeypatch should work
88+
response = client.get("/healthz")
89+
assert response.status_code == 200
90+
data = response.json()
91+
assert data["message"] == "OK"
92+
assert data["mode"] == "gpu_orchestration"
93+
assert data["gpu_orchestration_server"] == "http://orchestrator.example.com"
94+
assert data["gpu_orchestration_server_port"] == "8080"
95+
96+
97+
def test_healthz_gpu_orchestration_mode_no_port(client, monkeypatch):
98+
"""Test healthz endpoint in GPU orchestration mode without port"""
99+
# Set only GPU orchestration server URL
100+
monkeypatch.setenv("GPU_ORCHESTRATION_SERVER", "http://orchestrator.example.com")
101+
monkeypatch.delenv("GPU_ORCHESTRATION_SERVER_PORT", raising=False)
102+
103+
# The healthz endpoint reads env vars at request time, so monkeypatch should work
104+
response = client.get("/healthz")
105+
assert response.status_code == 200
106+
data = response.json()
107+
assert data["message"] == "OK"
108+
assert data["mode"] == "gpu_orchestration"
109+
assert data["gpu_orchestration_server"] == "http://orchestrator.example.com"
110+
assert data["gpu_orchestration_server_port"] == ""

test/api/test_train.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,3 @@ def test_train_create_template(client):
1313
data = {"name": "test_template", "description": "desc", "type": "test", "config": "{}"}
1414
resp = client.post("/train/template/create", data=data)
1515
assert resp.status_code in (200, 422, 400)
16-
17-
18-
def test_train_sweep_results(client):
19-
resp = client.get("/train/job/1/sweep_results")
20-
assert resp.status_code == 200
21-
data = resp.json()
22-
assert "status" in data
23-
assert data["status"] in ("success", "error")
24-
if data["status"] == "success":
25-
assert "data" in data
26-
else:
27-
assert "message" in data

transformerlab/plugins/mlx_exporter/index.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"description": "Exports the current model to MLX format so it can be run on Apple Silicon.",
55
"plugin-format": "python",
66
"type": "exporter",
7-
"version": "1.0.22",
7+
"version": "1.0.23",
88
"model_architectures": [
99
"CohereForCausalLM",
1010
"DeepseekV2ForCausalLM",
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/usr/bin/env bash
2-
uv pip install mlx==0.27.1 --upgrade
3-
uv pip install "mlx-lm==0.26.3" --upgrade
2+
uv pip install mlx==0.29.3 --upgrade
3+
uv pip install "mlx-lm==0.28.3" --upgrade

transformerlab/plugins/mlx_lora_trainer/index.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"description": "MLX Machine learning research on your laptop or in a data center - by Apple",
55
"plugin-format": "python",
66
"type": "trainer",
7-
"version": "0.4.21",
7+
"version": "0.4.22",
88
"model_architectures": [
99
"MLX",
1010
"CohereForCausalLM",
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#!/usr/bin/env bash
22
uv pip install trl
3-
uv pip install mlx==0.27.1 --upgrade
4-
uv pip install "mlx-lm==0.26.3" --upgrade
3+
uv pip install mlx==0.29.3 --upgrade
4+
uv pip install "mlx-lm==0.28.3" --upgrade

transformerlab/plugins/mlx_rlaif_trainer/index.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"description": "MLX PPO (Proximal Policy Optimization) Reinforcement Learning from AI Feedback (RLAIF) trainer for MLX models.",
55
"plugin-format": "python",
66
"type": "trainer",
7-
"version": "0.1.6",
7+
"version": "0.1.7",
88
"model_architectures": [
99
"MLX",
1010
"LlamaForCausalLM",
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env bash
2-
uv pip install mlx==0.27.1 --upgrade
3-
uv pip install "mlx-lm==0.26.3" --upgrade
2+
uv pip install mlx==0.29.3 --upgrade
3+
uv pip install "mlx-lm==0.28.3" --upgrade
44
uv pip install "mlx_embedding_models==0.0.11"
55
uv pip install --upgrade wandb

0 commit comments

Comments
 (0)