diff --git a/.gitignore b/.gitignore index 442d30330..3906306ae 100644 --- a/.gitignore +++ b/.gitignore @@ -224,3 +224,5 @@ test-results/ # Secrets **/.secrets + +docs/superpowers/* diff --git a/api/api.py b/api/api.py index c90ac6167..1bf04741b 100644 --- a/api/api.py +++ b/api/api.py @@ -119,6 +119,12 @@ async def lifespan(app: FastAPI): # Seed default admin user await seed_default_admin_user() + # One-time migration: legacy workspace/jobs -> workspace/experiments//jobs + # Runs in the background so it doesn't delay the API startup. + from transformerlab.services.migrate_jobs_to_experiment_dirs import start_jobs_migration_worker + + await start_jobs_migration_worker() + # Create buckets/folders for all existing teams if cloud or localfs storage is enabled tfl_remote_storage_enabled = os.getenv("TFL_REMOTE_STORAGE_ENABLED", "false").lower() == "true" if tfl_remote_storage_enabled or (os.getenv("TFL_STORAGE_PROVIDER") == "localfs" and os.getenv("TFL_STORAGE_URI")): @@ -163,6 +169,9 @@ async def lifespan(app: FastAPI): await stop_sweep_status_worker() await stop_remote_job_status_worker() await stop_notification_worker() + from transformerlab.services.migrate_jobs_to_experiment_dirs import stop_jobs_migration_worker + + await stop_jobs_migration_worker() await db.close() # Run the clean up function cleanup_at_exit() diff --git a/api/pyproject.toml b/api/pyproject.toml index f8f60b599..42c28155b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ "soundfile==0.13.1", "tensorboardX==2.6.2.2", "timm==1.0.15", - "transformerlab==0.1.5", + "transformerlab==0.1.6", "transformerlab-inference==0.2.52", "transformers==4.57.1", "wandb==0.23.1", diff --git a/api/test/api/test_compute_provider_sweep_contract.py b/api/test/api/test_compute_provider_sweep_contract.py index 110c77ad6..b5b284e38 100644 --- a/api/test/api/test_compute_provider_sweep_contract.py +++ b/api/test/api/test_compute_provider_sweep_contract.py @@ -30,8 +30,9 @@ async def fake_jobs_get_all(experiment_id, type="", status=""): @pytest.mark.asyncio async def test_check_sweep_status_contract(monkeypatch): - async def fake_job_get(job_id): + async def fake_job_get(job_id, experiment_id=None): assert job_id == "123" + assert experiment_id == "exp-1" return { "id": "123", "type": "SWEEP", @@ -52,6 +53,7 @@ async def fake_job_get(job_id): response = await compute_provider.check_sweep_status( job_id="123", + experiment_id="exp-1", user_and_team={"team_id": "team-1"}, session=None, ) @@ -70,7 +72,8 @@ async def fake_job_get(job_id): @pytest.mark.asyncio async def test_check_sweep_status_non_sweep_raises(monkeypatch): - async def fake_job_get(_job_id): + async def fake_job_get(_job_id, experiment_id=None): + assert experiment_id == "exp-1" return { "id": "456", "type": "REMOTE", @@ -84,6 +87,7 @@ async def fake_job_get(_job_id): with pytest.raises(HTTPException) as exc_info: await compute_provider.check_sweep_status( job_id="456", + experiment_id="exp-1", user_and_team={"team_id": "team-1"}, session=None, ) @@ -94,8 +98,9 @@ async def fake_job_get(_job_id): @pytest.mark.asyncio async def test_check_sweep_status_all_complete_true(monkeypatch): - async def fake_job_get(job_id): + async def fake_job_get(job_id, experiment_id=None): assert job_id == "789" + assert experiment_id == "exp-1" return { "id": "789", "type": "SWEEP", @@ -116,6 +121,7 @@ async def fake_job_get(job_id): response = await compute_provider.check_sweep_status( job_id="789", + experiment_id="exp-1", user_and_team={"team_id": "team-1"}, session=None, ) diff --git a/api/test/api/test_job_save_to_registry.py b/api/test/api/test_job_save_to_registry.py index 196bf093c..0036d60bc 100644 --- a/api/test/api/test_job_save_to_registry.py +++ b/api/test/api/test_job_save_to_registry.py @@ -19,7 +19,7 @@ def tmp_workspace(monkeypatch, tmp_path): async def mock_get_workspace_dir(): return str(workspace) - async def mock_get_jobs_dir(): + async def mock_get_jobs_dir(experiment_id: str): return str(jobs_dir) async def mock_get_datasets_dir(): diff --git a/api/test/api/test_model.py b/api/test/api/test_model.py index 3f4e5d2d4..e3dd42cea 100644 --- a/api/test/api/test_model.py +++ b/api/test/api/test_model.py @@ -27,12 +27,8 @@ def make_mock_adapter_info(overrides={}): @pytest.mark.skip(reason="") @pytest.mark.asyncio @patch("transformerlab.routers.model.huggingfacemodel.get_model_details_from_huggingface", new_callable=AsyncMock) -@patch("transformerlab.routers.model.shared.async_run_python_script_and_update_status", new_callable=AsyncMock) -async def test_install_peft_mock(mock_run_script, mock_get_details, client): +async def test_install_peft_mock(mock_get_details, client): mock_get_details.return_value = {"name": "dummy_adapter"} - mock_process = AsyncMock() - mock_process.returncode = 0 - mock_run_script.return_value = mock_process test_model_id = "unsloth_Llama-3.2-1B-Instruct" test_peft_id = "dummy_adapter" diff --git a/api/test/api/test_remote_workspace.py b/api/test/api/test_remote_workspace.py index 28bb8f880..8f01b40d2 100644 --- a/api/test/api/test_remote_workspace.py +++ b/api/test/api/test_remote_workspace.py @@ -20,7 +20,7 @@ def test_download_all_artifacts_endpoint(): """ mock_job_service = Mock() - async def mock_get_all_artifact_paths(job_id, storage): + async def mock_get_all_artifact_paths(job_id, experiment_id, storage): return ["path/to/artifact1.txt", "path/to/artifact2.png"] mock_job_service.get_all_artifact_paths = mock_get_all_artifact_paths @@ -44,7 +44,7 @@ async def mock_create_zip(file_paths, storage): # Test 1: Successful download loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - response = loop.run_until_complete(download_all_artifacts("test_job_id")) + response = loop.run_until_complete(download_all_artifacts("test_job_id", "exp_1")) loop.close() assert response.status_code == 200 @@ -56,7 +56,7 @@ async def mock_create_zip(file_paths, storage): assert len(create_zip_calls) == 1 # Test 2: No artifacts found - async def mock_get_all_artifact_paths_empty(job_id, storage): + async def mock_get_all_artifact_paths_empty(job_id, experiment_id, storage): return [] mock_job_service.get_all_artifact_paths = mock_get_all_artifact_paths_empty @@ -64,7 +64,7 @@ async def mock_get_all_artifact_paths_empty(job_id, storage): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - response_empty = loop.run_until_complete(download_all_artifacts("test_job_id_empty")) + response_empty = loop.run_until_complete(download_all_artifacts("test_job_id_empty", "exp_1")) loop.close() assert response_empty.status_code == 404 diff --git a/api/transformerlab/routers/compute_provider.py b/api/transformerlab/routers/compute_provider.py index 9f4206349..b906d1de8 100644 --- a/api/transformerlab/routers/compute_provider.py +++ b/api/transformerlab/routers/compute_provider.py @@ -1256,7 +1256,10 @@ async def _launch_sweep_jobs( # Create unique cluster name for this run run_suffix = f"sweep-{i + 1}" - formatted_cluster_name = f"{_sanitize_cluster_basename(base_name)}-{run_suffix}-job-{parent_job_id}" + parent_job_short_id = job_service.get_short_job_id(parent_job_id) + formatted_cluster_name = ( + f"{_sanitize_cluster_basename(base_name)}-{run_suffix}-job-{parent_job_short_id}" + ) # Create child job child_job_id = await job_service.job_create( @@ -1692,7 +1695,8 @@ async def launch_template_on_provider( ) base_name = request.cluster_name or request.task_name or provider.name - formatted_cluster_name = f"{_sanitize_cluster_basename(base_name)}-job-{job_id}" + job_short_id = job_service.get_short_job_id(job_id) + formatted_cluster_name = f"{_sanitize_cluster_basename(base_name)}-job-{job_short_id}" user_info = {} if getattr(user, "first_name", None) or getattr(user, "last_name", None): @@ -1850,7 +1854,7 @@ async def launch_template_on_provider( if request.enable_trackio: env_vars["TLAB_TRACKIO_AUTO_INIT"] = "true" project_name = (request.trackio_project_name or "").strip() or str(request.experiment_id) - trackio_run_name = f"{request.task_name or 'task'}-job-{job_id}" + trackio_run_name = f"{request.task_name or 'task'}-job-{job_short_id}" trackio_project_name_for_job = project_name trackio_run_name_for_job = trackio_run_name env_vars["TLAB_TRACKIO_PROJECT_NAME"] = project_name @@ -2026,7 +2030,7 @@ async def launch_template_on_provider( task_dir_root = await get_task_dir() task_src = storage.join(task_dir_root, secure_filename(str(request.task_id))) if await storage.isdir(task_src): - workspace_job_dir = await get_job_dir(job_id) + workspace_job_dir = await get_job_dir(job_id, request.experiment_id) await _copy_task_files_to_dir(task_src, workspace_job_dir) job_data = { @@ -2210,6 +2214,7 @@ async def launch_template_on_provider( @router.get("/jobs/{job_id}/check-status") async def check_provider_job_status( job_id: str, + experiment_id: str = Query(..., description="Experiment ID for this job"), user_and_team=Depends(get_user_and_team), ): """ @@ -2221,7 +2226,7 @@ async def check_provider_job_status( intentionally side-effect-free so that frequent frontend polling never blocks on provider latency or downtime. """ - job = await job_service.job_get(job_id) + job = await job_service.job_get(job_id, experiment_id=experiment_id) if not job: raise HTTPException(status_code=404, detail="Job not found") @@ -2254,7 +2259,9 @@ async def ensure_quota_recorded_for_completed_jobs( if job_id: # Check specific job # Pass team_id from user_and_team context - quota_recorded = await quota_service.ensure_quota_recorded_for_completed_job(session, job_id, team_id=team_id) + quota_recorded = await quota_service.ensure_quota_recorded_for_completed_job( + session, job_id, experiment_id=experiment_id, team_id=team_id + ) return { "status": "success", "job_id": job_id, @@ -2284,7 +2291,7 @@ async def ensure_quota_recorded_for_completed_jobs( if job_id_str: # Pass team_id from user_and_team context quota_recorded = await quota_service.ensure_quota_recorded_for_completed_job( - session, job_id_str, team_id=team_id + session, job_id_str, experiment_id=experiment_id, team_id=team_id ) if quota_recorded: jobs_recorded += 1 @@ -2323,6 +2330,7 @@ async def check_sweep_status_all( @router.get("/jobs/{job_id}/sweep-status") async def check_sweep_status( job_id: str, + experiment_id: str = Query(..., description="Experiment ID for this sweep job"), user_and_team=Depends(get_user_and_team), session: AsyncSession = Depends(get_async_session), ): @@ -2330,7 +2338,7 @@ async def check_sweep_status( Check status of a specific sweep job from current persisted values. Returns current sweep status with counts and job data. """ - job = await job_service.job_get(job_id) + job = await job_service.job_get(job_id, experiment_id=experiment_id) if not job: raise HTTPException(status_code=404, detail="Job not found") @@ -2359,6 +2367,7 @@ async def check_sweep_status( @router.get("/jobs/{job_id}/sweep-results") async def get_sweep_results( job_id: str, + experiment_id: str = Query(..., description="Experiment ID for this sweep job"), user_and_team=Depends(get_user_and_team), session: AsyncSession = Depends(get_async_session), ): @@ -2368,7 +2377,7 @@ async def get_sweep_results( """ # Get the parent sweep job - job = await job_service.job_get(job_id) + job = await job_service.job_get(job_id, experiment_id=experiment_id) if not job: raise HTTPException(status_code=404, detail="Job not found") @@ -2379,7 +2388,6 @@ async def get_sweep_results( if not job_data.get("sweep_parent"): raise HTTPException(status_code=400, detail="Job is not a sweep parent") - experiment_id = job.get("experiment_id") sweep_job_ids = job_data.get("sweep_job_ids", []) sweep_metric = job_data.get("sweep_metric", "eval/loss") lower_is_better = job_data.get("lower_is_better", True) @@ -2392,7 +2400,7 @@ async def get_sweep_results( best_job_id = None for child_job_id in sweep_job_ids: - child_job = await job_service.job_get(child_job_id) + child_job = await job_service.job_get(child_job_id, experiment_id=experiment_id) if not child_job: continue @@ -2486,7 +2494,7 @@ async def resume_from_checkpoint( import time # Get the original job - original_job = await job_service.job_get(job_id) + original_job = await job_service.job_get(job_id, experiment_id=experimentId) if not original_job or str(original_job.get("experiment_id")) != str(experimentId): raise HTTPException(status_code=404, detail="Job not found") @@ -2512,7 +2520,7 @@ async def resume_from_checkpoint( ) # Verify checkpoint exists using workspace-aware path resolution - checkpoints_dir = await get_job_checkpoints_dir(job_id) + checkpoints_dir = await get_job_checkpoints_dir(job_id, experimentId) checkpoint_path = storage.join(checkpoints_dir, request.checkpoint) if not await storage.exists(checkpoint_path): raise HTTPException(status_code=404, detail=f"Checkpoint '{request.checkpoint}' not found") @@ -2579,7 +2587,8 @@ async def resume_from_checkpoint( # Build cluster name base_name = job_data.get("task_name") or provider.name - formatted_cluster_name = f"{_sanitize_cluster_basename(base_name)}-job-{new_job_id}" + new_job_short_id = job_service.get_short_job_id(new_job_id) + formatted_cluster_name = f"{_sanitize_cluster_basename(base_name)}-job-{new_job_short_id}" # Get user info user = user_and_team.get("user") diff --git a/api/transformerlab/routers/experiment/jobs.py b/api/transformerlab/routers/experiment/jobs.py index 90b02a97a..ce3d1e027 100644 --- a/api/transformerlab/routers/experiment/jobs.py +++ b/api/transformerlab/routers/experiment/jobs.py @@ -109,21 +109,21 @@ async def job_delete_all(experimentId: str): @router.get("/{job_id}") -async def get_training_job(job_id: str): - job = await job_service.job_get_cached(job_id) +async def get_training_job(job_id: str, experimentId: str): + job = await job_service.job_get_cached(job_id, experiment_id=experimentId) if job is None: return Response("Job not found", status_code=404) return job @router.get("/{job_id}/tasks_output") -async def get_tasks_job_output(job_id: str, sweeps: bool = False): +async def get_tasks_job_output(job_id: str, experimentId: str, sweeps: bool = False): """ Get Tasks job output with robust error handling. Uses the same logic as stream_job_output but returns content directly. """ try: - job = await job_service.job_get_cached(job_id) + job = await job_service.job_get_cached(job_id, experiment_id=experimentId) if job is None: return "Job not found" @@ -144,10 +144,10 @@ async def get_tasks_job_output(job_id: str, sweeps: bool = False): output_file_name = output_file else: # Fall back to regular output file logic - output_file_name = await shared.get_job_output_file_name(job_id) + output_file_name = await shared.get_job_output_file_name(job_id, experiment_name=experimentId) else: # Try to get output file name with fallback logic - output_file_name = await shared.get_job_output_file_name(job_id) + output_file_name = await shared.get_job_output_file_name(job_id, experiment_name=experimentId) # Read and return the file content as JSON array of lines if await storage.exists(output_file_name): @@ -164,7 +164,7 @@ async def get_tasks_job_output(job_id: str, sweeps: bool = False): print(f"Output file not found for job {job_id}, retrying in 4 seconds...") await asyncio.sleep(4) try: - output_file_name = await shared.get_job_output_file_name(job_id) + output_file_name = await shared.get_job_output_file_name(job_id, experiment_name=experimentId) if await storage.exists(output_file_name): async with await storage.open(output_file_name, "r") as f: content = await f.read() @@ -175,7 +175,11 @@ async def get_tasks_job_output(job_id: str, sweeps: bool = False): # If still no file after retry, create an empty one in the jobs directory print(f"Still no output file found for job {job_id} after retry, creating empty file: {retry_e}") # Use the Job class to get the proper directory and create the file - job_obj = Job(job_id) + job_dict = await job_service.job_get_cached(job_id, experiment_id=experimentId) + experiment_id = job_dict.get("experiment_id") if job_dict else None + if not experiment_id: + return [] + job_obj = Job(job_id, experiment_id) output_file_name = await job_obj.get_log_path() # Get directory by removing filename from path using storage.join output_dir = storage.join(*output_file_name.split("/")[:-1]) if "/" in output_file_name else "." @@ -215,7 +219,7 @@ async def get_provider_job_logs( 2. Otherwise, fall back to provider-native log retrieval (existing behavior). """ - job = await job_service.job_get_cached(job_id) + job = await job_service.job_get_cached(job_id, experiment_id=experimentId) if not job or str(job.get("experiment_id")) != str(experimentId): raise HTTPException(status_code=404, detail="Job not found") @@ -244,7 +248,7 @@ async def get_provider_job_logs( try: from lab.dirs import get_job_dir - job_dir = await get_job_dir(job_id) + job_dir = await get_job_dir(job_id, experimentId) provider_logs_path = storage.join(job_dir, "provider_logs.txt") if await storage.exists(provider_logs_path): async with await storage.open(provider_logs_path, "r", encoding="utf-8") as f: @@ -383,7 +387,7 @@ async def get_tunnel_info_for_job( choose the parser. Supports: 'vscode', 'jupyter', 'vllm', 'ollama', 'ssh'. """ - job = await job_service.job_get(job_id) + job = await job_service.job_get(job_id, experiment_id=experimentId) if not job or str(job.get("experiment_id")) != str(experimentId): raise HTTPException(status_code=404, detail="Job not found") @@ -566,13 +570,13 @@ async def get_tunnel_info_for_job( @router.get("/{job_id}/stream_output") -async def stream_job_output(job_id: str, sweeps: bool = False): +async def stream_job_output(job_id: str, experimentId: str, sweeps: bool = False): """ Stream job output with robust error handling and retry logic. Enhanced version combining the best of both train and jobs routers. """ try: - job = await job_service.job_get(job_id) + job = await job_service.job_get(job_id, experiment_id=experimentId) job_data = job.get("job_data", {}) @@ -591,10 +595,10 @@ async def stream_job_output(job_id: str, sweeps: bool = False): output_file_name = output_file else: # Fall back to regular output file logic - output_file_name = await shared.get_job_output_file_name(job_id) + output_file_name = await shared.get_job_output_file_name(job_id, experiment_name=experimentId) else: # Try to get output file name with fallback logic - output_file_name = await shared.get_job_output_file_name(job_id) + output_file_name = await shared.get_job_output_file_name(job_id, experiment_name=experimentId) except ValueError as e: # If the value error starts with "No output file found for job" then wait 4 seconds and try again @@ -603,12 +607,24 @@ async def stream_job_output(job_id: str, sweeps: bool = False): print(f"Output file not found for job {job_id}, retrying in 4 seconds...") await asyncio.sleep(4) try: - output_file_name = await shared.get_job_output_file_name(job_id) + output_file_name = await shared.get_job_output_file_name(job_id, experiment_name=experimentId) except Exception as retry_e: # If still no file after retry, create an empty one in the jobs directory print(f"Still no output file found for job {job_id} after retry, creating empty file: {retry_e}") # Use the Job class to get the proper directory and create the file - job_obj = Job(job_id) + job_dict = await job_service.job_get_cached(job_id, experiment_id=experimentId) + experiment_id = job_dict.get("experiment_id") if job_dict else None + if not experiment_id: + return StreamingResponse( + iter(["data: Error: An internal error has occurred!\n\n"]), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Access-Control-Allow-Origin": "*", + }, + ) + job_obj = Job(job_id, experiment_id) output_file_name = await job_obj.get_log_path() # Get directory by removing filename from path using storage.join output_dir = storage.join(*output_file_name.split("/")[:-1]) if "/" in output_file_name else "." @@ -670,8 +686,8 @@ async def stream_detailed_json_report(job_id: str, file_name: str): @router.get("/{job_id}/get_additional_details") -async def stream_job_additional_details(job_id: str, task: str = "view"): - job = await job_service.job_get(job_id) +async def stream_job_additional_details(job_id: str, experimentId: str, task: str = "view"): + job = await job_service.job_get(job_id, experiment_id=experimentId) if job is None: return Response("Job not found", status_code=404) job_data = job["job_data"] @@ -706,8 +722,8 @@ async def stream_job_additional_details(job_id: str, task: str = "view"): @router.get("/{job_id}/get_figure_json") -async def get_figure_path(job_id: str): - job = await job_service.job_get(job_id) +async def get_figure_path(job_id: str, experimentId: str): + job = await job_service.job_get(job_id, experiment_id=experimentId) if job is None: return Response("Job not found", status_code=404) job_data = job["job_data"] @@ -723,8 +739,8 @@ async def get_figure_path(job_id: str): @router.get("/{job_id}/get_generated_dataset") -async def get_generated_dataset(job_id: str): - job = await job_service.job_get(job_id) +async def get_generated_dataset(job_id: str, experimentId: str): + job = await job_service.job_get(job_id, experiment_id=experimentId) if job is None: return Response("Job not found", status_code=404) # Get experiment name @@ -751,9 +767,9 @@ async def get_generated_dataset(job_id: str): @router.get("/{job_id}/get_eval_results") -async def get_eval_results(job_id: str, task: str = "view", file_index: int = 0): +async def get_eval_results(job_id: str, experimentId: str, task: str = "view", file_index: int = 0): """Get evaluation results for a job""" - job = await job_service.job_get(job_id) + job = await job_service.job_get(job_id, experiment_id=experimentId) if job is None: return Response("Job not found", status_code=404) job_data = job["job_data"] @@ -834,9 +850,9 @@ async def generate(): @router.get("/{job_id}/get_eval_images") -async def get_eval_images(job_id: str): +async def get_eval_images(job_id: str, experimentId: str): """Get list of evaluation images for a job""" - job = await job_service.job_get(job_id) + job = await job_service.job_get(job_id, experiment_id=experimentId) if job is None: return Response("Job not found", status_code=404) job_data = job["job_data"] @@ -921,9 +937,9 @@ async def get_eval_images(job_id: str): @router.get("/{job_id}/image/{filename}") -async def get_eval_image(job_id: str, filename: str): +async def get_eval_image(job_id: str, filename: str, experimentId: str): """Serve individual evaluation image files""" - job = await job_service.job_get(job_id) + job = await job_service.job_get(job_id, experiment_id=experimentId) if job is None: return Response("Job not found", status_code=404) job_data = job["job_data"] @@ -977,38 +993,40 @@ async def get_eval_image(job_id: str, filename: str): @router.get("/{job_id}/checkpoints") -async def get_checkpoints(job_id: str, request: Request): +async def get_checkpoints(job_id: str, experimentId: str, request: Request): if job_id is None or job_id == "" or job_id == "-1": return {"checkpoints": []} """Get list of checkpoints for a job""" - job = await job_service.job_get(job_id) + job = await job_service.job_get(job_id, experiment_id=experimentId) if job is None: return {"checkpoints": []} + experiment_id = job.get("experiment_id") job_data = job["job_data"] # First try to use the new SDK method to get checkpoints try: from lab.job import Job # Get checkpoints using the SDK method - sdk_job = Job(job_id) - checkpoint_paths = await sdk_job.get_checkpoint_paths() - - if checkpoint_paths and len(checkpoint_paths) > 0: - checkpoints = [] - for checkpoint_path in checkpoint_paths: - try: - # Get filename from path - filename = checkpoint_path.split("/")[-1] if "/" in checkpoint_path else checkpoint_path - checkpoints.append({"filename": filename}) - except Exception as e: - print(f"Error processing checkpoint {checkpoint_path}: {e}") - continue - - # Sort checkpoints by filename in reverse (descending) order for consistent ordering - checkpoints.sort(key=lambda x: x["filename"], reverse=True) - return {"checkpoints": checkpoints} + if experiment_id: + sdk_job = Job(job_id, experiment_id) + checkpoint_paths = await sdk_job.get_checkpoint_paths() + + if checkpoint_paths and len(checkpoint_paths) > 0: + checkpoints = [] + for checkpoint_path in checkpoint_paths: + try: + # Get filename from path + filename = checkpoint_path.split("/")[-1] if "/" in checkpoint_path else checkpoint_path + checkpoints.append({"filename": filename}) + except Exception as e: + print(f"Error processing checkpoint {checkpoint_path}: {e}") + continue + + # Sort checkpoints by filename in reverse (descending) order for consistent ordering + checkpoints.sort(key=lambda x: x["filename"], reverse=True) + return {"checkpoints": checkpoints} except Exception as e: print(f"SDK checkpoint method failed for job {job_id}, falling back to legacy method: {e}") @@ -1037,7 +1055,9 @@ async def get_checkpoints(job_id: str, request: Request): if not checkpoints_dir: from lab.dirs import get_job_checkpoints_dir - checkpoints_dir = await get_job_checkpoints_dir(job_id) + if not experiment_id: + return {"checkpoints": []} + checkpoints_dir = await get_job_checkpoints_dir(job_id, experiment_id) if not checkpoints_dir or not await storage.exists(checkpoints_dir): return {"checkpoints": []} elif await storage.isdir(checkpoints_dir): @@ -1087,7 +1107,7 @@ async def get_checkpoints(job_id: str, request: Request): @router.get("/{job_id}/artifacts") -async def get_artifacts(job_id: str, request: Request): +async def get_artifacts(job_id: str, experimentId: str, request: Request): """Get list of artifacts for a job""" # Validate job_id @@ -1100,7 +1120,12 @@ async def get_artifacts(job_id: str, request: Request): try: from lab.dirs import get_job_artifacts_dir - artifacts_dir = await get_job_artifacts_dir(job_id) + job_dict = await job_service.job_get_cached(job_id, experiment_id=experimentId) + experiment_id = job_dict.get("experiment_id") if job_dict else None + if not experiment_id: + return {"artifacts": []} + + artifacts_dir = await get_job_artifacts_dir(job_id, experiment_id) artifacts = await get_artifacts_from_directory(artifacts_dir) except Exception as e: print(f"Error getting artifacts for job {job_id}: {e}") @@ -1113,12 +1138,12 @@ async def get_artifacts(job_id: str, request: Request): @router.get("/{job_id}/artifacts/download_all") -async def download_all_artifacts(job_id: str): +async def download_all_artifacts(job_id: str, experimentId: str): """ Download a zip file containing all artifacts for a job. """ # 1. Gather all artifact file paths using service - all_file_paths = await job_service.get_all_artifact_paths(job_id, storage) + all_file_paths = await job_service.get_all_artifact_paths(job_id, experimentId, storage) if not all_file_paths: return Response("No artifacts found for this job", status_code=404) @@ -1141,7 +1166,7 @@ async def download_all_artifacts(job_id: str): @router.get("/{job_id}/artifact/{filename}") -async def get_artifact(job_id: str, filename: str, task: str = "view"): +async def get_artifact(job_id: str, experimentId: str, filename: str, task: str = "view"): """ Serve individual artifact files for viewing or downloading. @@ -1150,11 +1175,12 @@ async def get_artifact(job_id: str, filename: str, task: str = "view"): filename: The artifact filename task: Either "view" or "download" (default: "view") """ - job = await job_service.job_get(job_id) + job = await job_service.job_get(job_id, experiment_id=experimentId) if job is None: return Response("Job not found", status_code=404) job_data = job["job_data"] + experiment_id = job.get("experiment_id") # First try to use the new SDK method to get artifact paths artifact_file_path = None @@ -1162,18 +1188,19 @@ async def get_artifact(job_id: str, filename: str, task: str = "view"): from lab.job import Job # Get artifacts using the SDK method - sdk_job = Job(job_id) - artifact_paths = await sdk_job.get_artifact_paths() - - if artifact_paths: - # Look for the file in the artifact paths - filename_secure = secure_filename(filename) - for artifact_path in artifact_paths: - # Check if this path matches the filename - path_filename = artifact_path.split("/")[-1] if "/" in artifact_path else artifact_path - if path_filename == filename_secure: - artifact_file_path = artifact_path - break + if experiment_id: + sdk_job = Job(job_id, experiment_id) + artifact_paths = await sdk_job.get_artifact_paths() + + if artifact_paths: + # Look for the file in the artifact paths + filename_secure = secure_filename(filename) + for artifact_path in artifact_paths: + # Check if this path matches the filename + path_filename = artifact_path.split("/")[-1] if "/" in artifact_path else artifact_path + if path_filename == filename_secure: + artifact_file_path = artifact_path + break except Exception as e: print(f"Error using SDK method to get artifact paths: {e}") @@ -1272,14 +1299,14 @@ async def generate(): @router.get("/{job_id}") -async def get_training_job_by_path(job_id: str): - return await job_service.job_get(job_id) +async def get_training_job_by_path(job_id: str, experimentId: str): + return await job_service.job_get(job_id, experiment_id=experimentId) @router.get("/{job_id}/output") -async def get_training_job_output_jobpath(job_id: str, sweeps: bool = False): +async def get_training_job_output_jobpath(job_id: str, experimentId: str, sweeps: bool = False): try: - job = await job_service.job_get(job_id) + job = await job_service.job_get(job_id, experiment_id=experimentId) if job is None: return "Job not found" @@ -1325,9 +1352,9 @@ async def get_training_job_output_jobpath(job_id: str, sweeps: bool = False): @router.get("/{job_id}/sweep_results") -async def sweep_results(job_id: str): +async def sweep_results(job_id: str, experimentId: str): try: - job = await job_service.job_get(job_id) + job = await job_service.job_get(job_id, experiment_id=experimentId) if job is None: return {"status": "error", "message": "Job not found."} @@ -1361,7 +1388,7 @@ async def sweep_results(job_id: str): @router.get("/{job_id}/datasets") -async def get_job_datasets(job_id: str, request: Request): +async def get_job_datasets(job_id: str, experimentId: str, request: Request): """Get list of datasets in the job's datasets directory""" if not job_id or job_id in ("", "-1"): @@ -1370,7 +1397,7 @@ async def get_job_datasets(job_id: str, request: Request): try: from lab.dirs import get_job_datasets_dir - datasets_dir = await get_job_datasets_dir(job_id) + datasets_dir = await get_job_datasets_dir(job_id, experimentId) datasets = await job_service.get_datasets_from_directory(datasets_dir) except Exception as e: print(f"Error getting datasets for job {job_id}: {e}") @@ -1382,7 +1409,7 @@ async def get_job_datasets(job_id: str, request: Request): @router.get("/{job_id}/models") -async def get_job_models(job_id: str, request: Request): +async def get_job_models(job_id: str, experimentId: str, request: Request): """Get list of models in the job's models directory""" if not job_id or job_id in ("", "-1"): @@ -1391,7 +1418,7 @@ async def get_job_models(job_id: str, request: Request): try: from lab.dirs import get_job_models_dir - models_dir = await get_job_models_dir(job_id) + models_dir = await get_job_models_dir(job_id, experimentId) models = await job_service.get_models_from_directory(models_dir) except Exception as e: print(f"Error getting models for job {job_id}: {e}") @@ -1423,7 +1450,8 @@ async def save_dataset_to_registry( try: dataset_name_secure = secure_filename(dataset_name) - job_datasets_dir = await get_job_datasets_dir(job_id) + # Get source path (job's datasets directory) + job_datasets_dir = await get_job_datasets_dir(job_id, experimentId) source_path = storage.join(job_datasets_dir, dataset_name_secure) if not await storage.exists(source_path): @@ -1556,7 +1584,8 @@ async def save_model_to_registry( try: model_name_secure = secure_filename(model_name) - job_models_dir = await get_job_models_dir(job_id) + # Get source path (job's models directory) + job_models_dir = await get_job_models_dir(job_id, experimentId) source_path = storage.join(job_models_dir, model_name_secure) if not await storage.exists(source_path): @@ -1671,11 +1700,11 @@ async def _save_model_to_registry( @router.get("/{job_id}/files") -async def list_job_files(job_id: str, subpath: str = ""): +async def list_job_files(job_id: str, experimentId: str, subpath: str = ""): """List files and directories in a job's directory.""" from lab.dirs import get_job_dir - job_dir = await get_job_dir(job_id) + job_dir = await get_job_dir(job_id, experimentId) if not await storage.exists(job_dir): return {"files": [], "path": subpath} @@ -1715,11 +1744,11 @@ async def list_job_files(job_id: str, subpath: str = ""): @router.get("/{job_id}/file/{file_path:path}") -async def get_job_file(job_id: str, file_path: str): +async def get_job_file(job_id: str, file_path: str, experimentId: str): """Serve a file from a job's directory.""" from lab.dirs import get_job_dir - job_dir = await get_job_dir(job_id) + job_dir = await get_job_dir(job_id, experimentId) target = storage.join(job_dir, file_path) if not await storage.exists(target) or not await storage.isfile(target): @@ -1817,7 +1846,7 @@ async def get_profiling_report( """ from lab.dirs import get_job_profiling_dir - profiling_dir = await get_job_profiling_dir(job_id) + profiling_dir = await get_job_profiling_dir(job_id, experimentId) report_path = storage.join(profiling_dir, "profiling_report.json") if not await storage.exists(report_path): diff --git a/api/transformerlab/services/job_service.py b/api/transformerlab/services/job_service.py index a3d640b49..6ad301f4c 100644 --- a/api/transformerlab/services/job_service.py +++ b/api/transformerlab/services/job_service.py @@ -26,6 +26,54 @@ "SWEEP", ] +SHORT_JOB_ID_LEN = 8 + + +def get_short_job_id(job_id: str | int, length: int = SHORT_JOB_ID_LEN) -> str: + return str(job_id)[:length] + + +async def _resolve_full_job_id(job_id: str, experiment_id: str) -> Optional[str]: + """ + Resolve an incoming job identifier within an experiment. + + - If exact job dir exists, return as-is. + - Otherwise treat it as a prefix and return the unique match. + - If none or ambiguous, return None. + """ + from lab.dirs import get_jobs_dir + + jobs_dir = await get_jobs_dir(experiment_id) + exact_path = storage.join(jobs_dir, str(job_id)) + if await storage.exists(exact_path): + return str(job_id) + + try: + entries = await storage.ls(jobs_dir, detail=False) + except Exception: + return None + + matches: list[str] = [] + prefix = str(job_id) + for entry in entries: + entry_path = entry if isinstance(entry, str) else str(entry) + entry_id = entry_path.rstrip("/").split("/")[-1] + if entry_id.startswith(prefix): + matches.append(entry_id) + + if len(matches) == 1: + return matches[0] + return None + + +def _add_short_id(job_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if not job_dict: + return job_dict + job_id = job_dict.get("id") + if job_id is not None: + job_dict["short_id"] = get_short_job_id(str(job_id)) + return job_dict + async def job_create(type, status, experiment_id, job_data="{}"): # check if type is allowed @@ -58,7 +106,8 @@ async def job_create(type, status, experiment_id, job_data="{}"): async def jobs_get_all(experiment_id, type="", status=""): exp_obj = Experiment(experiment_id) - return await exp_obj.get_jobs(type, status) + jobs = await exp_obj.get_jobs(type, status) + return [_add_short_id(job) for job in jobs] async def jobs_get_all_by_experiment_and_type(experiment_id, job_type): @@ -91,23 +140,28 @@ def _job_cache_key(job_id: str) -> str: return f"jobs:{job_id}" -async def _job_get_live(job_id: str) -> Optional[Dict[str, Any]]: +async def _job_get_live(job_id: str, experiment_id: Optional[str]) -> Optional[Dict[str, Any]]: + if not experiment_id: + raise ValueError(f"experiment_id is required for job lookup (job_id={job_id})") try: - job = await Job.get(job_id) - return await job.get_json_data(uncached=True) + resolved_job_id = await _resolve_full_job_id(str(job_id), str(experiment_id)) + if not resolved_job_id: + return None + job = await Job.get(resolved_job_id, experiment_id) + return _add_short_id(await job.get_json_data(uncached=True)) except Exception as e: print("Error getting job data", e) return None -async def job_get(job_id: str) -> Optional[Dict[str, Any]]: +async def job_get(job_id: str, experiment_id: Optional[str] = None) -> Optional[Dict[str, Any]]: """ Backward-compatible live getter that bypasses the per-node cache. """ - return await _job_get_live(job_id) + return await _job_get_live(job_id, experiment_id) -async def job_get_cached(job_id: str) -> Optional[Dict[str, Any]]: +async def job_get_cached(job_id: str, experiment_id: Optional[str] = None) -> Optional[Dict[str, Any]]: """ Per-node cached getter for job JSON backed by cashews. @@ -116,7 +170,14 @@ async def job_get_cached(job_id: str) -> Optional[Dict[str, Any]]: 2. Fallback to live Job.get(...).get_json_data(uncached=True). 3. If live status is terminal, persist to cache for future reads. """ - key = _job_cache_key(job_id) + if not experiment_id: + raise ValueError(f"experiment_id is required for job lookup (job_id={job_id})") + + resolved_job_id = await _resolve_full_job_id(str(job_id), str(experiment_id)) + if not resolved_job_id: + return None + + key = _job_cache_key(resolved_job_id) # 1) Try cache first cached = await cache.get(key) @@ -124,7 +185,7 @@ async def job_get_cached(job_id: str) -> Optional[Dict[str, Any]]: return cached # 2) Fallback to live - job_dict = await _job_get_live(job_id) + job_dict = await _job_get_live(resolved_job_id, experiment_id) if not job_dict: return None @@ -132,7 +193,7 @@ async def job_get_cached(job_id: str) -> Optional[Dict[str, Any]]: status = job_dict.get("status") if is_terminal_state(status): try: - await cache.set(key, job_dict, ttl="7d", tags=["jobs", f"job:{job_id}"]) + await cache.set(key, job_dict, ttl="7d", tags=["jobs", f"job:{resolved_job_id}"]) except Exception: # Best-effort – ignore cache errors. pass @@ -152,10 +213,7 @@ async def job_delete_all(experiment_id): async def job_delete(job_id, experiment_id): try: - job = await Job.get(job_id) - exp_id = await job.get_experiment_id() - if experiment_id is not None and exp_id != experiment_id: - return + job = await Job.get(job_id, experiment_id) await job.delete() except Exception as e: print(f"Error deleting job {job_id}: {e}") @@ -163,10 +221,7 @@ async def job_delete(job_id, experiment_id): async def job_update_job_data_insert_key_value(job_id, key, value, experiment_id): try: - job = await Job.get(job_id) - exp_id = await job.get_experiment_id() - if experiment_id is not None and exp_id != experiment_id: - return + job = await Job.get(job_id, experiment_id) await job.update_job_data_field(key, value) except Exception as e: print(f"Error updating job {job_id}: {e}") @@ -180,10 +235,7 @@ async def job_update_job_data_insert_key_values(job_id, updates: Dict[str, Any], if not isinstance(updates, dict): raise TypeError("updates must be a dict") - job = await Job.get(job_id) - exp_id = await job.get_experiment_id() - if experiment_id is not None and exp_id != experiment_id: - return + job = await Job.get(job_id, experiment_id) await job.update_job_data_field(updates, multiple=True) except Exception as e: print(f"Error updating job {job_id}: {e}") @@ -202,10 +254,7 @@ async def job_update_progress(job_id, progress, experiment_id): progress: int representing percent complete """ try: - job = await Job.get(job_id) - exp_id = await job.get_experiment_id() - if experiment_id is not None and exp_id != experiment_id: - return + job = await Job.get(job_id, experiment_id) await job.update_progress(progress) except Exception as e: print(f"Error updating job {job_id}: {e}") @@ -216,10 +265,7 @@ async def job_update_sweep_progress(job_id, value, experiment_id): Update the 'sweep_progress' key in the job_data JSON column for a given job. """ try: - job = await Job.get(job_id) - exp_id = await job.get_experiment_id() - if experiment_id is not None and exp_id != experiment_id: - return + job = await Job.get(job_id, experiment_id) await job.update_sweep_progress(value) except Exception as e: print(f"Error updating sweep job {job_id}: {e}") @@ -254,8 +300,11 @@ async def job_update_launch_progress( Appends a step to the timeline and sets phase, percent, and message. """ try: - existing = await job_get(job_id) - if not existing or (experiment_id is not None and existing.get("experiment_id") != experiment_id): + if not experiment_id: + return + + existing = await job_get(job_id, experiment_id) + if not existing or existing.get("experiment_id") != experiment_id: return job_data = existing.get("job_data") or {} existing_progress = job_data.get("launch_progress") or {} @@ -273,11 +322,10 @@ async def jobs_get_sweep_children(parent_job_id, experiment_id=None): Get all child jobs that belong to a sweep parent job. """ try: - parent_job = await Job.get(parent_job_id) - if experiment_id is not None: - exp_id = await parent_job.get_experiment_id() - if exp_id != experiment_id: - return [] + if experiment_id is None: + return [] + + parent_job = await Job.get(parent_job_id, experiment_id) job_data = await parent_job.get_job_data() if not isinstance(job_data, dict): @@ -291,7 +339,7 @@ async def jobs_get_sweep_children(parent_job_id, experiment_id=None): child_jobs = [] for child_job_id in sweep_job_ids: try: - child_job = await Job.get(child_job_id) + child_job = await Job.get(child_job_id, experiment_id) # Get full job data (including type, status, etc.) job_json = await child_job.get_json_data() child_jobs.append(job_json) @@ -311,11 +359,10 @@ async def job_get_sweep_parent(child_job_id, experiment_id=None): Returns None if the job is not a sweep child. """ try: - child_job = await Job.get(child_job_id) - if experiment_id is not None: - exp_id = await child_job.get_experiment_id() - if exp_id != experiment_id: - return None + if experiment_id is None: + return None + + child_job = await Job.get(child_job_id, experiment_id) job_data = await child_job.get_job_data() if not isinstance(job_data, dict): @@ -325,7 +372,7 @@ async def job_get_sweep_parent(child_job_id, experiment_id=None): if not parent_job_id: return None - parent_job = await Job.get(parent_job_id) + parent_job = await Job.get(parent_job_id, experiment_id) return await parent_job.get_json_data() except Exception as e: print(f"Error getting sweep parent for job {child_job_id}: {e}") @@ -484,12 +531,13 @@ async def job_update_status( except ValueError: raise ValueError(f"Invalid job status: {status!r}. Must be one of: {[s.value for s in JobStatus]}") + if experiment_id is None: + return + + job = None # Get old status before updating for queue management try: - job = await Job.get(job_id) - exp_id = await job.get_experiment_id() - if experiment_id is not None and exp_id != experiment_id: - return + job = await Job.get(job_id, experiment_id) await job.update_status(status) if error_msg: @@ -521,7 +569,7 @@ async def job_update_status( # Populate per-node cache for terminal jobs (best-effort). if is_terminal_state(status): try: - live_dict = await _job_get_live(job_id) + live_dict = await _job_get_live(job_id, experiment_id) if live_dict: key = _job_cache_key(job_id) await cache.set(key, live_dict, ttl="7d", tags=["jobs", f"job:{job_id}"]) @@ -540,10 +588,9 @@ async def job_update(job_id: str, type: str, status: str, experiment_id: Optiona experiment_id: The experiment ID (required for most operations, optional for backward compatibility) """ try: - job = await Job.get(job_id) - exp_id = await job.get_experiment_id() - if experiment_id is not None and exp_id != experiment_id: + if experiment_id is None: return + job = await Job.get(job_id, experiment_id) await job.set_type(type) await job.update_status(status) @@ -563,10 +610,9 @@ def job_update_type_and_status_sync(job_id: str, job_type: str, status: str, exp experiment_id: The experiment ID (required for most operations, optional for backward compatibility) """ try: - job = asyncio.run(Job.get(job_id)) - exp_id = asyncio.run(job.get_experiment_id()) - if experiment_id is not None and exp_id != experiment_id: + if experiment_id is None: return + job = asyncio.run(Job.get(job_id, experiment_id)) asyncio.run(job.set_type(job_type)) asyncio.run(job.update_status(status)) except Exception as e: @@ -588,7 +634,7 @@ async def format_artifact(file_path: str) -> Optional[Dict[str, any]]: return None -async def get_artifacts_from_sdk(job_id: str) -> Optional[List[Dict]]: +async def get_artifacts_from_sdk(job_id: str, experiment_id: str) -> Optional[List[Dict]]: """ Get artifacts using the SDK method. Returns list of artifacts or None if SDK method fails. @@ -596,7 +642,10 @@ async def get_artifacts_from_sdk(job_id: str) -> Optional[List[Dict]]: try: from lab.job import Job - sdk_job = Job(job_id) + if not experiment_id: + return None + + sdk_job = Job(job_id, experiment_id) artifact_paths = sdk_job.get_artifact_paths() if not artifact_paths: @@ -654,27 +703,30 @@ async def get_artifacts_from_directory(artifacts_dir: str) -> List[Dict]: return artifacts -async def get_all_artifact_paths(job_id: str) -> List[str]: +async def get_all_artifact_paths(job_id: str, experiment_id: str, _storage: Any = None) -> List[str]: """ Get all artifact file paths for a job. Uses get_artifacts_from_sdk and get_artifacts_from_directory to retrieve paths. """ # 1. Try SDK method - sdk_artifacts = await get_artifacts_from_sdk(job_id) + sdk_artifacts = await get_artifacts_from_sdk(job_id, experiment_id) if sdk_artifacts: return [a.get("full_path") for a in sdk_artifacts if a.get("full_path")] # 2. Fallback to artifacts directory - job = await job_get(job_id) + job = await job_get(job_id, experiment_id=experiment_id) if job: job_data = job.get("job_data", {}) + job_experiment_id = job.get("experiment_id") artifacts_dir = job_data.get("artifacts_dir") if not artifacts_dir: try: from lab.dirs import get_job_artifacts_dir - artifacts_dir = await get_job_artifacts_dir(job_id) + if not job_experiment_id: + return [] + artifacts_dir = await get_job_artifacts_dir(job_id, job_experiment_id) except Exception: pass diff --git a/api/transformerlab/services/migrate_jobs_to_experiment_dirs.py b/api/transformerlab/services/migrate_jobs_to_experiment_dirs.py new file mode 100644 index 000000000..61380c716 --- /dev/null +++ b/api/transformerlab/services/migrate_jobs_to_experiment_dirs.py @@ -0,0 +1,278 @@ +""" +One-time jobs directory layout migration, executed automatically at API startup. + +Goal: + Move legacy job directories from: + {workspace}/jobs/{job_id}/ + into: + {workspace}/experiments/{exp_id}/jobs/{job_id}/ + +Then delete experiment-scoped `jobs.json` indexes so the filesystem becomes the index. + +This runs per org (team) with org-scoped storage roots. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any, Optional + +from lab.dirs import get_workspace_dir, set_organization_id as lab_set_org_id +from lab import storage + +from transformerlab.services import team_service + +logger = logging.getLogger(__name__) + +_jobs_migration_worker_task: Optional[asyncio.Task] = None + + +def _set_org_context(org_id: Optional[str]) -> None: + if lab_set_org_id is not None: + lab_set_org_id(org_id) + + +def _clear_org_context() -> None: + _set_org_context(None) + + +def _basename(path: str) -> str: + return str(path).rstrip("/").split("/")[-1] + + +def _get_experiment_id_from_index(index_data: dict[str, Any]) -> Optional[str]: + # We try multiple key names for backward compatibility. + for key in ( + "experiment_id", + "experimentId", + "experiment_name", + "experimentName", + "experiment", + "exp_id", + "expId", + ): + value = index_data.get(key) + if value is None: + continue + value_str = str(value).strip() + if value_str: + return value_str + return None + + +async def _read_json(path: str) -> dict[str, Any]: + async with await storage.open(path, "r", encoding="utf-8") as f: + return json.loads(await f.read()) + + +async def _iter_dir_names(parent_dir: str) -> list[str]: + """ + List immediate child directory names under `parent_dir`. + """ + try: + entries = await storage.ls(parent_dir, detail=True) + except Exception: + return [] + + dir_names: list[str] = [] + for entry in entries: + full_path = "" + entry_type: Optional[str] = None + + if isinstance(entry, dict): + full_path = entry.get("name") or entry.get("path") or "" + entry_type = entry.get("type") + else: + full_path = str(entry) + + if not full_path: + continue + + try: + is_dir = entry_type == "directory" or await storage.isdir(full_path) + except Exception: + is_dir = entry_type == "directory" + + if is_dir: + dir_names.append(_basename(full_path)) + + return dir_names + + +async def _org_needs_jobs_migration(org_id: str) -> bool: + """ + Determine if migration is needed for this org by checking for + `{workspace}/experiments/*/jobs.json`. + """ + _set_org_context(org_id) + try: + workspace_dir = await get_workspace_dir() + experiments_dir = storage.join(workspace_dir, "experiments") + if not await storage.exists(experiments_dir): + return False + + exp_ids = await _iter_dir_names(experiments_dir) + for exp_id in exp_ids: + jobs_json_path = storage.join(experiments_dir, exp_id, "jobs.json") + if await storage.exists(jobs_json_path): + return True + except Exception as exc: # noqa: BLE001 + logger.warning(f"Jobs migration: org {org_id} pre-check failed: {exc}") + return False + finally: + _clear_org_context() + + return False + + +async def _migrate_org_jobs(org_id: str) -> dict[str, Any]: + """ + Perform migration for a single org. + """ + _set_org_context(org_id) + moved_jobs = 0 + skipped_jobs = 0 + removed_jobs_json = 0 + + try: + workspace_dir = await get_workspace_dir() + old_jobs_dir = storage.join(workspace_dir, "jobs") + experiments_dir = storage.join(workspace_dir, "experiments") + + if not await storage.exists(old_jobs_dir): + return { + "org_id": org_id, + "moved_jobs": 0, + "skipped_jobs": 0, + "removed_jobs_json": 0, + "status": "no_old_jobs_dir", + } + + if not await storage.exists(experiments_dir): + await storage.makedirs(experiments_dir, exist_ok=True) + + job_ids = await _iter_dir_names(old_jobs_dir) + logger.info(f"Jobs migration: org {org_id}: {len(job_ids)} candidate job dir(s)") + + for job_id in sorted(job_ids): + job_dir = storage.join(old_jobs_dir, job_id) + index_path = storage.join(job_dir, "index.json") + + if not await storage.exists(index_path): + logger.warning(f"Jobs migration: org {org_id}: {job_id} missing index.json, skipping") + skipped_jobs += 1 + continue + + try: + index_data = await _read_json(index_path) + except Exception as exc: # noqa: BLE001 + logger.warning(f"Jobs migration: org {org_id}: {job_id} index.json read failed: {exc}") + skipped_jobs += 1 + continue + + experiment_id = _get_experiment_id_from_index(index_data) + if not experiment_id: + logger.warning(f"Jobs migration: org {org_id}: {job_id} missing experiment_id, skipping") + skipped_jobs += 1 + continue + + dest_dir = storage.join(experiments_dir, experiment_id, "jobs", job_id) + dest_exists = await storage.exists(dest_dir) + src_exists = await storage.exists(job_dir) + + if dest_exists and not src_exists: + skipped_jobs += 1 + continue + if dest_exists and src_exists: + logger.warning(f"Jobs migration: org {org_id}: {job_id} src+dest exist; manual resolution required") + skipped_jobs += 1 + continue + + logger.info(f"Jobs migration: org {org_id}: move {job_id} -> {dest_dir}") + await storage.copy_dir(job_dir, dest_dir) + await storage.rm_tree(job_dir) + moved_jobs += 1 + + # Remove jobs.json per experiment + exp_ids = await _iter_dir_names(experiments_dir) + for exp_id in exp_ids: + jobs_json_path = storage.join(experiments_dir, exp_id, "jobs.json") + if await storage.exists(jobs_json_path): + logger.info(f"Jobs migration: org {org_id}: delete {jobs_json_path}") + await storage.rm_tree(jobs_json_path) + removed_jobs_json += 1 + + # Backward-compat: remove legacy top-level workspace/jobs.json if present. + legacy_jobs_json = storage.join(workspace_dir, "jobs.json") + if await storage.exists(legacy_jobs_json): + await storage.rm_tree(legacy_jobs_json) + + return { + "org_id": org_id, + "moved_jobs": moved_jobs, + "skipped_jobs": skipped_jobs, + "removed_jobs_json": removed_jobs_json, + "status": "migrated", + } + + except Exception as exc: # noqa: BLE001 + logger.exception(f"Jobs migration: org {org_id} failed: {exc}") + return { + "org_id": org_id, + "moved_jobs": moved_jobs, + "skipped_jobs": skipped_jobs, + "removed_jobs_json": removed_jobs_json, + "status": "error", + "error": str(exc), + } + finally: + _clear_org_context() + + +async def _jobs_migration_worker() -> None: + try: + org_ids = await team_service.get_all_team_ids() + except Exception as exc: # noqa: BLE001 + logger.warning(f"Jobs migration worker: failed listing org ids: {exc}") + return + + if not org_ids: + logger.info("Jobs migration worker: no orgs found; nothing to migrate") + return + + for org_id in org_ids: + if await _org_needs_jobs_migration(org_id): + result = await _migrate_org_jobs(org_id) + logger.info(f"Jobs migration worker: result for org {org_id}: {result}") + else: + logger.info(f"Jobs migration worker: org {org_id} already migrated (no jobs.json found)") + + +async def start_jobs_migration_worker() -> None: + """ + Start the one-time migration worker in the background. + """ + global _jobs_migration_worker_task + if _jobs_migration_worker_task is not None and not _jobs_migration_worker_task.done(): + return + + _jobs_migration_worker_task = asyncio.create_task(_jobs_migration_worker()) + + +async def stop_jobs_migration_worker() -> None: + global _jobs_migration_worker_task + if _jobs_migration_worker_task is None: + return + + if _jobs_migration_worker_task.done(): + return + + _jobs_migration_worker_task.cancel() + try: + await _jobs_migration_worker_task + except asyncio.CancelledError: + pass + finally: + _jobs_migration_worker_task = None diff --git a/api/transformerlab/services/notification_service.py b/api/transformerlab/services/notification_service.py index 5d2e4e8fe..93a3c2d85 100644 --- a/api/transformerlab/services/notification_service.py +++ b/api/transformerlab/services/notification_service.py @@ -293,7 +293,7 @@ async def process_pending_notifications_once() -> Dict[str, int]: # IMPORTANT: jobs_get_all reads from a cached jobs.json that # does NOT reflect freshly-written job_data fields. # Always re-read uncached so we see the current notification_sent. - job = await job_service.job_get(job_id) + job = await job_service.job_get(job_id, experiment_id=experiment_id) if not job: continue diff --git a/api/transformerlab/services/quota_service.py b/api/transformerlab/services/quota_service.py index 7c80e17af..87b8024da 100644 --- a/api/transformerlab/services/quota_service.py +++ b/api/transformerlab/services/quota_service.py @@ -278,7 +278,7 @@ async def record_quota_usage( async def ensure_quota_recorded_for_completed_job( - session: AsyncSession, job_id: str, team_id: Optional[str] = None + session: AsyncSession, job_id: str, experiment_id: Optional[str] = None, team_id: Optional[str] = None ) -> bool: """ Check if a completed REMOTE job has quota usage recorded. @@ -291,7 +291,11 @@ async def ensure_quota_recorded_for_completed_job( from sqlalchemy import select # Get the job - job = await job_service.job_get(job_id) + if not experiment_id: + # Jobs are now experiment-scoped; caller must provide experiment context. + return False + + job = await job_service.job_get(job_id, experiment_id=experiment_id) if not job: return False diff --git a/api/transformerlab/services/sweep_status_service.py b/api/transformerlab/services/sweep_status_service.py index b8734c0d2..00b6642c7 100644 --- a/api/transformerlab/services/sweep_status_service.py +++ b/api/transformerlab/services/sweep_status_service.py @@ -122,9 +122,9 @@ async def apply_parent_sweep_updates( return job # caller only checks truthiness; avoids a redundant S3 re-fetch -async def _fetch_child_job(child_job_id: str) -> Optional[Dict[str, Any]]: +async def _fetch_child_job(child_job_id: str, experiment_id: str) -> Optional[Dict[str, Any]]: async with _child_fetch_semaphore: - return await job_service.job_get(child_job_id) + return await job_service.job_get(child_job_id, experiment_id=experiment_id) async def refresh_sweep_parent(job: Dict[str, Any], experiment_id: str) -> Optional[Dict[str, Any]]: @@ -140,7 +140,7 @@ async def refresh_sweep_parent(job: Dict[str, Any], experiment_id: str) -> Optio # Fetch all child jobs concurrently instead of serially. return_exceptions=True # means one failed fetch doesn't abort the rest; non-dict results are filtered out. results = await asyncio.gather( - *[_fetch_child_job(str(cid)) for cid in sweep_job_ids], + *[_fetch_child_job(str(cid), experiment_id) for cid in sweep_job_ids], return_exceptions=True, ) child_jobs: List[Dict[str, Any]] = [r for r in results if isinstance(r, dict)] diff --git a/api/transformerlab/services/trackio_service.py b/api/transformerlab/services/trackio_service.py index 137122ff9..1ff43dd6d 100644 --- a/api/transformerlab/services/trackio_service.py +++ b/api/transformerlab/services/trackio_service.py @@ -10,6 +10,7 @@ from lab import HOME_DIR, storage from lab.dirs import get_workspace_dir from lab.job import Job +from transformerlab.services.job_service import job_get_cached from werkzeug.utils import secure_filename @@ -32,8 +33,15 @@ async def start_trackio_for_job(job_id: str, org_id: str | None, experiment_id: safe_org_id = secure_filename(org_id) if org_id else "" safe_experiment_id = secure_filename(experiment_id) if experiment_id else "" + if not experiment_id: + raise HTTPException(status_code=400, detail="Missing experiment_id in request context") + try: - job = await Job.get(job_id) + job_dict = await job_get_cached(job_id, experiment_id=experiment_id) + if not job_dict: + job = None + else: + job = await Job.get(job_id, experiment_id) except Exception: job = None diff --git a/api/transformerlab/shared/shared.py b/api/transformerlab/shared/shared.py index 6558dbeae..add1f3610 100644 --- a/api/transformerlab/shared/shared.py +++ b/api/transformerlab/shared/shared.py @@ -2,20 +2,13 @@ import os import re import subprocess -import sys import threading import time import unicodedata import math -from anyio import open_process -from anyio.streams.text import TextReceiveStream - -from transformerlab.services.job_service import job_update_status -import transformerlab.services.job_service as job_service from lab import Job from lab import storage -from lab.job_status import JobStatus def popen_and_call(onExit, input="", output_file=None, *popenArgs, **popenKWArgs): @@ -96,112 +89,14 @@ def slugify(value, allow_unicode=False): return re.sub(r"[-\s]+", "-", value).strip("-_") -async def async_run_python_script_and_update_status( - python_script: list[str], job_id: str, begin_string: str, env: dict | None = None -): - """ - Use this script for one time, long running scripts that have a definite end. For example - downloading a model. - - This function runs a python script and updates the status of the job in the database - to RUNNING when the python script prints begin_string to stderr - - The FastAPI worker uses stderr, not stdout - - Args: - python_script: List of command-line arguments for the Python script - job_id: Job ID for status updates - begin_string: String to look for in output to mark job as RUNNING - env: Optional dictionary of environment variables to pass to subprocess. - These are merged with the current environment and are process-local (won't leak to API). - """ - - print(f"Job {job_id} Running async python script: " + str(python_script)) - # Extract plugin location from the python_script list - plugin_location = None - if "--plugin_dir" in python_script: - for i, arg in enumerate(python_script): - if arg == "--plugin_dir" and i + 1 < len(python_script): - plugin_location = python_script[i + 1] - break - - # Check if plugin has a venv directory - if plugin_location: - plugin_location = os.path.normpath(plugin_location) - from lab.dirs import get_plugin_dir - - plugin_dir_root = await get_plugin_dir() - if not plugin_location.startswith(plugin_dir_root): - print(f"Plugin location {plugin_location} is not in {plugin_dir_root}") - raise Exception(f"Plugin location {plugin_location} is not in {plugin_dir_root}") - if os.path.exists(os.path.join(plugin_location, "venv")) and os.path.isdir( - os.path.join(plugin_location, "venv") - ): - venv_path = os.path.join(plugin_location, "venv") - print(f">Plugin has virtual environment, activating venv from {venv_path}") - venv_python = os.path.join(venv_path, "bin", "python") - command = [venv_python, *python_script] - else: - print(">Using system Python interpreter") - command = [sys.executable, *python_script] - - else: - print(">Using system Python interpreter") - command = [sys.executable, *python_script] # Skip the original Python interpreter - - # Prepare environment variables for subprocess - # Start with current environment and merge any provided env vars - process_env = os.environ.copy() - if env: - process_env.update(env) - - process = await open_process(command=command, stderr=subprocess.STDOUT, stdout=subprocess.PIPE, env=process_env) - - # read stderr and print: - if process.stdout: - async for text in TextReceiveStream(process.stdout): - print(">> " + text) - if begin_string in text: - print(f"Job {job_id} now in progress!") - job = await job_service.job_get(job_id) - experiment_id = job.get("experiment_id") if job else None - await job_update_status(job_id=job_id, status=JobStatus.RUNNING, experiment_id=experiment_id) - - # Check the job_data column for the stop flag: - job_row = await job_service.job_get(job_id) - job_data = job_row.get("job_data", None) - if job_data and job_data.get("stop", False): - print(f"Job {job_id}: 'stop' flag detected. Cancelling job.") - raise asyncio.CancelledError() - +async def get_job_output_file_name(job_id: str, plugin_name: str = None, experiment_name: str = None): try: - await process.wait() - - if process.returncode == 0: - print(f"Job {job_id} completed successfully") - job = await job_service.job_get(job_id) - experiment_id = job.get("experiment_id") if job else None - await job_update_status(job_id=job_id, status=JobStatus.COMPLETE, experiment_id=experiment_id) - else: - print(f"ERROR: Job {job_id} failed with exit code {process.returncode}.") - job = await job_service.job_get(job_id) - experiment_id = job.get("experiment_id") if job else None - await job_update_status(job_id=job_id, status=JobStatus.FAILED, experiment_id=experiment_id) - - return process + experiment_id = experiment_name or os.environ.get("_TFL_EXPERIMENT_ID") - except asyncio.CancelledError: - process.kill() - await process.wait() + if not experiment_id: + raise FileNotFoundError(f"Job '{job_id}' not found in any experiment directory") - print(f"Job {job_id} cancelled.") - - raise asyncio.CancelledError() - - -async def get_job_output_file_name(job_id: str, plugin_name: str = None, experiment_name: str = None): - try: - job_obj = await Job.get(job_id) + job_obj = await Job.get(job_id, experiment_id) output_file = await job_obj.get_log_path() return output_file except Exception as e: diff --git a/lab-sdk/pyproject.toml b/lab-sdk/pyproject.toml index 890d68080..66219a71c 100644 --- a/lab-sdk/pyproject.toml +++ b/lab-sdk/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "transformerlab" -version = "0.1.5" +version = "0.1.6" description = "Python SDK for Transformer Lab" readme = "README.md" requires-python = ">=3.10" diff --git a/lab-sdk/src/lab/dataset.py b/lab-sdk/src/lab/dataset.py index 7ff09caca..1063e3089 100644 --- a/lab-sdk/src/lab/dataset.py +++ b/lab-sdk/src/lab/dataset.py @@ -1,7 +1,7 @@ from werkzeug.utils import secure_filename from typing import Optional -from .dirs import get_datasets_dir, get_job_datasets_dir +from .dirs import get_datasets_dir, get_experiments_dir, get_job_datasets_dir from .labresource import BaseLabResource from . import storage @@ -44,13 +44,37 @@ async def get_dir(self): """Abstract method on BaseLabResource""" dataset_id_safe = secure_filename(str(self.id)) if self.job_id: - # Use job-specific directory - return await get_job_datasets_dir(self.job_id) + # Jobs are now stored under experiments/{exp_id}/jobs/{job_id}/. + # Since callers provide only job_id, infer exp_id by scanning experiments. + exp_id = await self._find_job_experiment_id() + if not exp_id: + raise FileNotFoundError(f"Job with id '{self.job_id}' not found") + return await get_job_datasets_dir(self.job_id, exp_id) else: # Use global datasets directory datasets_dir = await get_datasets_dir() return storage.join(datasets_dir, dataset_id_safe) + async def _find_job_experiment_id(self) -> Optional[str]: + job_id_safe = secure_filename(str(self.job_id)) + experiments_dir = await get_experiments_dir() + + try: + exp_entries = await storage.ls(experiments_dir, detail=False) + except Exception: + return None + + for exp_path in exp_entries: + if not await storage.isdir(exp_path): + continue + + exp_id = exp_path.rstrip("/").split("/")[-1] + job_dir = storage.join(exp_path, "jobs", job_id_safe) + if await storage.isdir(job_dir): + return exp_id + + return None + def _default_json(self): # Default metadata modeled after API dataset table fields return { diff --git a/lab-sdk/src/lab/dirs.py b/lab-sdk/src/lab/dirs.py index 0744b6dc5..14ae1b4f8 100644 --- a/lab-sdk/src/lab/dirs.py +++ b/lab-sdk/src/lab/dirs.py @@ -134,9 +134,16 @@ async def get_experiments_dir() -> str: return path -async def get_jobs_dir() -> str: - workspace_dir = await get_workspace_dir() - path = storage.join(workspace_dir, "jobs") +async def get_jobs_dir(experiment_id: str) -> str: + """ + Return the filesystem directory for all jobs in an experiment. + + Layout: + {workspace}/experiments/{experiment_id}/jobs/ + """ + experiments_dir = await get_experiments_dir() + experiment_id_safe = secure_filename(str(experiment_id)) + path = storage.join(experiments_dir, experiment_id_safe, "jobs") await storage.makedirs(path, exist_ok=True) return path @@ -245,78 +252,89 @@ def get_galleries_cache_dir() -> str: return path -async def get_job_dir(job_id: str | int) -> str: +async def get_job_dir(job_id: str | int, experiment_id: str) -> str: """ - Return the filesystem directory for a specific job id under the jobs root. - Mirrors `Job.get_dir()` but provided here for convenience where a `Job` - instance is not readily available. + Return the filesystem directory for a specific job inside an experiment. + + Layout: + {workspace}/experiments/{experiment_id}/jobs/{job_id}/ """ job_id_safe = secure_filename(str(job_id)) - jobs_dir = await get_jobs_dir() + jobs_dir = await get_jobs_dir(experiment_id) return storage.join(jobs_dir, job_id_safe) -async def get_job_artifacts_dir(job_id: str | int) -> str: +async def get_job_artifacts_dir(job_id: str | int, experiment_id: str) -> str: """ Return the artifacts directory for a specific job, creating it if needed. - Example: ~/.transformerlab/workspace/jobs//artifacts + Example: ~/.transformerlab/workspace/experiments//jobs//artifacts """ - job_dir = await get_job_dir(job_id) + job_dir = await get_job_dir(job_id, experiment_id) path = storage.join(job_dir, "artifacts") await storage.makedirs(path, exist_ok=True) return path -async def get_job_profiling_dir(job_id: str | int) -> str: +async def get_job_profiling_dir(job_id: str | int, experiment_id: str | None = None) -> str: """ Return the profiling directory for a specific job, creating it if needed. - Example: ~/.transformerlab/workspace/jobs//profiling + + Layout: + {workspace}/experiments/{experiment_id}/jobs/{job_id}/profiling + + If `experiment_id` is not provided, this function will attempt to resolve it from + the `_TFL_EXPERIMENT_ID` environment variable (used by the remote-trap wrapper). """ - job_dir = await get_job_dir(job_id) + if experiment_id is None: + experiment_id = os.environ.get("_TFL_EXPERIMENT_ID") + if not experiment_id: + raise ValueError(f"experiment_id is required for profiling dir (job_id={job_id})") + + job_dir = await get_job_dir(job_id, experiment_id) path = storage.join(job_dir, "profiling") await storage.makedirs(path, exist_ok=True) return path -async def get_job_checkpoints_dir(job_id: str | int) -> str: +async def get_job_checkpoints_dir(job_id: str | int, experiment_id: str) -> str: """ Return the checkpoints directory for a specific job, creating it if needed. - Example: ~/.transformerlab/workspace/jobs//checkpoints + Example: ~/.transformerlab/workspace/experiments//jobs//checkpoints """ - job_dir = await get_job_dir(job_id) + job_dir = await get_job_dir(job_id, experiment_id) path = storage.join(job_dir, "checkpoints") await storage.makedirs(path, exist_ok=True) return path -async def get_job_eval_results_dir(job_id: str | int) -> str: +async def get_job_eval_results_dir(job_id: str | int, experiment_id: str) -> str: """ Return the eval_results directory for a specific job, creating it if needed. - Example: ~/.transformerlab/workspace/jobs//eval_results + Example: ~/.transformerlab/workspace/experiments//jobs//eval_results """ - job_dir = await get_job_dir(job_id) + job_dir = await get_job_dir(job_id, experiment_id) path = storage.join(job_dir, "eval_results") await storage.makedirs(path, exist_ok=True) return path -async def get_job_models_dir(job_id: str | int) -> str: +async def get_job_models_dir(job_id: str | int, experiment_id: str) -> str: """ Return the models directory for a specific job, creating it if needed. - Example: ~/.transformerlab/workspace/jobs//models + Example: ~/.transformerlab/workspace/experiments//jobs//models """ - job_dir = await get_job_dir(job_id) + job_dir = await get_job_dir(job_id, experiment_id) path = storage.join(job_dir, "models") await storage.makedirs(path, exist_ok=True) return path -async def get_job_datasets_dir(job_id: str | int) -> str: +async def get_job_datasets_dir(job_id: str | int, experiment_id: str) -> str: """ Return the datasets directory for a specific job, creating it if needed. - Example: ~/.transformerlab/workspace/jobs//datasets + Example: ~/.transformerlab/workspace/experiments//jobs//datasets """ - job_dir = await get_job_dir(job_id) + job_dir = await get_job_dir(job_id, experiment_id) path = storage.join(job_dir, "datasets") await storage.makedirs(path, exist_ok=True) return path diff --git a/lab-sdk/src/lab/experiment.py b/lab-sdk/src/lab/experiment.py index fc2976391..431cbcb70 100644 --- a/lab-sdk/src/lab/experiment.py +++ b/lab-sdk/src/lab/experiment.py @@ -1,6 +1,8 @@ import asyncio import contextlib import time +import uuid +from datetime import datetime, timezone from werkzeug.utils import secure_filename from .dirs import get_experiments_dir, get_jobs_dir @@ -15,6 +17,56 @@ logger = logging.getLogger(__name__) +def _timestamp_sort_value(ts): + """ + Comparable value for sorting (higher = newer when using reverse=True). + + Parses common string shapes (ISO / space-separated naive datetimes) so order + matches real time; falls back to string compare only if parsing fails. + Missing/invalid sorts last when using reverse=True (via float('-inf')). + """ + if ts is None or ts == "": + return float("-inf") + if isinstance(ts, datetime): + if ts.tzinfo is None: + ts = ts.replace(tzinfo=timezone.utc) + return ts.timestamp() + if isinstance(ts, (int, float)): + return float(ts) + if isinstance(ts, str): + raw = ts.strip() + if not raw: + return float("-inf") + # fromisoformat accepts "2026-03-24T17:04:14+00:00" and "2026-03-24 17:04:14" + try: + dt = datetime.fromisoformat(raw.replace("Z", "+00:00")) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.timestamp() + except ValueError: + return float("-inf") + return float("-inf") + + +def _sort_key_job_recency(job: dict): + """ + Newest-first sort key. Prefer job_data.start_time (matches UI "Started") over + top-level created_at so order aligns with when the run actually started. + """ + jd = job.get("job_data") + if isinstance(jd, str): + try: + jd = json.loads(jd) + except Exception: + jd = {} + if not isinstance(jd, dict): + jd = {} + start_time = jd.get("start_time") + if start_time: + return _timestamp_sort_value(start_time) + return _timestamp_sort_value(job.get("created_at")) + + class Experiment(BaseLabResource): """ Base object for managing all config associated with an experiment @@ -63,11 +115,7 @@ def _default_json(self): async def _initialize(self): await super()._initialize() - # Create a empty jobs index and write - jobs_json_path = await self._jobs_json_file() - empty_jobs_data = {"index": self.DEFAULT_JOBS_INDEX, "cached_jobs": {}} - async with await storage.open(jobs_json_path, "w") as f: - await f.write(json.dumps(empty_jobs_data, indent=4)) + # jobs.json was eliminated in favor of direct filesystem indexing. async def update_config_field(self, key, value): """Update a single key in config.""" @@ -172,94 +220,52 @@ async def create_job(self, type: str = "REMOTE"): """ Creates a new job with a blank template and returns a Job object. """ + # New job IDs are UUIDs so they are globally unique across experiments. + new_job_id = str(uuid.uuid4()) - # Choose an ID for the new job - # Scan the jobs directory for subdirectories with numberic names - # Find the largest number and increment to get the new job ID - largest_numeric_subdir = 0 - jobs_dir = await get_jobs_dir() - try: - entries = await storage.ls(jobs_dir, detail=False) - except Exception: - entries = [] - for full_path in entries: - entry = full_path.rstrip("/").split("/")[-1] - if entry.isdigit() and await storage.isdir(full_path): - job_id = int(entry) - if job_id > largest_numeric_subdir: - largest_numeric_subdir = job_id - - new_job_id = largest_numeric_subdir + 1 - - # Create job and set its experiment fields directly (no rebuild needed) - new_job = await Job.create(new_job_id) - await new_job._update_json_data_field("experiment_id", self.id) + new_job = await Job.create(new_job_id, self.id) + await new_job._update_json_data_field("type", type) await new_job.update_job_data_field("experiment_name", self.id) - - # Add to index incrementally — no filesystem scan - await self._add_job(str(new_job.id), type) - return new_job - async def get_jobs(self, type: str = "", status: str = ""): - """ - Get a list of jobs stored in this experiment. - Uses cached data from jobs.json for completed jobs, only reads individual files for RUNNING jobs. - type: If not blank, filter by jobs with this type. - status: If not blank, filter by jobs with this status. + async def get_jobs(self, type: str = "", status: str = "") -> list[dict]: """ + Get all jobs for this experiment by scanning the filesystem directory: + experiments/{experiment_id}/jobs/{job_id}/index.json - # First get jobs of the passed type - job_list = [] - if type: - job_list = await self._get_jobs_of_type(type) - else: - job_list = await self._get_all_jobs() - - # Get cached job data from jobs.json - cached_jobs = await self._get_cached_jobs_data() + If `type` is provided, filter by `job_data["type"]`. + If `status` is provided, filter by `job_data["status"]`. + By default, DELETED jobs are excluded. + """ + jobs_dir = await get_jobs_dir(self.id) + try: + entries = await storage.ls(jobs_dir, detail=False) + except Exception: + return [] - # Iterate through the job list to return Job objects for valid jobs. - # Also filter for status if that parameter was passed. - results = [] - for job_id in job_list: + results: list[dict] = [] + for job_path in entries: + if not await storage.isdir(job_path): + continue + job_id = job_path.rstrip("/").split("/")[-1] try: - # Check if job is in cache (non-RUNNING jobs are cached) - if job_id in cached_jobs: - # Use cached data for completed jobs - job_json = cached_jobs[job_id] - # Check status of job if not RUNNING, LAUNCHING, INTERACTIVE or NOT_STARTED, then remove from cache - if job_json.get("status", "") in [ - JobStatus.RUNNING, - JobStatus.LAUNCHING, - JobStatus.INTERACTIVE, - JobStatus.NOT_STARTED, - ]: - del cached_jobs[job_id] - job = await Job.get(job_id) - job_json = await job.get_json_data(uncached=True) - cached_jobs[job_id] = job_json - - else: - # Job not in cache - job = await Job.get(job_id) - job_json = await job.get_json_data(uncached=True) + job = await Job.get(job_id, self.id) + job_data = await job.get_json_data(uncached=True) except Exception: - logger.warning("ERROR getting job %s", job_id, exc_info=True) + logger.warning("ERROR getting job %s in experiment %s", job_id, self.id, exc_info=True) continue - # Filter for status - if status and (job_json.get("status", "") != status): + if type and job_data.get("type") != type: continue - - # Exclude DELETED jobs by default (unless explicitly requested) - if not status and job_json.get("status", "") == JobStatus.DELETED: + if status and job_data.get("status") != status: + continue + if not status and job_data.get("status") == JobStatus.DELETED: continue - # If it passed filters then add as long as it has job_data - if "job_data" in job_json: - results.append(job_json) + if "job_data" in job_data: + results.append(job_data) + results.sort(key=_sort_key_job_recency, reverse=True) return results ############################### @@ -748,10 +754,19 @@ async def delete(self): async def delete_all_jobs(self): """Delete all jobs associated with this experiment.""" - all_jobs = await self._get_all_jobs() - for job_id in all_jobs: + jobs_dir = await get_jobs_dir(self.id) + try: + entries = await storage.ls(jobs_dir, detail=False) + except Exception: + return + + for job_path in entries: + if not await storage.isdir(job_path): + continue + + job_id = job_path.rstrip("/").split("/")[-1] try: - job = await Job.get(job_id) + job = await Job.get(job_id, self.id) await job.delete() except Exception: pass # Job might not exist diff --git a/lab-sdk/src/lab/job.py b/lab-sdk/src/lab/job.py index 3a50c83f3..0fb746798 100644 --- a/lab-sdk/src/lab/job.py +++ b/lab-sdk/src/lab/job.py @@ -1,5 +1,6 @@ +import json +from datetime import datetime, timezone import posixpath -from werkzeug.utils import secure_filename from . import dirs from .labresource import BaseLabResource @@ -10,21 +11,72 @@ logger = logging.getLogger(__name__) +async def _iter_all_experiment_jobs() -> list[tuple[str, str]]: + """ + Yield (job_id, experiment_id) pairs for every job directory across all experiments. + + This uses the filesystem as the index: + {workspace}/experiments/{experiment_id}/jobs/{job_id}/index.json + """ + pairs: list[tuple[str, str]] = [] + + experiments_dir = await dirs.get_experiments_dir() + try: + exp_entries = await storage.ls(experiments_dir, detail=False) + except Exception: + return pairs + + for exp_path in exp_entries: + if not await storage.isdir(exp_path): + continue + + exp_id = exp_path.rstrip("/").split("/")[-1] + jobs_path = storage.join(exp_path, "jobs") + + try: + job_entries = await storage.ls(jobs_path, detail=False) + except Exception: + continue + + for job_path in job_entries: + if await storage.isdir(job_path): + job_id = job_path.rstrip("/").split("/")[-1] + pairs.append((job_id, exp_id)) + + return pairs + + class Job(BaseLabResource): """ Used to update status and info of long-running jobs. """ - def __init__(self, job_id): + def __init__(self, job_id, experiment_id: str): self.id = job_id + self.experiment_id = experiment_id self.should_stop = False + @classmethod + async def create(cls, job_id: str, experiment_id: str): + newobj = cls(job_id, experiment_id) + await newobj._initialize() + return newobj + + @classmethod + async def get(cls, job_id: str, experiment_id: str): + newobj = cls(job_id, experiment_id) + resource_dir = await newobj.get_dir() + if not await storage.isdir(resource_dir): + raise FileNotFoundError(f"Directory for Job with id '{job_id}' not found in experiment '{experiment_id}'") + json_file = await newobj._get_json_file() + if not await storage.exists(json_file): + async with await storage.open(json_file, "w", encoding="utf-8") as f: + await f.write(json.dumps(newobj._default_json())) + return newobj + async def get_dir(self): """Abstract method on BaseLabResource""" - job_id_safe = secure_filename(str(self.id)) - jobs_dir = await dirs.get_jobs_dir() - job_dir = storage.join(jobs_dir, job_id_safe) - return job_dir + return await dirs.get_job_dir(self.id, self.experiment_id) async def get_log_path(self): """ @@ -61,7 +113,8 @@ def _default_json(self): } return { "id": self.id, - "experiment_id": "", + "experiment_id": self.experiment_id, + "created_at": datetime.now(timezone.utc).isoformat(), "job_data": default_job_data, "status": JobStatus.NOT_STARTED, "type": "REMOTE", @@ -77,33 +130,8 @@ async def update_progress(self, progress: int): await self._update_json_data_field("progress", progress) async def update_status(self, status: str): - """ - Update the status of this job. - - status: str representing the status of the job - """ await self._update_json_data_field("status", status) - # Update jobs.json incrementally — no full filesystem scan - try: - from .experiment import Experiment - - experiment_id = await self.get_experiment_id() - if experiment_id: - exp = Experiment(experiment_id) - if status in (JobStatus.COMPLETE, JobStatus.STOPPED, JobStatus.FAILED): - # Write current snapshot to cached_jobs so get_jobs() skips the live read - job_data = await self.get_json_data(uncached=True) - await exp._update_cached_job(str(self.id), job_data) - elif status == JobStatus.DELETED: - # Remove from index and cached_jobs - await exp._remove_job_from_index(str(self.id)) - # For non-terminal statuses (RUNNING, LAUNCHING, etc.) nothing to do: - # get_jobs() always does a live read for those. - except Exception: - # Never let index updates break status writes - pass - async def get_status(self): """ Get the status of this job. @@ -255,21 +283,14 @@ async def count_running_jobs(cls): Count how many jobs are currently running. """ count = 0 - jobs_dir = await dirs.get_jobs_dir() - try: - entries = await storage.ls(jobs_dir, detail=False) - except Exception: - entries = [] - for job_path in entries: - if await storage.isdir(job_path): - entry = job_path.rstrip("/").split("/")[-1] - try: - job = await cls.get(entry) - job_data = await job.get_json_data(uncached=True) - if job_data.get("status") == JobStatus.RUNNING: - count += 1 - except Exception: - pass + for job_id, exp_id in await _iter_all_experiment_jobs(): + try: + job = await cls.get(job_id, exp_id) + job_data = await job.get_json_data(uncached=True) + if job_data.get("status") == JobStatus.RUNNING: + count += 1 + except Exception: + pass return count @classmethod @@ -278,46 +299,42 @@ async def get_next_queued_job(cls): Get the next queued job (oldest first based on directory creation time). Returns Job data dict or None if no queued jobs. """ - queued_jobs = [] - jobs_dir = await dirs.get_jobs_dir() - try: - entries = await storage.ls(jobs_dir, detail=False) - except Exception: - entries = [] - for job_path in entries: - if await storage.isdir(job_path): - entry = job_path.rstrip("/").split("/")[-1] - try: - job = await cls.get(entry) - job_data = await job.get_json_data(uncached=True) - if job_data.get("status") == JobStatus.QUEUED: - # Without ctime in object stores, sort lexicographically by job id - queued_jobs.append((int(entry) if entry.isdigit() else 0, job_data)) - except Exception: - pass + queued_jobs: list[tuple[str, dict]] = [] + + for job_id, exp_id in await _iter_all_experiment_jobs(): + try: + job = await cls.get(job_id, exp_id) + job_data = await job.get_json_data(uncached=True) + if job_data.get("status") == JobStatus.QUEUED: + created_at = job_data.get("created_at", "") + queued_jobs.append((created_at, job_data)) + except Exception: + pass + + if not queued_jobs: + return None - if queued_jobs: - queued_jobs.sort(key=lambda x: x[0]) - return queued_jobs[0][1] - return None + # Sort oldest-first by ISO timestamp. Missing created_at sorts last. + queued_jobs.sort(key=lambda x: (x[0] == "", x[0])) + return queued_jobs[0][1] async def get_checkpoints_dir(self): """ Get the checkpoints directory path for this job. """ - return await dirs.get_job_checkpoints_dir(self.id) + return await dirs.get_job_checkpoints_dir(self.id, self.experiment_id) async def get_artifacts_dir(self): """ Get the artifacts directory path for this job. """ - return await dirs.get_job_artifacts_dir(self.id) + return await dirs.get_job_artifacts_dir(self.id, self.experiment_id) async def get_profiling_dir(self): """ Get the profiling directory path for this job. """ - return await dirs.get_job_profiling_dir(self.id) + return await dirs.get_job_profiling_dir(self.id, self.experiment_id) async def get_checkpoint_paths(self): """ diff --git a/lab-sdk/src/lab/lab_facade.py b/lab-sdk/src/lab/lab_facade.py index e3d6c3d32..011eadaf6 100644 --- a/lab-sdk/src/lab/lab_facade.py +++ b/lab-sdk/src/lab/lab_facade.py @@ -101,7 +101,7 @@ def init(self, experiment_id: str | None = None, config: Optional[Dict[str, Any] # Use existing job from environment variable # This will raise an error if the job doesn't exist self._experiment = _run_async(Experiment.create_or_get(experiment_id, create_new=False)) - self._job = _run_async(Job.get(existing_job_id)) + self._job = _run_async(Job.get(existing_job_id, self._experiment.id)) if self._job is None: raise RuntimeError(f"Job with ID {existing_job_id} not found. Check _TFL_JOB_ID environment variable.") logger.info(f"Using existing job ID: {existing_job_id}") @@ -337,7 +337,7 @@ def copy_file_mounts(self) -> None: async def _copy_file_mounts_async(self, job_id: str) -> None: """Async implementation of copy_file_mounts.""" - job = await Job.get(job_id) + job = await Job.get(job_id, self._experiment.id) if job is None: return job_data = await job.get_job_data() @@ -515,7 +515,7 @@ async def async_get_parent_job_checkpoint_path(self, parent_job_id: str, checkpo Optional[str]: The full path to the checkpoint, or None if it doesn't exist """ try: - checkpoints_dir = await dirs.get_job_checkpoints_dir(parent_job_id) + checkpoints_dir = await dirs.get_job_checkpoints_dir(parent_job_id, self._experiment.id) checkpoint_path = storage.join(checkpoints_dir, checkpoint_name) # Security check: ensure the checkpoint path is within the checkpoints directory @@ -564,7 +564,13 @@ def finish( if profiling_temp and self._job: from lab.profiling import copy_profiling_to_job - _run_async(copy_profiling_to_job(profiling_temp, str(self._job.id))) # type: ignore[union-attr] + _run_async( + copy_profiling_to_job( + profiling_temp, + str(self._job.id), + experiment_id=str(self._experiment.id), + ), # type: ignore[union-attr] + ) except Exception: pass _run_async(self._job.update_progress(100)) # type: ignore[union-attr] @@ -765,7 +771,7 @@ async def async_save_artifact( base_name = f"{job_id}_{base_name_original}" # Save to job-specific datasets directory (not a subdirectory per dataset) - datasets_dir = await dirs.get_job_datasets_dir(job_id) + datasets_dir = await dirs.get_job_datasets_dir(job_id, self._experiment.id) dest = storage.join(datasets_dir, base_name) # Create parent directories @@ -856,7 +862,7 @@ async def async_save_artifact( raise ValueError(f"Missing required columns in DataFrame: {missing_columns}") # Determine destination directory and filename - dest_dir = await dirs.get_job_eval_results_dir(job_id) + dest_dir = await dirs.get_job_eval_results_dir(job_id, self._experiment.id) if name is None or (isinstance(name, str) and name.strip() == ""): import time @@ -963,7 +969,7 @@ async def async_save_artifact( base_name = f"{job_id}_{base_name_without_ext}" # Save to job-specific models directory - models_dir = await dirs.get_job_models_dir(job_id) + models_dir = await dirs.get_job_models_dir(job_id, self._experiment.id) dest = storage.join(models_dir, base_name) # Create model directory @@ -1101,9 +1107,9 @@ async def async_save_artifact( # Determine destination directory based on type if type == "evals": - dest_dir = await dirs.get_job_eval_results_dir(job_id) + dest_dir = await dirs.get_job_eval_results_dir(job_id, self._experiment.id) else: - dest_dir = await dirs.get_job_artifacts_dir(job_id) + dest_dir = await dirs.get_job_artifacts_dir(job_id, self._experiment.id) base_name = name if (isinstance(name, str) and name.strip() != "") else posixpath.basename(src) dest = storage.join(dest_dir, base_name) @@ -1215,7 +1221,7 @@ async def async_save_dataset( # Add job_id prefix to dataset_id to avoid conflicts between jobs dataset_id_with_prefix = f"{job_id}_{dataset_id_safe}" - dataset_dir = await dirs.get_job_datasets_dir(job_id) + dataset_dir = await dirs.get_job_datasets_dir(job_id, self._experiment.id) await storage.makedirs(dataset_dir, exist_ok=True) # Determine output location and filename @@ -1367,7 +1373,7 @@ async def async_save_checkpoint(self, source_path: str, name: Optional[str] = No raise FileNotFoundError(f"Checkpoint source does not exist: {src}") job_id = self._job.id # type: ignore[union-attr] - ckpts_dir = await dirs.get_job_checkpoints_dir(job_id) + ckpts_dir = await dirs.get_job_checkpoints_dir(job_id, self._experiment.id) base_name = name if (isinstance(name, str) and name.strip() != "") else posixpath.basename(src) dest = storage.join(ckpts_dir, base_name) @@ -1478,7 +1484,13 @@ def error( if profiling_temp and self._job: from lab.profiling import copy_profiling_to_job - _run_async(copy_profiling_to_job(profiling_temp, str(self._job.id))) # type: ignore[union-attr] + _run_async( + copy_profiling_to_job( + profiling_temp, + str(self._job.id), + experiment_id=str(self._experiment.id), + ), # type: ignore[union-attr] + ) except Exception: pass _run_async( @@ -1509,6 +1521,21 @@ def _detect_and_capture_wandb_url(self) -> None: logger.info(f"📊 Detected wandb run URL: {wandb_url}") return + # Avoid importing wandb unless we have strong signals it is configured. + # Some native dependencies of wandb can hard-crash in constrained environments. + if not any( + os.environ.get(k) + for k in ( + "WANDB_API_KEY", + "WANDB_PROJECT", + "WANDB_RUN_ID", + "WANDB_ENTITY", + "WANDB_USER", + "WANDB_MODE", + ) + ): + return + # Method 2: Check for active wandb run in current process try: import wandb @@ -1564,6 +1591,19 @@ def _check_and_capture_wandb_url(self) -> None: # Method 2: Check active wandb run try: + if not any( + os.environ.get(k) + for k in ( + "WANDB_API_KEY", + "WANDB_PROJECT", + "WANDB_RUN_ID", + "WANDB_ENTITY", + "WANDB_USER", + "WANDB_MODE", + ) + ): + return + import wandb if wandb.run is not None and hasattr(wandb.run, "url"): diff --git a/lab-sdk/src/lab/profiling.py b/lab-sdk/src/lab/profiling.py index 517055e49..2382734d9 100644 --- a/lab-sdk/src/lab/profiling.py +++ b/lab-sdk/src/lab/profiling.py @@ -323,7 +323,11 @@ def finalize_profiling( pass -async def copy_profiling_to_job(profiling_temp_dir: str, job_id: str) -> None: +async def copy_profiling_to_job( + profiling_temp_dir: str, + job_id: str, + experiment_id: str | None = None, +) -> None: """ Copy profiling output from a temp directory into the job's profiling folder. @@ -334,15 +338,21 @@ async def copy_profiling_to_job(profiling_temp_dir: str, job_id: str) -> None: if not profiling_temp_dir or not os.path.isdir(profiling_temp_dir): return try: + if experiment_id is None: + experiment_id = os.environ.get("_TFL_EXPERIMENT_ID") + from lab.dirs import get_job_profiling_dir from lab import storage - dest_dir = await get_job_profiling_dir(job_id) + dest_dir = await get_job_profiling_dir(job_id, experiment_id) await storage.copy_dir(profiling_temp_dir, dest_dir) try: from lab.job import Job - job = await Job.get(job_id) + if experiment_id: + job = await Job.get(job_id, experiment_id) + else: + job = None if job is not None: await job.update_job_data_field("has_profiling", True) except Exception: diff --git a/lab-sdk/src/lab/remote_trap.py b/lab-sdk/src/lab/remote_trap.py index 131e8cf68..535d27053 100644 --- a/lab-sdk/src/lab/remote_trap.py +++ b/lab-sdk/src/lab/remote_trap.py @@ -16,7 +16,11 @@ async def _set_live_status_async(job_id: str, status: str) -> None: """Async helper to set live_status on a job and mirror failures to job status.""" try: - job = await Job.get(job_id) + experiment_id = os.environ.get("_TFL_EXPERIMENT_ID") + if not experiment_id: + return + + job = await Job.get(job_id, experiment_id) if job is None: return await job.update_job_data_field("live_status", status) @@ -32,7 +36,11 @@ async def _set_live_status_async(job_id: str, status: str) -> None: async def _set_status_async(job_id: str, status: str) -> None: """Async helper to set the high-level job status.""" try: - job = await Job.get(job_id) + experiment_id = os.environ.get("_TFL_EXPERIMENT_ID") + if not experiment_id: + return + + job = await Job.get(job_id, experiment_id) if job is None: return @@ -106,7 +114,11 @@ async def _write_provider_logs_async(job_id: str, logs_text: str) -> None: # Import inside helper to avoid circular imports at module load time. from lab.dirs import get_job_dir - job_dir = await get_job_dir(job_id) + experiment_id = os.environ.get("_TFL_EXPERIMENT_ID") + if not experiment_id: + return + + job_dir = await get_job_dir(job_id, experiment_id) log_path = storage.join(job_dir, "provider_logs.txt") # Ensure the directory exists (no-op for remote storage that doesn't require mkdirs). @@ -230,7 +242,8 @@ def main(argv: List[str] | None = None) -> int: # Copy profiling output from temp dir into job's profiling folder (same as lab.finish/error). if profiling_temp_dir and job_id: try: - asyncio.run(copy_profiling_to_job(profiling_temp_dir, job_id)) + experiment_id = os.environ.get("_TFL_EXPERIMENT_ID") + asyncio.run(copy_profiling_to_job(profiling_temp_dir, job_id, experiment_id=experiment_id)) except Exception: pass try: diff --git a/lab-sdk/tests/test_dirs.py b/lab-sdk/tests/test_dirs.py index d3996071a..cb2d67d4c 100644 --- a/lab-sdk/tests/test_dirs.py +++ b/lab-sdk/tests/test_dirs.py @@ -25,7 +25,7 @@ async def test_dirs_structure_created(monkeypatch, tmp_path): # Key directories exist assert os.path.isdir(await dirs.get_experiments_dir()) - assert os.path.isdir(await dirs.get_jobs_dir()) + assert os.path.isdir(await dirs.get_jobs_dir("exp1")) assert os.path.isdir(await dirs.get_models_dir()) assert os.path.isdir(await dirs.get_datasets_dir()) assert os.path.isdir(await dirs.get_temp_dir()) diff --git a/lab-sdk/tests/test_dirs_job_paths.py b/lab-sdk/tests/test_dirs_job_paths.py new file mode 100644 index 000000000..aac3c626b --- /dev/null +++ b/lab-sdk/tests/test_dirs_job_paths.py @@ -0,0 +1,52 @@ +import importlib +import os + +import pytest + + +@pytest.mark.asyncio +async def test_get_jobs_dir_uses_experiment(tmp_path, monkeypatch): + ws = tmp_path / "ws" + ws.mkdir() + monkeypatch.setenv("TFL_WORKSPACE_DIR", str(ws)) + + # Reload module so env var changes take effect + import lab.dirs + + importlib.reload(lab.dirs) + from lab.dirs import get_jobs_dir + + result = await get_jobs_dir("my_exp") + assert result.endswith("experiments/my_exp/jobs") + assert os.path.isdir(result) + + +@pytest.mark.asyncio +async def test_get_job_dir_uses_experiment(tmp_path, monkeypatch): + ws = tmp_path / "ws" + ws.mkdir() + monkeypatch.setenv("TFL_WORKSPACE_DIR", str(ws)) + + import lab.dirs + + importlib.reload(lab.dirs) + from lab.dirs import get_job_dir + + result = await get_job_dir("abc-123", "my_exp") + assert result.endswith("experiments/my_exp/jobs/abc-123") + + +@pytest.mark.asyncio +async def test_get_job_checkpoints_dir(tmp_path, monkeypatch): + ws = tmp_path / "ws" + ws.mkdir() + monkeypatch.setenv("TFL_WORKSPACE_DIR", str(ws)) + + import lab.dirs + + importlib.reload(lab.dirs) + from lab.dirs import get_job_checkpoints_dir + + result = await get_job_checkpoints_dir("abc-123", "my_exp") + assert result.endswith("experiments/my_exp/jobs/abc-123/checkpoints") + assert os.path.isdir(result) diff --git a/lab-sdk/tests/test_experiment.py b/lab-sdk/tests/test_experiment.py index ac52050f4..2d72a5c45 100644 --- a/lab-sdk/tests/test_experiment.py +++ b/lab-sdk/tests/test_experiment.py @@ -28,20 +28,12 @@ async def test_experiment_dir_and_jobs_index(tmp_path, monkeypatch): assert exp_dir.endswith(os.path.join("experiments", "exp1")) assert os.path.isdir(exp_dir) - # jobs.json created with default - jobs_index_file = os.path.join(exp_dir, "jobs.json") - assert os.path.isfile(jobs_index_file) - with open(jobs_index_file) as f: - data = json.load(f) - assert "index" in data - assert "TRAIN" in data["index"] - - # Create two jobs via experiment API and ensure they are indexed immediately. + # Create two jobs via experiment API and ensure they are discoverable. j1 = await exp.create_job() j2 = await exp.create_job() - all_jobs = await exp._get_all_jobs() - job_ids = set(all_jobs) + all_jobs = await exp.get_jobs() + job_ids = {j.get("id") for j in all_jobs} assert str(j1.id) in job_ids assert str(j2.id) in job_ids @@ -171,3 +163,165 @@ async def delayed_read(self, jobs_json_path, max_retries=5): jobs_data = json.load(f) cached_jobs = jobs_data.get("cached_jobs", {}) assert set(cached_jobs.keys()) >= {"job-1", "job-2"} + + +@pytest.mark.asyncio +async def test_create_job_uses_uuid_and_experiment_dir(tmp_path, monkeypatch): + for mod in list(importlib.sys.modules.keys()): + if mod.startswith("lab."): + importlib.sys.modules.pop(mod) + + ws = tmp_path / "ws" + ws.mkdir() + monkeypatch.setenv("TFL_WORKSPACE_DIR", str(ws)) + + from lab.experiment import Experiment + + exp = Experiment("alpha") + job = await exp.create_job("TRAIN") + + # Job ID should be a UUID string + import uuid + + uuid.UUID(str(job.id)) # raises ValueError if not valid UUID + + # Job directory should be under the experiment + job_dir = await job.get_dir() + assert "experiments/alpha/jobs/" in job_dir + + # index.json should have experiment_id and created_at + import os + + with open(os.path.join(job_dir, "index.json")) as f: + data = json.load(f) + assert data["experiment_id"] == "alpha" + assert "created_at" in data + + +@pytest.mark.asyncio +async def test_get_jobs_lists_from_directory(tmp_path, monkeypatch): + for mod in list(importlib.sys.modules.keys()): + if mod.startswith("lab."): + importlib.sys.modules.pop(mod) + + ws = tmp_path / "ws" + ws.mkdir() + monkeypatch.setenv("TFL_WORKSPACE_DIR", str(ws)) + + from lab.experiment import Experiment + + exp = Experiment("alpha") + j1 = await exp.create_job("TRAIN") + await j1.update_status("COMPLETE") + j2 = await exp.create_job("DOWNLOAD_MODEL") + await j2.update_status("RUNNING") + + all_jobs = await exp.get_jobs() + assert len(all_jobs) == 2 + + train_jobs = await exp.get_jobs(type="TRAIN") + assert len(train_jobs) == 1 + assert train_jobs[0]["type"] == "TRAIN" + + running_jobs = await exp.get_jobs(status="RUNNING") + assert len(running_jobs) == 1 + + +@pytest.mark.asyncio +async def test_get_jobs_sorted_newest_created_first_when_no_start_time(tmp_path, monkeypatch): + for mod in list(importlib.sys.modules.keys()): + if mod.startswith("lab."): + importlib.sys.modules.pop(mod) + + ws = tmp_path / "ws" + ws.mkdir() + monkeypatch.setenv("TFL_WORKSPACE_DIR", str(ws)) + + from lab.experiment import Experiment + + exp = Experiment("alpha") + j_old = await exp.create_job("TRAIN") + await j_old._update_json_data_field("created_at", "2020-01-01T00:00:00+00:00") + j_new = await exp.create_job("TRAIN") + await j_new._update_json_data_field("created_at", "2025-01-01T00:00:00+00:00") + + jobs = await exp.get_jobs() + assert len(jobs) == 2 + assert [j.get("id") for j in jobs] == [str(j_new.id), str(j_old.id)] + + +@pytest.mark.asyncio +async def test_get_jobs_sorted_by_start_time_over_created_at(tmp_path, monkeypatch): + for mod in list(importlib.sys.modules.keys()): + if mod.startswith("lab."): + importlib.sys.modules.pop(mod) + + ws = tmp_path / "ws" + ws.mkdir() + monkeypatch.setenv("TFL_WORKSPACE_DIR", str(ws)) + + from lab.experiment import Experiment + + exp = Experiment("alpha") + j_earlier_start = await exp.create_job("TRAIN") + await j_earlier_start._update_json_data_field("created_at", "2025-06-01T00:00:00+00:00") + await j_earlier_start.update_job_data_field("start_time", "2025-01-01T10:00:00+00:00") + + j_later_start = await exp.create_job("TRAIN") + await j_later_start._update_json_data_field("created_at", "2020-01-01T00:00:00+00:00") + await j_later_start.update_job_data_field("start_time", "2025-03-01T10:00:00+00:00") + + jobs = await exp.get_jobs() + assert len(jobs) == 2 + assert [j.get("id") for j in jobs] == [str(j_later_start.id), str(j_earlier_start.id)] + + +@pytest.mark.asyncio +async def test_get_jobs_sorted_by_space_separated_start_time(tmp_path, monkeypatch): + """start_time often stored as 'YYYY-MM-DD HH:MM:SS' from providers (not strictly ISO-T).""" + for mod in list(importlib.sys.modules.keys()): + if mod.startswith("lab."): + importlib.sys.modules.pop(mod) + + ws = tmp_path / "ws" + ws.mkdir() + monkeypatch.setenv("TFL_WORKSPACE_DIR", str(ws)) + + from lab.experiment import Experiment + + exp = Experiment("alpha") + j_first = await exp.create_job("TRAIN") + await j_first.update_job_data_field("start_time", "2026-03-24 17:04:14") + j_mid = await exp.create_job("TRAIN") + await j_mid.update_job_data_field("start_time", "2026-03-24 18:32:25") + j_last = await exp.create_job("TRAIN") + await j_last.update_job_data_field("start_time", "2026-03-24 19:57:26") + + jobs = await exp.get_jobs() + assert len(jobs) == 3 + assert [j.get("id") for j in jobs] == [str(j_last.id), str(j_mid.id), str(j_first.id)] + + +@pytest.mark.asyncio +async def test_delete_all_jobs(tmp_path, monkeypatch): + for mod in list(importlib.sys.modules.keys()): + if mod.startswith("lab."): + importlib.sys.modules.pop(mod) + + ws = tmp_path / "ws" + ws.mkdir() + monkeypatch.setenv("TFL_WORKSPACE_DIR", str(ws)) + + from lab.experiment import Experiment + + exp = Experiment("alpha") + await exp.create_job("TRAIN") + await exp.create_job("TRAIN") + + jobs_before = await exp.get_jobs() + assert len(jobs_before) == 2 + + await exp.delete_all_jobs() + + jobs_after = await exp.get_jobs() + assert len(jobs_after) == 0 diff --git a/lab-sdk/tests/test_job_and_labresource.py b/lab-sdk/tests/test_job_and_labresource.py index 7eed0bb67..1e177e67d 100644 --- a/lab-sdk/tests/test_job_and_labresource.py +++ b/lab-sdk/tests/test_job_and_labresource.py @@ -20,13 +20,14 @@ async def test_baselabresource_create_get(tmp_path, monkeypatch): from lab.job import Job - job = await Job.create("123") + job = await Job.create("123", "test_exp") job_dir = await job.get_dir() assert os.path.isdir(job_dir) + assert "experiments/test_exp/jobs/123" in job_dir index_file = os.path.join(job_dir, "index.json") assert os.path.isfile(index_file) - job2 = await Job.get("123") + job2 = await Job.get("123", "test_exp") assert isinstance(job2, Job) @@ -45,7 +46,7 @@ async def test_job_default_json_and_updates(tmp_path, monkeypatch): from lab.job import Job - job = await Job.create("1") + job = await Job.create("1", "test_exp") # On create, defaults are written to index.json job_dir = await job.get_dir() data_path = os.path.join(job_dir, "index.json") @@ -80,7 +81,7 @@ async def test_job_data_field_updates(tmp_path, monkeypatch): from lab.job import Job - job = await Job.create("2") + job = await Job.create("2", "test_exp") # Test updating job data fields directly await job.update_job_data_field("completion_status", "success") @@ -91,3 +92,97 @@ async def test_job_data_field_updates(tmp_path, monkeypatch): assert data["job_data"]["completion_status"] == "success" assert data["job_data"]["completion_details"] == "ok" assert data["job_data"]["score"] == {"acc": 1} + + +@pytest.mark.asyncio +async def test_job_create_and_get_with_experiment(tmp_path, monkeypatch): + for mod in list(importlib.sys.modules.keys()): + if mod.startswith("lab."): + importlib.sys.modules.pop(mod) + + ws = tmp_path / "ws" + ws.mkdir() + monkeypatch.setenv("TFL_WORKSPACE_DIR", str(ws)) + + from lab.job import Job + + job = await Job.create("job-uuid-1", "exp1") + job_dir = await job.get_dir() + assert "experiments/exp1/jobs/job-uuid-1" in job_dir + assert os.path.isdir(job_dir) + + # index.json must have experiment_id and created_at + index_file = os.path.join(job_dir, "index.json") + with open(index_file) as f: + data = json.load(f) + assert data["experiment_id"] == "exp1" + assert "created_at" in data + + # Job.get also requires experiment_id + job2 = await Job.get("job-uuid-1", "exp1") + assert isinstance(job2, Job) + assert job2.experiment_id == "exp1" + + +@pytest.mark.asyncio +async def test_job_get_missing_raises(tmp_path, monkeypatch): + for mod in list(importlib.sys.modules.keys()): + if mod.startswith("lab."): + importlib.sys.modules.pop(mod) + + ws = tmp_path / "ws" + ws.mkdir() + monkeypatch.setenv("TFL_WORKSPACE_DIR", str(ws)) + + from lab.job import Job + + with pytest.raises(FileNotFoundError): + await Job.get("nonexistent", "exp1") + + +@pytest.mark.asyncio +async def test_count_running_jobs_cross_experiment(tmp_path, monkeypatch): + for mod in list(importlib.sys.modules.keys()): + if mod.startswith("lab."): + importlib.sys.modules.pop(mod) + + ws = tmp_path / "ws" + ws.mkdir() + monkeypatch.setenv("TFL_WORKSPACE_DIR", str(ws)) + + from lab.job import Job + + j1 = await Job.create("job-a", "exp1") + await j1.update_status("RUNNING") + j2 = await Job.create("job-b", "exp2") + await j2.update_status("RUNNING") + j3 = await Job.create("job-c", "exp1") + await j3.update_status("COMPLETE") + + count = await Job.count_running_jobs() + assert count == 2 + + +@pytest.mark.asyncio +async def test_get_next_queued_job_sorted_by_created_at(tmp_path, monkeypatch): + for mod in list(importlib.sys.modules.keys()): + if mod.startswith("lab."): + importlib.sys.modules.pop(mod) + + ws = tmp_path / "ws" + ws.mkdir() + monkeypatch.setenv("TFL_WORKSPACE_DIR", str(ws)) + + import asyncio + + from lab.job import Job + + j1 = await Job.create("job-first", "exp1") + await j1.update_status("QUEUED") + await asyncio.sleep(0.01) # ensure different created_at + j2 = await Job.create("job-second", "exp2") + await j2.update_status("QUEUED") + + result = await Job.get_next_queued_job() + assert result is not None + assert result["id"] == "job-first" diff --git a/src/renderer/components/Experiment/Interactive/InteractIframeModal.tsx b/src/renderer/components/Experiment/Interactive/InteractIframeModal.tsx index d2fb8d8a3..f6be23772 100644 --- a/src/renderer/components/Experiment/Interactive/InteractIframeModal.tsx +++ b/src/renderer/components/Experiment/Interactive/InteractIframeModal.tsx @@ -23,7 +23,7 @@ import { fetcher } from 'renderer/lib/transformerlab-api-sdk'; import { useExperimentInfo } from 'renderer/lib/ExperimentInfoContext'; interface InteractModalProps { - jobId: number; + jobId: string | number | null; open: boolean; onClose: () => void; } @@ -38,7 +38,7 @@ export default function InteractIframeModal({ const { experimentInfo } = useExperimentInfo(); const url = React.useMemo(() => { - if (!open || !experimentInfo?.id) return null; + if (!open || !experimentInfo?.id || jobId === null) return null; return chatAPI.Endpoints.Experiment.GetTunnelInfo( experimentInfo.id, String(jobId), diff --git a/src/renderer/components/Experiment/Interactive/InteractiveJobCard.tsx b/src/renderer/components/Experiment/Interactive/InteractiveJobCard.tsx index 6d55b984b..0b9f86196 100644 --- a/src/renderer/components/Experiment/Interactive/InteractiveJobCard.tsx +++ b/src/renderer/components/Experiment/Interactive/InteractiveJobCard.tsx @@ -173,15 +173,15 @@ export default function InteractiveJobCard({ job }: InteractiveJobCardProps) { jobData.cluster_name || jobData.template_name || (isPlaceholder ? '' : `Job ${job.id}`); - const jobIdNum = parseInt(job.id, 10); + const jobIdValue = job?.id == null ? null : String(job.id); const tunnelInfoUrl = React.useMemo(() => { if (!isInteractive || !experimentInfo?.id) return null; return chatAPI.Endpoints.Experiment.GetTunnelInfo( experimentInfo.id, - String(jobIdNum), + String(jobIdValue), ); - }, [isInteractive, experimentInfo?.id, jobIdNum]); + }, [isInteractive, experimentInfo?.id, jobIdValue]); const { data: tunnelData } = useSWR(tunnelInfoUrl, fetcher, { refreshInterval: 3000, @@ -276,18 +276,18 @@ export default function InteractiveJobCard({ job }: InteractiveJobCardProps) { )} setConnectOpen(false)} embeddedOutput={ } /> setInteractOpen(false)} /> diff --git a/src/renderer/components/Experiment/Tasks/EmbeddableStreamingOutput.tsx b/src/renderer/components/Experiment/Tasks/EmbeddableStreamingOutput.tsx index 02f357d3d..f25439485 100644 --- a/src/renderer/components/Experiment/Tasks/EmbeddableStreamingOutput.tsx +++ b/src/renderer/components/Experiment/Tasks/EmbeddableStreamingOutput.tsx @@ -206,7 +206,7 @@ const TAB_OPTIONS: { value: TabValue; label: string }[] = [ ]; export interface EmbeddableStreamingOutputProps { - jobId: number; + jobId: string | number | null; /** Which tabs to show, in order. e.g. ['output', 'provider'] or ['provider'] for interactive tasks. */ tabs?: TabValue[]; /** Current job status string (e.g. 'RUNNING', 'COMPLETE'). Passed from the parent to avoid extra polling. */ @@ -237,7 +237,15 @@ export default function EmbeddableStreamingOutput({ }, [jobId, tabsKey]); const providerLogsUrl = useMemo(() => { - if (jobId === -1 || !experimentInfo?.id) { + if ( + !experimentInfo?.id || + jobId === null || + jobId === '' || + jobId === -1 || + jobId === '-1' || + jobId === 'NaN' || + (typeof jobId === 'number' && Number.isNaN(jobId)) + ) { return null; } return chatAPI.Endpoints.Experiment.GetProviderLogs( @@ -309,7 +317,7 @@ export default function EmbeddableStreamingOutput({ resetProviderCountdown, ]); - if (jobId === -1 || !experimentInfo) { + if (!jobId || !experimentInfo) { return null; } diff --git a/src/renderer/components/Experiment/Tasks/FileBrowserModal.tsx b/src/renderer/components/Experiment/Tasks/FileBrowserModal.tsx index 716a01529..aa2999374 100644 --- a/src/renderer/components/Experiment/Tasks/FileBrowserModal.tsx +++ b/src/renderer/components/Experiment/Tasks/FileBrowserModal.tsx @@ -31,9 +31,11 @@ interface FileEntry { type JobModeProps = { mode: 'job'; - jobId: number; + jobId: string; taskId?: never; taskName?: never; + galleryId?: never; + galleryTitle?: never; }; type TaskModeProps = { @@ -41,6 +43,8 @@ type TaskModeProps = { taskId: string; taskName?: string | null; jobId?: never; + galleryId?: never; + galleryTitle?: never; }; type TeamGalleryModeProps = { @@ -87,7 +91,7 @@ export default function FileBrowserModal({ setLoading(true); try { if (mode === 'job') { - if (jobId === -1) { + if (!jobId) { setFiles([]); return; } diff --git a/src/renderer/components/Experiment/Tasks/InteractiveModal.tsx b/src/renderer/components/Experiment/Tasks/InteractiveModal.tsx index 99c4f7ca3..e89be05f8 100644 --- a/src/renderer/components/Experiment/Tasks/InteractiveModal.tsx +++ b/src/renderer/components/Experiment/Tasks/InteractiveModal.tsx @@ -51,8 +51,8 @@ interface PortDef { } type InteractiveModalProps = { - jobId: number; - setJobId: (jobId: number) => void; + jobId: string | null; + setJobId: (jobId: string | null) => void; embeddedOutput?: React.ReactNode; }; @@ -482,7 +482,7 @@ export default function InteractiveModal({ const { team } = useAuth(); const url = React.useMemo(() => { - if (jobId === -1 || !experimentInfo?.id) { + if (!jobId || !experimentInfo?.id) { return null; } return chatAPI.Endpoints.Experiment.GetTunnelInfo( @@ -505,10 +505,10 @@ export default function InteractiveModal({ }); const handleClose = () => { - setJobId(-1); + setJobId(null); }; - if (jobId === -1 || !experimentInfo) { + if (!jobId || !experimentInfo) { return null; } @@ -544,7 +544,7 @@ export default function InteractiveModal({ } return ( - + ; onDeleteJob?: (jobId: string) => void; onViewOutput?: (jobId: string) => void; - onViewTensorboard?: (jobId: string) => void; onViewCheckpoints?: (jobId: string) => void; onViewArtifacts?: (jobId: string) => void; onViewEvalImages?: (jobId: string) => void; @@ -60,7 +59,6 @@ const JobsList: React.FC = ({ launchProgressByJobId, onDeleteJob, onViewOutput, - onViewTensorboard, onViewCheckpoints, onViewArtifacts, onViewEvalImages, @@ -202,139 +200,60 @@ const JobsList: React.FC = ({ {jobs?.length > 0 ? ( - jobs?.map((job) => ( - - - - - + + + + - - )) + + + )} + + + + ); + }) ) : (
- {selectMode && - job?.job_data?.eval_results && - Array.isArray(job.job_data.eval_results) && - job.job_data.eval_results.length > 0 && ( - onToggleJobSelected?.(String(job.id))} - sx={{ mr: 1 }} - /> - )} - {job.id} - - {formatJobConfig(job)} - - - - - {job?.placeholder && ( - - )} - {job?.job_data?.tensorboard_output_dir && ( - - )} - - {job?.job_data?.wandb_run_url && ( - - )} - - {(job?.job_data?.trackio_db_artifact_path || - job?.job_data?.trackio_project_name) && ( - - )} - - - {job?.job_data?.eval_images_dir && ( - - )} - {job?.job_data?.eval_results && + jobs?.map((job) => { + const fullJobId = String(job?.id ?? ''); + const displayJobId = + String(job?.short_id ?? '').trim() || fullJobId.slice(0, 8); + return ( +
+ {selectMode && + job?.job_data?.eval_results && Array.isArray(job.job_data.eval_results) && job.job_data.eval_results.length > 0 && ( + onToggleJobSelected?.(String(job.id))} + sx={{ mr: 1 }} + /> + )} + {displayJobId} + + {formatJobConfig(job)} + + + + + {job?.placeholder && ( + + )} + {job?.job_data?.wandb_run_url && ( )} - {(job?.job_data?.artifacts || - job?.job_data?.artifacts_dir || - job?.job_data?.generated_datasets || - job?.job_data?.models || - job?.job_data?.has_profiling) && ( - - } - > - - Artifacts - - - - {(job?.job_data?.artifacts || - job?.job_data?.artifacts_dir) && ( - onViewArtifacts?.(job?.id)}> - View Artifacts - - )} - {job?.job_data?.has_profiling && ( - onViewProfiling?.(job?.id)}> - View Profiling - - )} - {job?.job_data?.generated_datasets && ( - onViewJobDatasets?.(job?.id)} - > - View Datasets - - )} - {job?.job_data?.models && ( - onViewJobModels?.(job?.id)}> - View Models - - )} - - - )} - {(job?.type === 'SWEEP' || job?.job_data?.sweep_parent) && - job?.status === 'COMPLETE' && ( + + {(job?.job_data?.trackio_db_artifact_path || + job?.job_data?.trackio_project_name) && ( )} - {job?.job_data?.sweep_output_file && ( + - )} - {job?.status === 'INTERACTIVE' && - job?.job_data?.subtype === 'interactive' && ( - <> + {job?.job_data?.eval_images_dir && ( + + )} + {job?.job_data?.eval_results && + Array.isArray(job.job_data.eval_results) && + job.job_data.eval_results.length > 0 && ( + )} + {(job?.job_data?.artifacts || + job?.job_data?.artifacts_dir || + job?.job_data?.generated_datasets || + job?.job_data?.models || + job?.job_data?.has_profiling) && ( + + } + > + + Artifacts + + + + {(job?.job_data?.artifacts || + job?.job_data?.artifacts_dir) && ( + onViewArtifacts?.(job?.id)} + > + View Artifacts + + )} + {job?.job_data?.has_profiling && ( + onViewProfiling?.(job?.id)} + > + View Profiling + + )} + {job?.job_data?.generated_datasets && ( + onViewJobDatasets?.(job?.id)} + > + View Datasets + + )} + {job?.job_data?.models && ( + onViewJobModels?.(job?.id)} + > + View Models + + )} + + + )} + {(job?.type === 'SWEEP' || job?.job_data?.sweep_parent) && + job?.status === 'COMPLETE' && ( - + )} + {job?.job_data?.sweep_output_file && ( + )} - {job?.job_data?.checkpoints && ( - + + + )} + {job?.job_data?.checkpoints && ( + - )} - {!job?.placeholder && ( - + )} + {!job?.placeholder && ( + + )} + {!job?.placeholder && ( + { + if (!isTerminalJobStatus(job?.status)) { + return; + } + onDeleteJob?.(job.id); }} > - Files - - - )} - {!job?.placeholder && ( - { - if (!isTerminalJobStatus(job?.status)) { - return; - } - onDeleteJob?.(job.id); - }} - > - - - )} - -
void, wait: number) => { }; interface PollingOutputTerminalProps { - jobId: number; + jobId: string | number; experimentId: string; lineAnimationDelay?: number; initialMessage?: string; @@ -49,7 +49,7 @@ const PollingOutputTerminal: React.FC = ({ // Fetch the output file content directly using the Tasks-specific endpoint const outputEndpoint = chatAPI.Endpoints.Experiment.GetTasksOutputFromJob( experimentId, - jobId.toString(), + String(jobId), ); const { diff --git a/src/renderer/components/Experiment/Tasks/ProfilingReport.tsx b/src/renderer/components/Experiment/Tasks/ProfilingReport.tsx index eb2fc751f..65e0c3a8b 100644 --- a/src/renderer/components/Experiment/Tasks/ProfilingReport.tsx +++ b/src/renderer/components/Experiment/Tasks/ProfilingReport.tsx @@ -78,14 +78,14 @@ function StatCard({ label, peak, avg }: StatCardProps) { } interface ProfilingReportProps { - jobId: number; + jobId: string | null; } export default function ProfilingReport({ jobId }: ProfilingReportProps) { const { experimentInfo } = useExperimentInfo(); const url = - jobId !== -1 && experimentInfo?.id + jobId && experimentInfo?.id ? chatAPI.Endpoints.Experiment.GetProfilingReport( experimentInfo.id, String(jobId), diff --git a/src/renderer/components/Experiment/Tasks/Tasks.tsx b/src/renderer/components/Experiment/Tasks/Tasks.tsx index c9e202803..ea6a43d99 100644 --- a/src/renderer/components/Experiment/Tasks/Tasks.tsx +++ b/src/renderer/components/Experiment/Tasks/Tasks.tsx @@ -49,30 +49,50 @@ export default function Tasks({ subtype }: { subtype?: string }) { const [queueModalOpen, setQueueModalOpen] = useState(false); const [taskBeingQueued, setTaskBeingQueued] = useState(null); const [isSubmitting, setIsSubmitting] = useState(false); - const [viewOutputFromJob, setViewOutputFromJob] = useState(-1); - const [currentTensorboardForModal, setCurrentTensorboardForModal] = - useState(-1); - const [viewCheckpointsFromJob, setViewCheckpointsFromJob] = useState(-1); - const [viewArtifactsFromJob, setViewArtifactsFromJob] = useState(-1); - const [viewProfilingFromJob, setViewProfilingFromJob] = useState(-1); - const [viewEvalImagesFromJob, setViewEvalImagesFromJob] = useState(-1); + const [viewOutputFromJob, setViewOutputFromJob] = useState( + null, + ); + const [viewCheckpointsFromJob, setViewCheckpointsFromJob] = useState< + string | null + >(null); + const [viewArtifactsFromJob, setViewArtifactsFromJob] = useState< + string | null + >(null); + const [viewProfilingFromJob, setViewProfilingFromJob] = useState< + string | null + >(null); + const [viewEvalImagesFromJob, setViewEvalImagesFromJob] = useState< + string | null + >(null); const [viewOutputFromSweepJob, setViewOutputFromSweepJob] = useState(false); - const [viewSweepResultsFromJob, setViewSweepResultsFromJob] = useState(-1); - const [viewEvalResultsFromJob, setViewEvalResultsFromJob] = useState(-1); - const [interactiveJobForModal, setInteractiveJobForModal] = useState(-1); - const [viewJobDatasetsFromJob, setViewJobDatasetsFromJob] = useState(-1); - const [viewJobModelsFromJob, setViewJobModelsFromJob] = useState(-1); + const [viewSweepResultsFromJob, setViewSweepResultsFromJob] = useState< + string | null + >(null); + const [viewEvalResultsFromJob, setViewEvalResultsFromJob] = useState< + string | null + >(null); + const [interactiveJobForModal, setInteractiveJobForModal] = useState< + string | null + >(null); + const [viewJobDatasetsFromJob, setViewJobDatasetsFromJob] = useState< + string | null + >(null); + const [viewJobModelsFromJob, setViewJobModelsFromJob] = useState< + string | null + >(null); const [previewDatasetModal, setPreviewDatasetModal] = useState<{ open: boolean; datasetId: string | null; }>({ open: false, datasetId: null }); const [trackioJobIdForModal, setTrackioJobIdForModal] = useState< - number | null + string | null >(null); - const [compareEvalJobIds, setCompareEvalJobIds] = useState([]); + const [compareEvalJobIds, setCompareEvalJobIds] = useState([]); const [isCompareSelectMode, setIsCompareSelectMode] = useState(false); const [compareEvalModalOpen, setCompareEvalModalOpen] = useState(false); - const [viewFileBrowserFromJob, setViewFileBrowserFromJob] = useState(-1); + const [viewFileBrowserFromJob, setViewFileBrowserFromJob] = useState< + string | null + >(null); const [viewTaskFilesFromTask, setViewTaskFilesFromTask] = useState<{ id: string | null; name?: string | null; @@ -167,16 +187,18 @@ export default function Tasks({ subtype }: { subtype?: string }) { // Listen for custom event to open job output modal from interactive modals useEffect(() => { const handleOpenJobOutput = (e: Event) => { - const customEvent = e as CustomEvent<{ jobId: number }>; - const jobId = customEvent.detail?.jobId; - if (jobId && jobId !== -1) { + const customEvent = e as CustomEvent<{ jobId?: unknown }>; + const rawJobId = customEvent.detail?.jobId; + const jobIdStr = + rawJobId === null || rawJobId === undefined ? '' : String(rawJobId); + if (jobIdStr && jobIdStr !== '-1' && jobIdStr !== 'NaN') { // Close the interactive modal first - setInteractiveJobForModal(-1); + setInteractiveJobForModal(null); // Wait for the modal to close (MUI modals have transition animations) // Use a longer delay to ensure the interactive modal fully closes // before opening the output modal to avoid z-index/stacking issues setTimeout(() => { - setViewOutputFromJob(jobId); + setViewOutputFromJob(jobIdStr); }, 300); // 300ms should be enough for modal close animation } }; @@ -1296,51 +1318,63 @@ export default function Tasks({ subtype }: { subtype?: string }) { jobs={jobsWithPlaceholders as any} launchProgressByJobId={launchProgressByJobId} onDeleteJob={handleDeleteJob} - onViewOutput={(jobId) => setViewOutputFromJob(parseInt(jobId))} - onViewTensorboard={(jobId) => - setCurrentTensorboardForModal(parseInt(jobId)) - } + onViewOutput={(jobId) => { + const jobIdStr = + jobId === null || jobId === undefined ? '' : String(jobId); + if (!jobIdStr || jobIdStr === '-1' || jobIdStr === 'NaN') return; + setViewOutputFromJob(jobIdStr); + }} onViewCheckpoints={(jobId) => - setViewCheckpointsFromJob(parseInt(jobId)) + setViewCheckpointsFromJob(jobId && jobId !== 'NaN' ? jobId : null) + } + onViewArtifacts={(jobId) => + setViewArtifactsFromJob(jobId && jobId !== 'NaN' ? jobId : null) + } + onViewProfiling={(jobId) => + setViewProfilingFromJob(jobId && jobId !== 'NaN' ? jobId : null) } - onViewArtifacts={(jobId) => setViewArtifactsFromJob(parseInt(jobId))} - onViewProfiling={(jobId) => setViewProfilingFromJob(parseInt(jobId))} onViewEvalImages={(jobId) => - setViewEvalImagesFromJob(parseInt(jobId)) + setViewEvalImagesFromJob(jobId && jobId !== 'NaN' ? jobId : null) } onViewEvalResults={(jobId) => - setViewEvalResultsFromJob(parseInt(jobId)) + setViewEvalResultsFromJob(jobId && jobId !== 'NaN' ? jobId : null) } onViewGeneratedDataset={(jobId, datasetId) => { setPreviewDatasetModal({ open: true, datasetId }); }} onViewJobDatasets={(jobId) => - setViewJobDatasetsFromJob(parseInt(jobId)) + setViewJobDatasetsFromJob(jobId && jobId !== 'NaN' ? jobId : null) } - onViewJobModels={(jobId) => setViewJobModelsFromJob(parseInt(jobId))} - onViewFileBrowser={(jobId) => - setViewFileBrowserFromJob(parseInt(jobId)) + onViewJobModels={(jobId) => + setViewJobModelsFromJob(jobId && jobId !== 'NaN' ? jobId : null) } + onViewFileBrowser={(jobId) => { + if (jobId == null || jobId === '') return; + setViewFileBrowserFromJob(String(jobId)); + }} onViewSweepOutput={(jobId) => { setViewOutputFromSweepJob(true); - setViewOutputFromJob(parseInt(jobId)); + const jobIdStr = + jobId === null || jobId === undefined ? '' : String(jobId); + if (!jobIdStr || jobIdStr === '-1' || jobIdStr === 'NaN') return; + setViewOutputFromJob(jobIdStr); }} onViewSweepResults={(jobId) => { - setViewSweepResultsFromJob(parseInt(jobId)); + setViewSweepResultsFromJob(jobId && jobId !== 'NaN' ? jobId : null); }} onViewInteractive={(jobId) => - setInteractiveJobForModal(parseInt(jobId)) + setInteractiveJobForModal(jobId && jobId !== 'NaN' ? jobId : null) } onViewTrackio={(jobId) => - setTrackioJobIdForModal(parseInt(jobId, 10)) + setTrackioJobIdForModal(jobId && jobId !== 'NaN' ? jobId : null) } loading={jobsIsLoading} selectMode={isCompareSelectMode} selectedJobIds={compareEvalJobIds.map((id) => String(id))} onToggleJobSelected={(jobId) => { setCompareEvalJobIds((prev) => { - const id = parseInt(jobId, 10); - if (Number.isNaN(id)) return prev; + const id = jobId; + if (!id || id === 'NaN') return prev; if (prev.includes(id)) { return prev.filter((existing) => existing !== id); } @@ -1354,33 +1388,34 @@ export default function Tasks({ subtype }: { subtype?: string }) { setViewSweepResultsFromJob(jobId)} + setJobId={(jobId: string | null) => setViewSweepResultsFromJob(jobId)} /> setViewOutputFromJob(jobId)} + setJobId={(jobId: string | null) => setViewOutputFromJob(jobId)} jobStatus={ - jobs?.find((j: any) => j.id === viewOutputFromJob)?.status || '' + jobs?.find((j: any) => String(j.id) === viewOutputFromJob)?.status || + '' } /> setViewArtifactsFromJob(-1)} + open={viewArtifactsFromJob !== null} + onClose={() => setViewArtifactsFromJob(null)} jobId={viewArtifactsFromJob} /> setViewProfilingFromJob(-1)} + open={viewProfilingFromJob !== null} + onClose={() => setViewProfilingFromJob(null)} jobId={viewProfilingFromJob} /> setViewCheckpointsFromJob(-1)} + open={viewCheckpointsFromJob !== null} + onClose={() => setViewCheckpointsFromJob(null)} jobId={viewCheckpointsFromJob} /> setViewEvalResultsFromJob(-1)} + open={viewEvalResultsFromJob !== null} + onClose={() => setViewEvalResultsFromJob(null)} jobId={viewEvalResultsFromJob} /> setInteractiveJobForModal(jobId)} + setJobId={(jobId: string | null) => setInteractiveJobForModal(jobId)} /> - setViewJobDatasetsFromJob(-1)} - jobId={viewJobDatasetsFromJob} - /> - setViewJobModelsFromJob(-1)} - jobId={viewJobModelsFromJob} - /> + {viewJobDatasetsFromJob !== null && ( + setViewJobDatasetsFromJob(null)} + jobId={viewJobDatasetsFromJob} + /> + )} + {viewJobModelsFromJob !== null && ( + setViewJobModelsFromJob(null)} + jobId={viewJobModelsFromJob} + /> + )} setViewFileBrowserFromJob(-1)} - jobId={viewFileBrowserFromJob} + open={viewFileBrowserFromJob !== null} + onClose={() => setViewFileBrowserFromJob(null)} + jobId={viewFileBrowserFromJob ?? ''} /> void; } diff --git a/src/renderer/components/Experiment/Tasks/ViewArtifactsModal.tsx b/src/renderer/components/Experiment/Tasks/ViewArtifactsModal.tsx index 603ea06a0..9ad84cfc8 100644 --- a/src/renderer/components/Experiment/Tasks/ViewArtifactsModal.tsx +++ b/src/renderer/components/Experiment/Tasks/ViewArtifactsModal.tsx @@ -23,7 +23,7 @@ import Model3DViewer from 'renderer/components/Shared/Model3DViewer'; interface ViewArtifactsModalProps { open: boolean; onClose: () => void; - jobId: number | string; + jobId: number | string | null; } interface Artifact { @@ -110,6 +110,7 @@ export default function ViewArtifactsModal({ }; const handleViewArtifact = async (artifact: Artifact) => { + if (!jobId) return; setSelectedArtifact(artifact); setPreviewLoading(true); setPreviewError(null); @@ -191,6 +192,7 @@ export default function ViewArtifactsModal({ }; const handleDownloadArtifact = async (artifact: Artifact) => { + if (!jobId) return; try { const downloadUrl = getAPIFullPath('jobs', ['getArtifact'], { experimentId: experimentInfo?.id, @@ -224,6 +226,7 @@ export default function ViewArtifactsModal({ }; const handleDownloadAllArtifacts = async () => { + if (!jobId) return; try { setIsDownloading(true); const downloadUrl = getAPIFullPath('jobs', ['downloadAllArtifacts'], { diff --git a/src/renderer/components/Experiment/Tasks/ViewCheckpointsModal.tsx b/src/renderer/components/Experiment/Tasks/ViewCheckpointsModal.tsx index a02640d81..59ab9f7b8 100644 --- a/src/renderer/components/Experiment/Tasks/ViewCheckpointsModal.tsx +++ b/src/renderer/components/Experiment/Tasks/ViewCheckpointsModal.tsx @@ -16,7 +16,17 @@ import { useExperimentInfo } from 'renderer/lib/ExperimentInfoContext'; import { useNotification } from 'renderer/components/Shared/NotificationSystem'; import { formatBytes } from 'renderer/lib/utils'; -export default function ViewCheckpointsModal({ open, onClose, jobId }) { +interface ViewCheckpointsModalProps { + open: boolean; + onClose: () => void; + jobId: string | null; +} + +export default function ViewCheckpointsModal({ + open, + onClose, + jobId, +}: ViewCheckpointsModalProps) { const { experimentInfo } = useExperimentInfo(); const { addNotification } = useNotification(); const [resumingCheckpoint, setResumingCheckpoint] = useState( @@ -36,6 +46,7 @@ export default function ViewCheckpointsModal({ open, onClose, jobId }) { }); return; } + if (!jobId) return; setResumingCheckpoint(checkpoint.filename); try { diff --git a/src/renderer/components/Experiment/Tasks/ViewEvalResultsModal.tsx b/src/renderer/components/Experiment/Tasks/ViewEvalResultsModal.tsx index 7dd511caf..cb870fd5b 100644 --- a/src/renderer/components/Experiment/Tasks/ViewEvalResultsModal.tsx +++ b/src/renderer/components/Experiment/Tasks/ViewEvalResultsModal.tsx @@ -51,7 +51,7 @@ const ViewEvalResultsModal = ({ }: { open: boolean; onClose: () => void; - jobId: number | string; + jobId: number | string | null; }) => { const { experimentInfo } = useExperimentInfo(); const [selectedFileIndex, setSelectedFileIndex] = useState(0); diff --git a/src/renderer/components/Experiment/Tasks/ViewJobDatasetsModal.tsx b/src/renderer/components/Experiment/Tasks/ViewJobDatasetsModal.tsx index 268eaad81..6e3419531 100644 --- a/src/renderer/components/Experiment/Tasks/ViewJobDatasetsModal.tsx +++ b/src/renderer/components/Experiment/Tasks/ViewJobDatasetsModal.tsx @@ -26,7 +26,7 @@ import SaveToRegistryDialog, { SaveVersionInfo } from './SaveToRegistryDialog'; interface ViewJobDatasetsModalProps { open: boolean; onClose: () => void; - jobId: number | string; + jobId: number | string | null; } interface Dataset { diff --git a/src/renderer/components/Experiment/Tasks/ViewJobModelsModal.tsx b/src/renderer/components/Experiment/Tasks/ViewJobModelsModal.tsx index 1f7773077..8ae8dcb81 100644 --- a/src/renderer/components/Experiment/Tasks/ViewJobModelsModal.tsx +++ b/src/renderer/components/Experiment/Tasks/ViewJobModelsModal.tsx @@ -28,7 +28,7 @@ import SaveToRegistryDialog, { SaveVersionInfo } from './SaveToRegistryDialog'; interface ViewJobModelsModalProps { open: boolean; onClose: () => void; - jobId: number | string; + jobId: number | string | null; } interface Model { diff --git a/src/renderer/components/Experiment/Tasks/ViewOutputModalStreaming.tsx b/src/renderer/components/Experiment/Tasks/ViewOutputModalStreaming.tsx index 12a30c545..df01f5f46 100644 --- a/src/renderer/components/Experiment/Tasks/ViewOutputModalStreaming.tsx +++ b/src/renderer/components/Experiment/Tasks/ViewOutputModalStreaming.tsx @@ -8,8 +8,8 @@ const TAB_LABELS: Record = { }; interface ViewOutputModalStreamingProps { - jobId: number; - setJobId: (jobId: number) => void; + jobId: string | null; + setJobId: (jobId: string | null) => void; /** Which tabs to show, in order. e.g. ['output', 'provider'] or ['provider'] for interactive tasks. */ tabs?: ('output' | 'provider')[]; /** Current job status string (e.g. 'RUNNING', 'COMPLETE'). */ @@ -22,9 +22,7 @@ function ViewOutputModalStreaming({ tabs = ['output', 'provider'], jobStatus = '', }: ViewOutputModalStreamingProps) { - if (jobId === -1) { - return null; - } + if (!jobId) return null; const showTabList = tabs.length > 1; const title = showTabList @@ -33,9 +31,9 @@ function ViewOutputModalStreaming({ return ( { - setJobId(-1); + setJobId(null); }} > void; - jobId: number; + jobId: string | null; } export default function ViewProfilingModal({ @@ -30,7 +30,7 @@ export default function ViewProfilingModal({ Profiling – Job {jobId} - {jobId !== -1 && } + {jobId && } diff --git a/src/renderer/components/Experiment/Tasks/ViewSweepResultsModal.tsx b/src/renderer/components/Experiment/Tasks/ViewSweepResultsModal.tsx index 9bc6dbb13..ec0237b15 100644 --- a/src/renderer/components/Experiment/Tasks/ViewSweepResultsModal.tsx +++ b/src/renderer/components/Experiment/Tasks/ViewSweepResultsModal.tsx @@ -17,8 +17,8 @@ import { useExperimentInfo } from 'renderer/lib/ExperimentInfoContext'; import { fetcher } from 'renderer/lib/transformerlab-api-sdk'; interface ViewSweepResultsModalProps { - jobId: number; - setJobId: (jobId: number) => void; + jobId: string | null; + setJobId: (jobId: string | null) => void; } export default function ViewSweepResultsModal({ @@ -31,7 +31,7 @@ export default function ViewSweepResultsModal({ // Always call hooks at the top level - conditionally enable the query instead const { data, error, isLoading } = useSWR( - jobId !== -1 && experimentInfo + jobId && experimentInfo ? chatAPI.Endpoints.ComputeProvider.GetSweepResults(String(jobId)) : null, fetcher, @@ -40,7 +40,7 @@ export default function ViewSweepResultsModal({ }, ); - if (jobId === -1 || !experimentInfo) { + if (!jobId || !experimentInfo) { return null; } @@ -260,7 +260,7 @@ export default function ViewSweepResultsModal({ { - setJobId(-1); + setJobId(null); }} >