Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
29bec2f
Move jobs into experiments dir
deep1401 Mar 23, 2026
f469059
imports
deep1401 Mar 23, 2026
9cd8470
Merge branch 'main' into fix/jobs-list-fetcher
deep1401 Mar 23, 2026
7a2d476
fix pytest
deep1401 Mar 23, 2026
0335fff
profiling fixes
deep1401 Mar 23, 2026
3fd52a9
sdk
deep1401 Mar 23, 2026
97c0395
Merge branch 'main' into fix/jobs-list-fetcher
deep1401 Mar 23, 2026
83221fa
fixes for missed calls and have short and long uuids
deep1401 Mar 23, 2026
d6591d7
Merge branch 'fix/jobs-list-fetcher' of https://github.com/transforme…
deep1401 Mar 23, 2026
c48aa9b
delete async_run_python_script_and_update_status
deep1401 Mar 23, 2026
6d5f4c0
Merge branch 'main' of https://github.com/transformerlab/transformerl…
deep1401 Mar 23, 2026
6987d3a
ruff
deep1401 Mar 23, 2026
81b55db
fix more missed calls
deep1401 Mar 23, 2026
04c7068
fix more jobslist
deep1401 Mar 23, 2026
382494d
jobs migration old to new places
deep1401 Mar 23, 2026
f6c046c
fix sweep tests
deep1401 Mar 23, 2026
9a3f5af
Merge branch 'main' into fix/jobs-list-fetcher
deep1401 Mar 23, 2026
b69064a
Merge branch 'main' into fix/jobs-list-fetcher
deep1401 Mar 23, 2026
bcf3728
Merge branch 'main' into fix/jobs-list-fetcher
deep1401 Mar 24, 2026
649ede2
fix all job_get and job_get_cached calls
deep1401 Mar 24, 2026
68c48e6
sdk version
deep1401 Mar 24, 2026
9a9cb59
avoid double checks
deep1401 Mar 24, 2026
ba0ec97
Merge branch 'main' into fix/jobs-list-fetcher
deep1401 Mar 24, 2026
0844dd4
change parseint calls for job id to string
deep1401 Mar 24, 2026
1d5e4d4
Merge branch 'fix/jobs-list-fetcher' of https://github.com/transforme…
deep1401 Mar 24, 2026
49fe4b0
more cleanup
deep1401 Mar 24, 2026
6c908a2
fix job sorting
deep1401 Mar 24, 2026
b7393b4
remove disable env var
deep1401 Mar 24, 2026
1bd62df
Ruff
dadmobile Mar 24, 2026
88b2894
frontend conflict
deep1401 Mar 25, 2026
8ff3f2c
fix
deep1401 Mar 25, 2026
fa80831
Merge branch 'main' into fix/jobs-list-fetcher
deep1401 Mar 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,5 @@ test-results/

# Secrets
**/.secrets

docs/superpowers/*
9 changes: 9 additions & 0 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ async def lifespan(app: FastAPI):
# Seed default admin user
await seed_default_admin_user()

# One-time migration: legacy workspace/jobs -> workspace/experiments/<exp_id>/jobs
# Runs in the background so it doesn't delay the API startup.
from transformerlab.services.migrate_jobs_to_experiment_dirs import start_jobs_migration_worker

await start_jobs_migration_worker()

# Create buckets/folders for all existing teams if cloud or localfs storage is enabled
tfl_remote_storage_enabled = os.getenv("TFL_REMOTE_STORAGE_ENABLED", "false").lower() == "true"
if tfl_remote_storage_enabled or (os.getenv("TFL_STORAGE_PROVIDER") == "localfs" and os.getenv("TFL_STORAGE_URI")):
Expand Down Expand Up @@ -163,6 +169,9 @@ async def lifespan(app: FastAPI):
await stop_sweep_status_worker()
await stop_remote_job_status_worker()
await stop_notification_worker()
from transformerlab.services.migrate_jobs_to_experiment_dirs import stop_jobs_migration_worker

await stop_jobs_migration_worker()
await db.close()
# Run the clean up function
cleanup_at_exit()
Expand Down
2 changes: 1 addition & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies = [
"soundfile==0.13.1",
"tensorboardX==2.6.2.2",
"timm==1.0.15",
"transformerlab==0.1.5",
"transformerlab==0.1.6",
"transformerlab-inference==0.2.52",
"transformers==4.57.1",
"wandb==0.23.1",
Expand Down
12 changes: 9 additions & 3 deletions api/test/api/test_compute_provider_sweep_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ async def fake_jobs_get_all(experiment_id, type="", status=""):

@pytest.mark.asyncio
async def test_check_sweep_status_contract(monkeypatch):
async def fake_job_get(job_id):
async def fake_job_get(job_id, experiment_id=None):
assert job_id == "123"
assert experiment_id == "exp-1"
return {
"id": "123",
"type": "SWEEP",
Expand All @@ -52,6 +53,7 @@ async def fake_job_get(job_id):

response = await compute_provider.check_sweep_status(
job_id="123",
experiment_id="exp-1",
user_and_team={"team_id": "team-1"},
session=None,
)
Expand All @@ -70,7 +72,8 @@ async def fake_job_get(job_id):

@pytest.mark.asyncio
async def test_check_sweep_status_non_sweep_raises(monkeypatch):
async def fake_job_get(_job_id):
async def fake_job_get(_job_id, experiment_id=None):
assert experiment_id == "exp-1"
return {
"id": "456",
"type": "REMOTE",
Expand All @@ -84,6 +87,7 @@ async def fake_job_get(_job_id):
with pytest.raises(HTTPException) as exc_info:
await compute_provider.check_sweep_status(
job_id="456",
experiment_id="exp-1",
user_and_team={"team_id": "team-1"},
session=None,
)
Expand All @@ -94,8 +98,9 @@ async def fake_job_get(_job_id):

@pytest.mark.asyncio
async def test_check_sweep_status_all_complete_true(monkeypatch):
async def fake_job_get(job_id):
async def fake_job_get(job_id, experiment_id=None):
assert job_id == "789"
assert experiment_id == "exp-1"
return {
"id": "789",
"type": "SWEEP",
Expand All @@ -116,6 +121,7 @@ async def fake_job_get(job_id):

response = await compute_provider.check_sweep_status(
job_id="789",
experiment_id="exp-1",
user_and_team={"team_id": "team-1"},
session=None,
)
Expand Down
2 changes: 1 addition & 1 deletion api/test/api/test_job_save_to_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def tmp_workspace(monkeypatch, tmp_path):
async def mock_get_workspace_dir():
return str(workspace)

async def mock_get_jobs_dir():
async def mock_get_jobs_dir(experiment_id: str):
return str(jobs_dir)

async def mock_get_datasets_dir():
Expand Down
6 changes: 1 addition & 5 deletions api/test/api/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,8 @@ def make_mock_adapter_info(overrides={}):
@pytest.mark.skip(reason="")
@pytest.mark.asyncio
@patch("transformerlab.routers.model.huggingfacemodel.get_model_details_from_huggingface", new_callable=AsyncMock)
@patch("transformerlab.routers.model.shared.async_run_python_script_and_update_status", new_callable=AsyncMock)
async def test_install_peft_mock(mock_run_script, mock_get_details, client):
async def test_install_peft_mock(mock_get_details, client):
mock_get_details.return_value = {"name": "dummy_adapter"}
mock_process = AsyncMock()
mock_process.returncode = 0
mock_run_script.return_value = mock_process

test_model_id = "unsloth_Llama-3.2-1B-Instruct"
test_peft_id = "dummy_adapter"
Expand Down
8 changes: 4 additions & 4 deletions api/test/api/test_remote_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_download_all_artifacts_endpoint():
"""
mock_job_service = Mock()

async def mock_get_all_artifact_paths(job_id, storage):
async def mock_get_all_artifact_paths(job_id, experiment_id, storage):
return ["path/to/artifact1.txt", "path/to/artifact2.png"]

mock_job_service.get_all_artifact_paths = mock_get_all_artifact_paths
Expand All @@ -44,7 +44,7 @@ async def mock_create_zip(file_paths, storage):
# Test 1: Successful download
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
response = loop.run_until_complete(download_all_artifacts("test_job_id"))
response = loop.run_until_complete(download_all_artifacts("test_job_id", "exp_1"))
loop.close()

assert response.status_code == 200
Expand All @@ -56,15 +56,15 @@ async def mock_create_zip(file_paths, storage):
assert len(create_zip_calls) == 1

# Test 2: No artifacts found
async def mock_get_all_artifact_paths_empty(job_id, storage):
async def mock_get_all_artifact_paths_empty(job_id, experiment_id, storage):
return []

mock_job_service.get_all_artifact_paths = mock_get_all_artifact_paths_empty
create_zip_calls.clear() # Reset call tracking

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
response_empty = loop.run_until_complete(download_all_artifacts("test_job_id_empty"))
response_empty = loop.run_until_complete(download_all_artifacts("test_job_id_empty", "exp_1"))
loop.close()

assert response_empty.status_code == 404
Expand Down
37 changes: 23 additions & 14 deletions api/transformerlab/routers/compute_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,7 +1256,10 @@ async def _launch_sweep_jobs(

# Create unique cluster name for this run
run_suffix = f"sweep-{i + 1}"
formatted_cluster_name = f"{_sanitize_cluster_basename(base_name)}-{run_suffix}-job-{parent_job_id}"
parent_job_short_id = job_service.get_short_job_id(parent_job_id)
formatted_cluster_name = (
f"{_sanitize_cluster_basename(base_name)}-{run_suffix}-job-{parent_job_short_id}"
)

# Create child job
child_job_id = await job_service.job_create(
Expand Down Expand Up @@ -1692,7 +1695,8 @@ async def launch_template_on_provider(
)

base_name = request.cluster_name or request.task_name or provider.name
formatted_cluster_name = f"{_sanitize_cluster_basename(base_name)}-job-{job_id}"
job_short_id = job_service.get_short_job_id(job_id)
formatted_cluster_name = f"{_sanitize_cluster_basename(base_name)}-job-{job_short_id}"

user_info = {}
if getattr(user, "first_name", None) or getattr(user, "last_name", None):
Expand Down Expand Up @@ -1850,7 +1854,7 @@ async def launch_template_on_provider(
if request.enable_trackio:
env_vars["TLAB_TRACKIO_AUTO_INIT"] = "true"
project_name = (request.trackio_project_name or "").strip() or str(request.experiment_id)
trackio_run_name = f"{request.task_name or 'task'}-job-{job_id}"
trackio_run_name = f"{request.task_name or 'task'}-job-{job_short_id}"
trackio_project_name_for_job = project_name
trackio_run_name_for_job = trackio_run_name
env_vars["TLAB_TRACKIO_PROJECT_NAME"] = project_name
Expand Down Expand Up @@ -2026,7 +2030,7 @@ async def launch_template_on_provider(
task_dir_root = await get_task_dir()
task_src = storage.join(task_dir_root, secure_filename(str(request.task_id)))
if await storage.isdir(task_src):
workspace_job_dir = await get_job_dir(job_id)
workspace_job_dir = await get_job_dir(job_id, request.experiment_id)
await _copy_task_files_to_dir(task_src, workspace_job_dir)

job_data = {
Expand Down Expand Up @@ -2210,6 +2214,7 @@ async def launch_template_on_provider(
@router.get("/jobs/{job_id}/check-status")
async def check_provider_job_status(
job_id: str,
experiment_id: str = Query(..., description="Experiment ID for this job"),
user_and_team=Depends(get_user_and_team),
):
"""
Expand All @@ -2221,7 +2226,7 @@ async def check_provider_job_status(
intentionally side-effect-free so that frequent frontend polling
never blocks on provider latency or downtime.
"""
job = await job_service.job_get(job_id)
job = await job_service.job_get(job_id, experiment_id=experiment_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")

Expand Down Expand Up @@ -2254,7 +2259,9 @@ async def ensure_quota_recorded_for_completed_jobs(
if job_id:
# Check specific job
# Pass team_id from user_and_team context
quota_recorded = await quota_service.ensure_quota_recorded_for_completed_job(session, job_id, team_id=team_id)
quota_recorded = await quota_service.ensure_quota_recorded_for_completed_job(
session, job_id, experiment_id=experiment_id, team_id=team_id
)
return {
"status": "success",
"job_id": job_id,
Expand Down Expand Up @@ -2284,7 +2291,7 @@ async def ensure_quota_recorded_for_completed_jobs(
if job_id_str:
# Pass team_id from user_and_team context
quota_recorded = await quota_service.ensure_quota_recorded_for_completed_job(
session, job_id_str, team_id=team_id
session, job_id_str, experiment_id=experiment_id, team_id=team_id
)
if quota_recorded:
jobs_recorded += 1
Expand Down Expand Up @@ -2323,14 +2330,15 @@ async def check_sweep_status_all(
@router.get("/jobs/{job_id}/sweep-status")
async def check_sweep_status(
job_id: str,
experiment_id: str = Query(..., description="Experiment ID for this sweep job"),
user_and_team=Depends(get_user_and_team),
session: AsyncSession = Depends(get_async_session),
):
"""
Check status of a specific sweep job from current persisted values.
Returns current sweep status with counts and job data.
"""
job = await job_service.job_get(job_id)
job = await job_service.job_get(job_id, experiment_id=experiment_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")

Expand Down Expand Up @@ -2359,6 +2367,7 @@ async def check_sweep_status(
@router.get("/jobs/{job_id}/sweep-results")
async def get_sweep_results(
job_id: str,
experiment_id: str = Query(..., description="Experiment ID for this sweep job"),
user_and_team=Depends(get_user_and_team),
session: AsyncSession = Depends(get_async_session),
):
Expand All @@ -2368,7 +2377,7 @@ async def get_sweep_results(
"""

# Get the parent sweep job
job = await job_service.job_get(job_id)
job = await job_service.job_get(job_id, experiment_id=experiment_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")

Expand All @@ -2379,7 +2388,6 @@ async def get_sweep_results(
if not job_data.get("sweep_parent"):
raise HTTPException(status_code=400, detail="Job is not a sweep parent")

experiment_id = job.get("experiment_id")
sweep_job_ids = job_data.get("sweep_job_ids", [])
sweep_metric = job_data.get("sweep_metric", "eval/loss")
lower_is_better = job_data.get("lower_is_better", True)
Expand All @@ -2392,7 +2400,7 @@ async def get_sweep_results(
best_job_id = None

for child_job_id in sweep_job_ids:
child_job = await job_service.job_get(child_job_id)
child_job = await job_service.job_get(child_job_id, experiment_id=experiment_id)
if not child_job:
continue

Expand Down Expand Up @@ -2486,7 +2494,7 @@ async def resume_from_checkpoint(
import time

# Get the original job
original_job = await job_service.job_get(job_id)
original_job = await job_service.job_get(job_id, experiment_id=experimentId)
if not original_job or str(original_job.get("experiment_id")) != str(experimentId):
raise HTTPException(status_code=404, detail="Job not found")

Expand All @@ -2512,7 +2520,7 @@ async def resume_from_checkpoint(
)

# Verify checkpoint exists using workspace-aware path resolution
checkpoints_dir = await get_job_checkpoints_dir(job_id)
checkpoints_dir = await get_job_checkpoints_dir(job_id, experimentId)
checkpoint_path = storage.join(checkpoints_dir, request.checkpoint)
if not await storage.exists(checkpoint_path):
raise HTTPException(status_code=404, detail=f"Checkpoint '{request.checkpoint}' not found")
Expand Down Expand Up @@ -2579,7 +2587,8 @@ async def resume_from_checkpoint(

# Build cluster name
base_name = job_data.get("task_name") or provider.name
formatted_cluster_name = f"{_sanitize_cluster_basename(base_name)}-job-{new_job_id}"
new_job_short_id = job_service.get_short_job_id(new_job_id)
formatted_cluster_name = f"{_sanitize_cluster_basename(base_name)}-job-{new_job_short_id}"

# Get user info
user = user_and_team.get("user")
Expand Down
Loading
Loading