diff --git a/.github/workflows/pytest-sdk.yml b/.github/workflows/pytest-sdk.yml index 178c0505a..b4e7464f3 100644 --- a/.github/workflows/pytest-sdk.yml +++ b/.github/workflows/pytest-sdk.yml @@ -41,7 +41,5 @@ jobs: uv pip install -e . - name: Run SDK tests (uv) - env: - PYTEST_DISABLE_PLUGIN_AUTOLOAD: "1" run: | uv run pytest diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index d04b7a472..09221611e 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -51,7 +51,6 @@ jobs: - name: Test with pytest run: | pytest --cov=transformerlab --cov-branch --cov-report=xml -k 'not test_teams' - - name: Upload results to Codecov uses: codecov/codecov-action@v5 with: diff --git a/api/api.py b/api/api.py index 266b8cd2a..bb81896db 100644 --- a/api/api.py +++ b/api/api.py @@ -104,9 +104,6 @@ # used internally to set constants that are shared between separate processes. They are not meant to be # to be overriden by the user. os.environ["_TFL_SOURCE_CODE_DIR"] = dirs.TFL_SOURCE_CODE_DIR -# The temporary image directory for transformerlab (default; per-request overrides computed in routes) -temp_image_dir = storage.join(get_workspace_dir(), "temp", "images") -os.environ["TLAB_TEMP_IMAGE_DIR"] = str(temp_image_dir) @asynccontextmanager @@ -114,18 +111,26 @@ async def lifespan(app: FastAPI): """Docs on lifespan events: https://fastapi.tiangolo.com/advanced/events/""" # Do the following at API Startup: print_launch_message() + # Initialize directories early + from transformerlab.shared import dirs as shared_dirs + + await shared_dirs.initialize_dirs() + + # Set the temporary image directory for transformerlab (computed async) + temp_image_dir = storage.join(await get_workspace_dir(), "temp", "images") + os.environ["TLAB_TEMP_IMAGE_DIR"] = str(temp_image_dir) # Validate cloud credentials early - fail fast if missing validate_cloud_credentials() - galleries.update_gallery_cache() + await galleries.update_gallery_cache() spawn_fastchat_controller_subprocess() await db.init() # This now runs Alembic migrations internally print("✅ SEED DATA") # Initialize experiments - seed_default_experiments() + await seed_default_experiments() # Seed default admin user await seed_default_admin_user() # Cancel any running jobs - cancel_in_progress_jobs() + await cancel_in_progress_jobs() # Create buckets for all existing teams if TFL_API_STORAGE_URI is enabled if os.getenv("TFL_API_STORAGE_URI"): @@ -362,7 +367,7 @@ async def server_worker_start( # then we check to see if we are an experiment elif experiment_id is not None: try: - experiment = experiment_get(experiment_id) + experiment = await experiment_get(experiment_id) experiment_config = ( experiment["config"] if isinstance(experiment["config"], dict) @@ -394,7 +399,7 @@ async def server_worker_start( model_architecture = model_architecture plugin_name = inference_engine - plugin_location = lab_dirs.plugin_dir_by_name(plugin_name) + plugin_location = await lab_dirs.plugin_dir_by_name(plugin_name) model = model_name if model_filename is not None and model_filename != "": @@ -402,7 +407,7 @@ async def server_worker_start( if adaptor != "": # Resolve per-request workspace if multitenant - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() adaptor = f"{workspace_dir}/adaptors/{secure_filename(model)}/{adaptor}" params = [ @@ -419,14 +424,14 @@ async def server_worker_start( json.dumps(inference_params), ] - job_id = job_create(type="LOAD_MODEL", status="STARTED", job_data="{}", experiment_id=experiment_id) + job_id = await job_create(type="LOAD_MODEL", status="STARTED", job_data="{}", experiment_id=experiment_id) print("Loading plugin loader instead of default worker") from lab.dirs import get_global_log_path - with storage.open(get_global_log_path(), "a") as global_log: - global_log.write(f"🏃 Loading Inference Server for {model_name} with {inference_params}\n") + async with await storage.open(await get_global_log_path(), "a") as global_log: + await global_log.write(f"🏃 Loading Inference Server for {model_name} with {inference_params}\n") # Pass organization_id as environment variable to subprocess from transformerlab.shared.request_context import get_current_org_id @@ -447,8 +452,8 @@ async def server_worker_start( if exitcode == 99: from lab.dirs import get_global_log_path - with storage.open(get_global_log_path(), "a") as global_log: - global_log.write( + async with await storage.open(await get_global_log_path(), "a") as global_log: + await global_log.write( "GPU (CUDA) Out of Memory: Please try a smaller model or a different inference engine. Restarting the server may free up resources.\n" ) return { @@ -458,20 +463,20 @@ async def server_worker_start( if exitcode is not None and exitcode != 0: from lab.dirs import get_global_log_path - with storage.open(get_global_log_path(), "a") as global_log: - global_log.write(f"Error loading model: {model_name} with exit code {exitcode}\n") - job = job_get(job_id) + async with await storage.open(await get_global_log_path(), "a") as global_log: + await global_log.write(f"Error loading model: {model_name} with exit code {exitcode}\n") + job = await job_get(job_id) error_msg = None if job and job.get("job_data"): error_msg = job["job_data"].get("error_msg") if not error_msg: error_msg = f"Exit code {exitcode}" - job_update_status(job_id, "FAILED", experiment_id=experiment_id, error_msg=error_msg) + await job_update_status(job_id, "FAILED", experiment_id=experiment_id, error_msg=error_msg) return {"status": "error", "message": error_msg} from lab.dirs import get_global_log_path - with storage.open(get_global_log_path(), "a") as global_log: - global_log.write(f"Model loaded successfully: {model_name}\n") + async with await storage.open(await get_global_log_path(), "a") as global_log: + await global_log.write(f"Model loaded successfully: {model_name}\n") return {"status": "success", "job_id": job_id} @@ -630,7 +635,9 @@ def run(): ) if args.https: - cert_path, key_path = ensure_persistent_self_signed_cert() + import asyncio + + cert_path, key_path = asyncio.run(ensure_persistent_self_signed_cert()) uvicorn.run( "api:app", host=args.host, port=args.port, log_level="warning", ssl_certfile=cert_path, ssl_keyfile=key_path ) diff --git a/api/pyproject.toml b/api/pyproject.toml index df2054e61..2ab3c8b27 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -33,8 +33,8 @@ dependencies = [ "soundfile==0.13.1", "tensorboardX==2.6.2.2", "timm==1.0.15", - "transformerlab==0.0.58", - "transformerlab-inference==0.2.51", + "transformerlab==0.0.62", + "transformerlab-inference==0.2.52", "transformers==4.57.1", "wandb==0.19.10", "werkzeug==3.1.3", @@ -109,4 +109,4 @@ cpu = [ "tensorboard==2.18.0", "tiktoken==0.8.0", "watchfiles==1.0.4", -] \ No newline at end of file +] diff --git a/api/test/api/conftest.py b/api/test/api/conftest.py index a52f8d96d..c55ab7402 100644 --- a/api/test/api/conftest.py +++ b/api/test/api/conftest.py @@ -5,6 +5,7 @@ # Create test directories before setting environment variables os.makedirs("test/tmp/", exist_ok=True) +os.makedirs("test/tmp/webapp", exist_ok=True) # Create webapp directory for static files os.environ["TFL_HOME_DIR"] = "test/tmp/" # Note: TFL_WORKSPACE_DIR is not set so that get_workspace_dir() will use the org-based diff --git a/api/test/api/test_data.py b/api/test/api/test_data.py index c22ec60f1..1ae4b1625 100644 --- a/api/test/api/test_data.py +++ b/api/test/api/test_data.py @@ -1,6 +1,7 @@ import pytest import os import json +import asyncio from io import BytesIO from PIL import Image from pathlib import Path @@ -13,7 +14,7 @@ def cleanup_dataset(dataset_id, client): from transformerlab.shared.shared import slugify import shutil - dataset_dir = dirs.dataset_dir_by_id(slugify(dataset_id)) + dataset_dir = asyncio.run(dirs.dataset_dir_by_id(slugify(dataset_id))) shutil.rmtree(dataset_dir, ignore_errors=True) client.get(f"/data/delete?dataset_id={dataset_id}") @@ -54,7 +55,7 @@ def test_data_info(client): def test_save_metadata(client): source_dataset_id = "source_dataset" new_dataset_id = "destination_dataset" - dataset_dir = dirs.dataset_dir_by_id(slugify(source_dataset_id)) + dataset_dir = asyncio.run(dirs.dataset_dir_by_id(slugify(source_dataset_id))) os.makedirs(dataset_dir, exist_ok=True) # Create dummy JPEG image @@ -96,7 +97,7 @@ def test_save_metadata(client): data = response.json() assert data["status"] == "success" - new_dataset_dir = Path(dirs.dataset_dir_by_id(slugify(new_dataset_id))) + new_dataset_dir = Path(asyncio.run(dirs.dataset_dir_by_id(slugify(new_dataset_id)))) assert new_dataset_dir.exists() cleanup_dataset(source_dataset_id, client) @@ -106,7 +107,7 @@ def test_save_metadata(client): @pytest.mark.skip(reason="Skipping as it contains application-specific logic") def test_edit_with_template(client): dataset_id = "test_dataset" - dataset_dir = dirs.dataset_dir_by_id(slugify(dataset_id)) + dataset_dir = asyncio.run(dirs.dataset_dir_by_id(slugify(dataset_id))) os.makedirs(dataset_dir, exist_ok=True) image_path = os.path.join(dataset_dir, "image.jpg") diff --git a/api/test/api/test_dataset_service.py b/api/test/api/test_dataset_service.py index 870dcc703..ee8fcf27e 100644 --- a/api/test/api/test_dataset_service.py +++ b/api/test/api/test_dataset_service.py @@ -16,7 +16,8 @@ def tmp_dataset_dir(tmp_path: Path) -> Path: return tmp_path -def test_load_local_dataset_filters_index_and_hidden(tmp_dataset_dir: Path, monkeypatch): +@pytest.mark.asyncio +async def test_load_local_dataset_filters_index_and_hidden(tmp_dataset_dir: Path, monkeypatch): # Import inside test to ensure module path resolution for monkeypatching from transformerlab.services import dataset_service @@ -30,7 +31,7 @@ def fake_load_dataset(path=None, data_files=None, streaming=False): monkeypatch.setattr(dataset_service, "load_dataset", fake_load_dataset) - result = dataset_service.load_local_dataset(str(tmp_dataset_dir)) + result = await dataset_service.load_local_dataset(str(tmp_dataset_dir)) assert result == {"ok": True} assert captured["path"] == str(tmp_dataset_dir) @@ -43,7 +44,8 @@ def fake_load_dataset(path=None, data_files=None, streaming=False): assert captured["streaming"] is False -def test_load_local_dataset_uses_explicit_data_files(tmp_path: Path, monkeypatch): +@pytest.mark.asyncio +async def test_load_local_dataset_uses_explicit_data_files(tmp_path: Path, monkeypatch): from transformerlab.services import dataset_service # Explicit files provided (note: function should not re-filter these) @@ -60,7 +62,9 @@ def fake_load_dataset(path=None, data_files=None, streaming=False): monkeypatch.setattr(dataset_service, "load_dataset", fake_load_dataset) - result = dataset_service.load_local_dataset(str(tmp_path), data_files=["keep.me", "index.json"], streaming=True) + result = await dataset_service.load_local_dataset( + str(tmp_path), data_files=["keep.me", "index.json"], streaming=True + ) assert result == {"ok": True} assert captured["path"] == str(tmp_path) @@ -72,7 +76,8 @@ def fake_load_dataset(path=None, data_files=None, streaming=False): assert captured["streaming"] is True -def test_load_local_dataset_fallback_when_no_valid_files(tmp_path: Path, monkeypatch): +@pytest.mark.asyncio +async def test_load_local_dataset_fallback_when_no_valid_files(tmp_path: Path, monkeypatch): from transformerlab.services import dataset_service # Only metadata/hidden files present @@ -89,7 +94,7 @@ def fake_load_dataset(path=None, data_files=None, streaming=False): monkeypatch.setattr(dataset_service, "load_dataset", fake_load_dataset) - result = dataset_service.load_local_dataset(str(tmp_path)) + result = await dataset_service.load_local_dataset(str(tmp_path)) assert result == {"ok": True} assert captured["path"] == str(tmp_path) diff --git a/api/test/api/test_diffusion.py b/api/test/api/test_diffusion.py index 4b4e16a69..81ac282fb 100644 --- a/api/test/api/test_diffusion.py +++ b/api/test/api/test_diffusion.py @@ -1,8 +1,26 @@ import pytest -from unittest.mock import patch, MagicMock, mock_open +from unittest.mock import patch, MagicMock, AsyncMock, mock_open import json +class AsyncContextManagerMock: + """Helper class to create async context manager mocks for storage.open""" + + def __init__(self, file_content=""): + self.file_content = file_content + self.file_obj = MagicMock() + self.file_obj.read = AsyncMock(return_value=file_content) + self.file_obj.write = AsyncMock() + self.file_obj.__aiter__ = AsyncMock(return_value=iter([])) + self.file_obj.__anext__ = AsyncMock(side_effect=StopAsyncIteration) + + async def __aenter__(self): + return self.file_obj + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return False + + def test_diffusion_generate_success(client): try: import transformerlab.plugins.image_diffusion.main as diffusion_main @@ -189,7 +207,7 @@ def test_is_valid_diffusion_model_empty_class_name(client): def test_get_history_success(client): """Test getting diffusion history with default parameters""" with ( - patch("transformerlab.routers.experiment.diffusion.load_history") as mock_load_history, + patch("transformerlab.routers.experiment.diffusion.load_history", new_callable=AsyncMock) as mock_load_history, ): mock_history = MagicMock() mock_history.images = [] @@ -211,7 +229,7 @@ def test_get_history_success(client): def test_get_history_with_pagination(client): """Test getting diffusion history with pagination parameters""" with ( - patch("transformerlab.routers.experiment.diffusion.load_history") as mock_load_history, + patch("transformerlab.routers.experiment.diffusion.load_history", new_callable=AsyncMock) as mock_load_history, ): mock_history = MagicMock() mock_history.images = [] @@ -247,7 +265,9 @@ def test_get_history_invalid_offset(client): def test_get_image_by_id_not_found(client): """Test getting a non-existent image by ID""" with ( - patch("transformerlab.routers.experiment.diffusion.find_image_by_id") as mock_find_image, + patch( + "transformerlab.routers.experiment.diffusion.find_image_by_id", new_callable=AsyncMock + ) as mock_find_image, ): mock_find_image.return_value = None @@ -259,7 +279,9 @@ def test_get_image_by_id_not_found(client): def test_get_image_by_id_index_out_of_range(client): """Test getting image with index out of range""" with ( - patch("transformerlab.routers.experiment.diffusion.find_image_by_id") as mock_find_image, + patch( + "transformerlab.routers.experiment.diffusion.find_image_by_id", new_callable=AsyncMock + ) as mock_find_image, patch("transformerlab.routers.experiment.diffusion.get_images_dir", return_value="/fake/images"), patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True), patch("transformerlab.routers.experiment.diffusion.storage.isdir", return_value=True), @@ -280,7 +302,9 @@ def test_get_image_by_id_index_out_of_range(client): def test_get_image_info_by_id_success(client): """Test getting image metadata by ID""" with ( - patch("transformerlab.routers.experiment.diffusion.find_image_by_id") as mock_find_image, + patch( + "transformerlab.routers.experiment.diffusion.find_image_by_id", new_callable=AsyncMock + ) as mock_find_image, patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True), patch("transformerlab.routers.experiment.diffusion.storage.isdir", return_value=True), patch("transformerlab.routers.experiment.diffusion.storage.ls", return_value=["0.png", "1.png", "2.png"]), @@ -303,7 +327,9 @@ def test_get_image_info_by_id_success(client): def test_get_image_count_success(client): """Test getting image count for an image set""" with ( - patch("transformerlab.routers.experiment.diffusion.find_image_by_id") as mock_find_image, + patch( + "transformerlab.routers.experiment.diffusion.find_image_by_id", new_callable=AsyncMock + ) as mock_find_image, patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True), patch("transformerlab.routers.experiment.diffusion.storage.isdir", return_value=True), patch("transformerlab.routers.experiment.diffusion.storage.ls", return_value=["0.png", "1.png"]), @@ -322,14 +348,23 @@ def test_get_image_count_success(client): assert data["num_images"] == 2 -def test_delete_image_from_history_not_found(client): +@pytest.mark.asyncio +async def test_delete_image_from_history_not_found(client): """Test deleting a non-existent image from history""" + mock_file_content = '[{"id": "other-id", "image_path": "/fake/path.png"}]' + mock_context = AsyncContextManagerMock(mock_file_content) + with ( - patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"), - patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True), + patch( + "transformerlab.routers.experiment.diffusion.get_history_file_path", + new_callable=AsyncMock, + return_value="/fake/history.json", + ), + patch("transformerlab.routers.experiment.diffusion.storage.exists", new_callable=AsyncMock, return_value=True), patch( "transformerlab.routers.experiment.diffusion.storage.open", - mock_open(read_data='[{"id": "other-id", "image_path": "/fake/path.png"}]'), + new_callable=AsyncMock, + return_value=mock_context, ), ): resp = client.delete("/experiment/test-exp-name/diffusion/history/non-existent-id") @@ -337,27 +372,37 @@ def test_delete_image_from_history_not_found(client): assert "Image with ID non-existent-id not found" in resp.json()["detail"] -def test_create_dataset_from_history_success(client): +@pytest.mark.asyncio +async def test_create_dataset_from_history_success(client): """Test creating a dataset from history images""" with ( - patch("transformerlab.routers.experiment.diffusion.find_image_by_id") as mock_find_image, - patch("transformerlab.routers.experiment.diffusion.Dataset.get") as mock_dataset_get, - patch("transformerlab.routers.experiment.diffusion.create_local_dataset") as mock_create_dataset, - patch("transformerlab.routers.experiment.diffusion.storage.makedirs"), - patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True), - patch("transformerlab.routers.experiment.diffusion.storage.isdir", return_value=True), + patch( + "transformerlab.routers.experiment.diffusion.find_image_by_id", new_callable=AsyncMock + ) as mock_find_image, + patch("transformerlab.routers.experiment.diffusion.Dataset.get", new_callable=AsyncMock) as mock_dataset_get, + patch( + "transformerlab.routers.experiment.diffusion.create_local_dataset", new_callable=AsyncMock + ) as mock_create_dataset, + patch("transformerlab.routers.experiment.diffusion.storage.makedirs", new_callable=AsyncMock), + patch("transformerlab.routers.experiment.diffusion.storage.exists", new_callable=AsyncMock, return_value=True), + patch("transformerlab.routers.experiment.diffusion.storage.isdir", new_callable=AsyncMock, return_value=True), patch( "transformerlab.routers.experiment.diffusion.storage.ls", + new_callable=AsyncMock, return_value=["/fake/path/folder/0.png", "/fake/path/folder/1.png"], ), - patch("transformerlab.routers.experiment.diffusion.storage.copy_file"), - patch("transformerlab.routers.experiment.diffusion.storage.open", mock_open()), + patch("transformerlab.routers.experiment.diffusion.storage.copy_file", new_callable=AsyncMock), + patch( + "transformerlab.routers.experiment.diffusion.storage.open", + new_callable=AsyncMock, + return_value=AsyncContextManagerMock(""), + ), ): # Mock Dataset.get to raise FileNotFoundError for non-existent dataset (new behavior) mock_dataset_get.side_effect = FileNotFoundError("Directory for Dataset with id 'test-dataset' not found") # Configure Dataset.get().get_dir() mock_dataset = MagicMock() - mock_dataset.get_dir.return_value = "/fake/dataset" + mock_dataset.get_dir = AsyncMock(return_value="/fake/dataset") mock_create_dataset.return_value = mock_dataset # Create mock image @@ -416,15 +461,17 @@ def test_create_dataset_invalid_image_ids(client): def test_create_dataset_existing_dataset(client): """Test creating dataset with name that already exists""" with ( - patch("transformerlab.routers.experiment.diffusion.Dataset.get") as mock_dataset_get, - patch("transformerlab.routers.experiment.diffusion.find_image_by_id") as mock_find_image, + patch("transformerlab.routers.experiment.diffusion.Dataset.get", new_callable=AsyncMock) as mock_dataset_get, + patch( + "transformerlab.routers.experiment.diffusion.find_image_by_id", new_callable=AsyncMock + ) as mock_find_image, ): # Mock Dataset.get to raise FileNotFoundError for non-existent dataset (new behavior) # but return a mock dataset for existing dataset - def mock_get_side_effect(dataset_id): + async def mock_get_side_effect(dataset_id): if dataset_id == "existing-dataset": mock_dataset = MagicMock() - mock_dataset.get_dir.return_value = "/fake/path/to/existing-dataset" + mock_dataset.get_dir = AsyncMock(return_value="/fake/path/to/existing-dataset") return mock_dataset else: raise FileNotFoundError(f"Directory for Dataset with id '{dataset_id}' not found") @@ -450,8 +497,10 @@ def mock_get_side_effect(dataset_id): def test_create_dataset_no_images_found(client): """Test creating dataset when no images are found for given IDs""" with ( - patch("transformerlab.routers.experiment.diffusion.find_image_by_id") as mock_find_image, - patch("transformerlab.routers.experiment.diffusion.Dataset.get") as mock_dataset_get, + patch( + "transformerlab.routers.experiment.diffusion.find_image_by_id", new_callable=AsyncMock + ) as mock_find_image, + patch("transformerlab.routers.experiment.diffusion.Dataset.get", new_callable=AsyncMock) as mock_dataset_get, ): # Mock Dataset.get to raise FileNotFoundError for non-existent dataset (new behavior) mock_dataset_get.side_effect = FileNotFoundError("Directory for Dataset with id 'test-dataset' not found") @@ -621,21 +670,27 @@ def test_is_valid_diffusion_model_inpainting_detection(client, inpainting_flag): assert "Architecture matches allowed SD" in data["reason"] -def test_load_history_success(): +@pytest.mark.asyncio +async def test_load_history_success(): """Test loading history with valid data""" + mock_file_content = '[{"id": "test-id", "model": "test-model", "prompt": "test prompt", "adaptor": "", "adaptor_scale": 1.0, "num_inference_steps": 20, "guidance_scale": 7.5, "seed": 42, "image_path": "/fake/path.png", "timestamp": "2023-01-01T00:00:00", "upscaled": false, "upscale_factor": 1, "negative_prompt": "", "eta": 0.0, "clip_skip": 0, "guidance_rescale": 0.0, "height": 512, "width": 512, "generation_time": 5.0, "num_images": 1, "input_image_path": "", "strength": 0.8, "is_img2img": false, "mask_image_path": "", "is_inpainting": false}]' + mock_context = AsyncContextManagerMock(mock_file_content) with ( - patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"), - patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True), + patch( + "transformerlab.routers.experiment.diffusion.get_history_file_path", + new_callable=AsyncMock, + return_value="/fake/history.json", + ), + patch("transformerlab.routers.experiment.diffusion.storage.exists", new_callable=AsyncMock, return_value=True), patch( "transformerlab.routers.experiment.diffusion.storage.open", - mock_open( - read_data='[{"id": "test-id", "model": "test-model", "prompt": "test prompt", "adaptor": "", "adaptor_scale": 1.0, "num_inference_steps": 20, "guidance_scale": 7.5, "seed": 42, "image_path": "/fake/path.png", "timestamp": "2023-01-01T00:00:00", "upscaled": false, "upscale_factor": 1, "negative_prompt": "", "eta": 0.0, "clip_skip": 0, "guidance_rescale": 0.0, "height": 512, "width": 512, "generation_time": 5.0, "num_images": 1, "input_image_path": "", "strength": 0.8, "is_img2img": false, "mask_image_path": "", "is_inpainting": false}]' - ), + new_callable=AsyncMock, + return_value=mock_context, ), ): from transformerlab.routers.experiment.diffusion import load_history - result = load_history(limit=50, offset=0, experiment_name=None) + result = await load_history(limit=50, offset=0, experiment_name=None) assert result.total == 1 assert len(result.images) == 1 @@ -644,7 +699,8 @@ def test_load_history_success(): assert result.images[0].prompt == "test prompt" -def test_load_history_with_pagination(): +@pytest.mark.asyncio +async def test_load_history_with_pagination(): """Test loading history with pagination parameters""" history_data = [] for i in range(10): @@ -678,16 +734,23 @@ def test_load_history_with_pagination(): } ) + mock_context = AsyncContextManagerMock(json.dumps(history_data)) with ( - patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"), - patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True), patch( - "transformerlab.routers.experiment.diffusion.storage.open", mock_open(read_data=json.dumps(history_data)) + "transformerlab.routers.experiment.diffusion.get_history_file_path", + new_callable=AsyncMock, + return_value="/fake/history.json", + ), + patch("transformerlab.routers.experiment.diffusion.storage.exists", new_callable=AsyncMock, return_value=True), + patch( + "transformerlab.routers.experiment.diffusion.storage.open", + new_callable=AsyncMock, + return_value=mock_context, ), ): from transformerlab.routers.experiment.diffusion import load_history - result = load_history(limit=3, offset=2, experiment_name=None) + result = await load_history(limit=3, offset=2, experiment_name=None) assert result.total == 10 assert len(result.images) == 3 @@ -696,36 +759,52 @@ def test_load_history_with_pagination(): assert result.images[2].id == "test-id-4" -def test_load_history_no_file(): +@pytest.mark.asyncio +async def test_load_history_no_file(): """Test loading history when history file doesn't exist""" with ( - patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"), - patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=False), + patch( + "transformerlab.routers.experiment.diffusion.get_history_file_path", + new_callable=AsyncMock, + return_value="/fake/history.json", + ), + patch("transformerlab.routers.experiment.diffusion.storage.exists", new_callable=AsyncMock, return_value=False), ): from transformerlab.routers.experiment.diffusion import load_history - result = load_history(experiment_name=None) + result = await load_history(experiment_name=None) assert result.total == 0 assert len(result.images) == 0 -def test_load_history_invalid_json(): +@pytest.mark.asyncio +async def test_load_history_invalid_json(): """Test loading history with corrupted JSON file""" + mock_context = AsyncContextManagerMock("invalid json") with ( - patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"), - patch("lab.storage.exists", return_value=True), - patch("transformerlab.routers.experiment.diffusion.storage.open", mock_open(read_data="invalid json")), + patch( + "transformerlab.routers.experiment.diffusion.get_history_file_path", + new_callable=AsyncMock, + return_value="/fake/history.json", + ), + patch("lab.storage.exists", new_callable=AsyncMock, return_value=True), + patch( + "transformerlab.routers.experiment.diffusion.storage.open", + new_callable=AsyncMock, + return_value=mock_context, + ), ): from transformerlab.routers.experiment.diffusion import load_history - result = load_history(experiment_name=None) + result = await load_history(experiment_name=None) assert result.total == 0 assert len(result.images) == 0 -def test_find_image_by_id_success(): +@pytest.mark.asyncio +async def test_find_image_by_id_success(): """Test finding an image by ID successfully""" history_data = [ { @@ -784,16 +863,23 @@ def test_find_image_by_id_success(): }, ] + mock_context = AsyncContextManagerMock(json.dumps(history_data)) with ( - patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"), - patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True), patch( - "transformerlab.routers.experiment.diffusion.storage.open", mock_open(read_data=json.dumps(history_data)) + "transformerlab.routers.experiment.diffusion.get_history_file_path", + new_callable=AsyncMock, + return_value="/fake/history.json", + ), + patch("transformerlab.routers.experiment.diffusion.storage.exists", new_callable=AsyncMock, return_value=True), + patch( + "transformerlab.routers.experiment.diffusion.storage.open", + new_callable=AsyncMock, + return_value=mock_context, ), ): from transformerlab.routers.experiment.diffusion import find_image_by_id - result = find_image_by_id("test-id-2", experiment_name=None) + result = await find_image_by_id("test-id-2", experiment_name=None) assert result is not None assert result.id == "test-id-2" @@ -801,7 +887,8 @@ def test_find_image_by_id_success(): assert result.prompt == "test prompt 2" -def test_find_image_by_id_not_found(client): +@pytest.mark.asyncio +async def test_find_image_by_id_not_found(client): """Test finding an image by ID when it doesn't exist""" history_data = [ { @@ -833,43 +920,65 @@ def test_find_image_by_id_not_found(client): } ] + mock_context = AsyncContextManagerMock(json.dumps(history_data)) with ( - patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"), - patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True), patch( - "transformerlab.routers.experiment.diffusion.storage.open", mock_open(read_data=json.dumps(history_data)) + "transformerlab.routers.experiment.diffusion.get_history_file_path", + new_callable=AsyncMock, + return_value="/fake/history.json", + ), + patch("transformerlab.routers.experiment.diffusion.storage.exists", new_callable=AsyncMock, return_value=True), + patch( + "transformerlab.routers.experiment.diffusion.storage.open", + new_callable=AsyncMock, + return_value=mock_context, ), ): from transformerlab.routers.experiment.diffusion import find_image_by_id - result = find_image_by_id("non-existent-id", experiment_name=None) + result = await find_image_by_id("non-existent-id", experiment_name=None) assert result is None -def test_find_image_by_id_no_file(): +@pytest.mark.asyncio +async def test_find_image_by_id_no_file(): """Test finding an image by ID when history file doesn't exist""" with ( - patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"), - patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=False), + patch( + "transformerlab.routers.experiment.diffusion.get_history_file_path", + new_callable=AsyncMock, + return_value="/fake/history.json", + ), + patch("transformerlab.routers.experiment.diffusion.storage.exists", new_callable=AsyncMock, return_value=False), ): from transformerlab.routers.experiment.diffusion import find_image_by_id - result = find_image_by_id("test-id", experiment_name=None) + result = await find_image_by_id("test-id", experiment_name=None) assert result is None -def test_find_image_by_id_invalid_json(): +@pytest.mark.asyncio +async def test_find_image_by_id_invalid_json(): """Test finding an image by ID with corrupted JSON file""" + mock_context = AsyncContextManagerMock("invalid json") with ( - patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"), - patch("lab.storage.exists", return_value=True), - patch("transformerlab.routers.experiment.diffusion.storage.open", mock_open(read_data="invalid json")), + patch( + "transformerlab.routers.experiment.diffusion.get_history_file_path", + new_callable=AsyncMock, + return_value="/fake/history.json", + ), + patch("lab.storage.exists", new_callable=AsyncMock, return_value=True), + patch( + "transformerlab.routers.experiment.diffusion.storage.open", + new_callable=AsyncMock, + return_value=mock_context, + ), ): from transformerlab.routers.experiment.diffusion import find_image_by_id - result = find_image_by_id("test-id", experiment_name=None) + result = await find_image_by_id("test-id", experiment_name=None) assert result is None diff --git a/api/test/api/test_experiment_service.py b/api/test/api/test_experiment_service.py index ee2c29a9c..46f0500dc 100644 --- a/api/test/api/test_experiment_service.py +++ b/api/test/api/test_experiment_service.py @@ -14,17 +14,22 @@ def tmp_experiments_dir(monkeypatch, tmp_path): """ experiments_dir = tmp_path / "experiments" experiments_dir.mkdir() - monkeypatch.setattr(lab_dirs, "get_experiments_dir", lambda: str(experiments_dir)) + + async def mock_get_experiments_dir(): + return str(experiments_dir) + + monkeypatch.setattr(lab_dirs, "get_experiments_dir", mock_get_experiments_dir) return str(experiments_dir) -def test_experiment_create_and_get_real(tmp_experiments_dir): +@pytest.mark.asyncio +async def test_experiment_create_and_get_real(tmp_experiments_dir): name = f"real_exp_{uuid.uuid4().hex[:8]}" cfg = {"description": "integration test experiment"} - exp_id = experiment_service.experiment_create(name, cfg) + exp_id = await experiment_service.experiment_create(name, cfg) assert exp_id == name - data = experiment_service.experiment_get(exp_id) + data = await experiment_service.experiment_get(exp_id) assert data is not None # config may be stored as dict or string; normalize for assertion cfg_field = data.get("config", {}) @@ -33,61 +38,64 @@ def test_experiment_create_and_get_real(tmp_experiments_dir): assert cfg_field.get("description") == "integration test experiment" -def test_missing_experiment_returns_none(tmp_experiments_dir): +@pytest.mark.asyncio +async def test_missing_experiment_returns_none(tmp_experiments_dir): # experiment_get should return None for non-existent id (FileNotFoundError handled) - assert experiment_service.experiment_get("no_such_experiment") is None + assert await experiment_service.experiment_get("no_such_experiment") is None # Added test to hit the new FileNotFoundError except-clauses in experiment_service -def test_missing_experiment_operations_handle_FileNotFound(tmp_experiments_dir): +@pytest.mark.asyncio +async def test_missing_experiment_operations_handle_FileNotFound(tmp_experiments_dir): """ Call the service functions that now catch FileNotFoundError to ensure those branches are executed and no exceptions are raised. """ # These calls should not raise even if the experiment does not exist - experiment_service.experiment_delete("no_such") - experiment_service.experiment_update("no_such", {"a": 1}) - experiment_service.experiment_update_config("no_such", "k", "v") - experiment_service.experiment_save_prompt_template("no_such", "tmpl") - experiment_service.experiment_update_configs("no_such", {"x": 2}) + await experiment_service.experiment_delete("no_such") + await experiment_service.experiment_update("no_such", {"a": 1}) + await experiment_service.experiment_update_config("no_such", "k", "v") + await experiment_service.experiment_save_prompt_template("no_such", "tmpl") + await experiment_service.experiment_update_configs("no_such", {"x": 2}) -def test_update_and_delete_flow_real(tmp_experiments_dir): +@pytest.mark.asyncio +async def test_update_and_delete_flow_real(tmp_experiments_dir): name = f"cycle_exp_{uuid.uuid4().hex[:8]}" - experiment_service.experiment_create(name, {"a": 1}) + await experiment_service.experiment_create(name, {"a": 1}) # update whole config - experiment_service.experiment_update(name, {"b": 2}) - data = experiment_service.experiment_get(name) + await experiment_service.experiment_update(name, {"b": 2}) + data = await experiment_service.experiment_get(name) cfg = data.get("config", {}) if isinstance(cfg, str): cfg = json.loads(cfg) assert cfg.get("b") == 2 # update single field - experiment_service.experiment_update_config(name, "c", 3) - data = experiment_service.experiment_get(name) + await experiment_service.experiment_update_config(name, "c", 3) + data = await experiment_service.experiment_get(name) cfg = data.get("config", {}) if isinstance(cfg, str): cfg = json.loads(cfg) assert cfg.get("c") == 3 # save prompt template - experiment_service.experiment_save_prompt_template(name, "tmpl123") - data = experiment_service.experiment_get(name) + await experiment_service.experiment_save_prompt_template(name, "tmpl123") + data = await experiment_service.experiment_get(name) cfg = data.get("config", {}) if isinstance(cfg, str): cfg = json.loads(cfg) assert cfg.get("prompt_template") == "tmpl123" # update multiple fields - experiment_service.experiment_update_configs(name, {"m": 9, "n": 10}) - data = experiment_service.experiment_get(name) + await experiment_service.experiment_update_configs(name, {"m": 9, "n": 10}) + data = await experiment_service.experiment_get(name) cfg = data.get("config", {}) if isinstance(cfg, str): cfg = json.loads(cfg) assert cfg.get("m") == 9 and cfg.get("n") == 10 # delete and confirm gone - experiment_service.experiment_delete(name) - assert experiment_service.experiment_get(name) is None + await experiment_service.experiment_delete(name) + assert await experiment_service.experiment_get(name) is None diff --git a/api/test/api/test_plugins.py b/api/test/api/test_plugins.py index 0b2ba6322..066069a23 100644 --- a/api/test/api/test_plugins.py +++ b/api/test/api/test_plugins.py @@ -83,7 +83,10 @@ async def test_delete_plugin_files_from_workspace(): # Create a temporary plugin directory structure with tempfile.TemporaryDirectory() as temp_dir: # Mock the get_plugin_dir function to use our temp directory - with patch("lab.dirs.get_plugin_dir", return_value=temp_dir): + async def mock_get_plugin_dir(): + return temp_dir + + with patch("lab.dirs.get_plugin_dir", side_effect=mock_get_plugin_dir): test_plugin_id = "test_plugin_to_delete" plugin_path = os.path.join(temp_dir, test_plugin_id) @@ -110,7 +113,11 @@ async def test_delete_plugin_files_from_workspace_nonexistent(): from transformerlab.routers.plugins import delete_plugin_files_from_workspace with tempfile.TemporaryDirectory() as temp_dir: - with patch("lab.dirs.get_plugin_dir", return_value=temp_dir): + + async def mock_get_plugin_dir(): + return temp_dir + + with patch("lab.dirs.get_plugin_dir", side_effect=mock_get_plugin_dir): # This should not raise an error even if plugin doesn't exist await delete_plugin_files_from_workspace("nonexistent_plugin") @@ -138,10 +145,16 @@ async def test_copy_plugin_files_to_workspace(): with open(test_file, "w") as f: f.write('{"name": "Test Plugin", "version": "1.0"}') + async def mock_get_plugin_dir(): + return plugin_dir + + async def mock_plugin_dir_by_name(name): + return os.path.join(plugin_dir, name) + with ( patch.object(dirs, "PLUGIN_PRELOADED_GALLERY", gallery_dir), - patch("lab.dirs.get_plugin_dir", return_value=plugin_dir), - patch("lab.dirs.plugin_dir_by_name", lambda x: os.path.join(plugin_dir, x)), + patch("lab.dirs.get_plugin_dir", side_effect=mock_get_plugin_dir), + patch("lab.dirs.plugin_dir_by_name", side_effect=mock_plugin_dir_by_name), ): # Copy the plugin await copy_plugin_files_to_workspace(test_plugin_id) diff --git a/api/test/api/test_recipes.py b/api/test/api/test_recipes.py index 7c1ae6584..254bdb8d3 100644 --- a/api/test/api/test_recipes.py +++ b/api/test/api/test_recipes.py @@ -1,4 +1,4 @@ -from unittest.mock import patch, AsyncMock, Mock +from unittest.mock import patch, AsyncMock RECIPES = [ { @@ -43,7 +43,7 @@ def test_check_dependencies_all_installed(client): ], ), patch("transformerlab.models.model_helper.list_installed_models", AsyncMock(return_value=models)), - patch("transformerlab.routers.recipes.Dataset.list_all", Mock(return_value=datasets)), + patch("transformerlab.routers.recipes.Dataset.list_all", AsyncMock(return_value=datasets)), patch("transformerlab.routers.plugins.plugin_gallery", AsyncMock(return_value=plugins)), ): resp = client.get("/recipes/1/check_dependencies") @@ -82,7 +82,7 @@ def test_check_dependencies_some_missing(client): ], ), patch("transformerlab.models.model_helper.list_installed_models", AsyncMock(return_value=models)), - patch("transformerlab.routers.recipes.Dataset.list_all", Mock(return_value=datasets)), + patch("transformerlab.routers.recipes.Dataset.list_all", AsyncMock(return_value=datasets)), patch("transformerlab.routers.plugins.plugin_gallery", AsyncMock(return_value=plugins)), ): resp = client.get("/recipes/1/check_dependencies") diff --git a/api/test/api/test_remote_workspace.py b/api/test/api/test_remote_workspace.py index dc13db0bd..28bb8f880 100644 --- a/api/test/api/test_remote_workspace.py +++ b/api/test/api/test_remote_workspace.py @@ -19,10 +19,20 @@ def test_download_all_artifacts_endpoint(): Verifies that it correctly retrieves paths, creates a zip, and returns a streaming response. """ mock_job_service = Mock() - mock_job_service.get_all_artifact_paths.return_value = ["path/to/artifact1.txt", "path/to/artifact2.png"] + + async def mock_get_all_artifact_paths(job_id, storage): + return ["path/to/artifact1.txt", "path/to/artifact2.png"] + + mock_job_service.get_all_artifact_paths = mock_get_all_artifact_paths mock_zip_buffer = io.BytesIO(b"fake zip content") - mock_create_zip = Mock(return_value=mock_zip_buffer) + + # Track calls to create_zip + create_zip_calls = [] + + async def mock_create_zip(file_paths, storage): + create_zip_calls.append((file_paths, storage)) + return mock_zip_buffer with ( patch("transformerlab.routers.experiment.jobs.job_service", mock_job_service), @@ -42,12 +52,15 @@ def test_download_all_artifacts_endpoint(): assert "Content-Disposition" in response.headers assert response.headers["Content-Disposition"].startswith("attachment; filename=") - mock_job_service.get_all_artifact_paths.assert_called_once() - mock_create_zip.assert_called_once() + # Verify the async function was called + assert len(create_zip_calls) == 1 # Test 2: No artifacts found - mock_job_service.get_all_artifact_paths.return_value = [] - mock_create_zip.reset_mock() + async def mock_get_all_artifact_paths_empty(job_id, storage): + return [] + + mock_job_service.get_all_artifact_paths = mock_get_all_artifact_paths_empty + create_zip_calls.clear() # Reset call tracking loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -55,86 +68,7 @@ def test_download_all_artifacts_endpoint(): loop.close() assert response_empty.status_code == 404 - mock_create_zip.assert_not_called() - - -def test_s3_artifacts_lose_metadata_due_to_os_stat_bug(): - """ - BUG REPRODUCTION TEST - - This test reproduces the exact issue: - - S3 artifacts lose metadata because os.stat() is called on S3 paths - - os.stat() fails on S3 paths and exceptions are caught silently - - Artifacts returned without size/date → frontend can't load them - """ - real_s3_paths = [ - "s3://workspace-e66b18ed-9947-4ec5-ad09-119568eb4bd8/jobs/30/artifacts/black_cat_1_1766507805.png", - "s3://workspace-e66b18ed-9947-4ec5-ad09-119568eb4bd8/jobs/30/artifacts/black_cat_2_1766507816.png", - "s3://workspace-e66b18ed-9947-4ec5-ad09-119568eb4bd8/jobs/30/artifacts/black_cat_3_1766507828.png", - "s3://workspace-e66b18ed-9947-4ec5-ad09-119568eb4bd8/jobs/30/artifacts/black_cat_4_1766507839.png", - "s3://workspace-e66b18ed-9947-4ec5-ad09-119568eb4bd8/jobs/30/artifacts/gallery.html", - ] - - mock_job = Mock() - mock_job.get_artifact_paths.return_value = real_s3_paths - - mock_storage = Mock() - - def mock_ls(path, detail=True): - if detail: - # Return empty list to trigger fallback to os.stat() in buggy code - return [] - return real_s3_paths - - mock_storage.ls.side_effect = mock_ls - mock_storage.exists.return_value = True - mock_storage.isfile.return_value = True - - mock_job_service = Mock() - mock_job_service.job_get.return_value = {"job_data": {}} - - def mock_get_job_artifacts_dir(job_id): - return f"s3://workspace-test/jobs/{job_id}/artifacts" - - with ( - patch("transformerlab.routers.experiment.jobs.job_service", mock_job_service), - patch("transformerlab.routers.experiment.jobs.Job", return_value=mock_job), - patch("transformerlab.routers.experiment.jobs.storage", mock_storage), - patch("lab.dirs.get_job_artifacts_dir", mock_get_job_artifacts_dir), - patch("transformerlab.routers.experiment.jobs.os.stat") as mock_os_stat, - patch("transformerlab.routers.experiment.jobs.datetime") as mock_datetime, - ): - mock_datetime.fromtimestamp.return_value.isoformat.return_value = "2024-01-01T00:00:00" - - def os_stat_side_effect(path): - if path.startswith("s3://"): - raise OSError(f"[Errno 2] No such file or directory: '{path}'") - stat_result = Mock() - stat_result.st_size = 12345 - stat_result.st_mtime = 1766507805.0 - return stat_result - - mock_os_stat.side_effect = os_stat_side_effect - - from transformerlab.routers.experiment.jobs import get_artifacts - - result = asyncio.run(get_artifacts("30", Mock())) - - # THE CRITICAL ASSERTION - s3_calls = [call_args for call_args in mock_os_stat.call_args_list if call_args[0][0].startswith("s3://")] - - assert len(s3_calls) == 0, ( - f"BUG DETECTED: os.stat() was called {len(s3_calls)} times with S3 paths!\n" - f"S3 paths passed to os.stat(): {[c[0][0] for c in s3_calls]}\n\n" - f"FIX: Check if path starts with 's3://' before calling os.stat()" - ) - - artifacts = result["artifacts"] - assert len(artifacts) == 5, f"Expected 5 artifacts, got {len(artifacts)}" - - for artifact in artifacts: - assert "filename" in artifact, "Each artifact must have a filename" - assert not artifact["filename"].startswith("s3://"), "Filename should not contain the full S3 path" + assert len(create_zip_calls) == 0 # Should not have been called def test_os_stat_fails_on_s3_paths_demonstration(): diff --git a/api/test/shared/test_ssl.py b/api/test/shared/test_ssl.py index 63533b672..f2736d08a 100644 --- a/api/test/shared/test_ssl.py +++ b/api/test/shared/test_ssl.py @@ -3,6 +3,7 @@ import importlib from pathlib import Path import pytest +import asyncio @pytest.fixture() @@ -26,42 +27,46 @@ def ssl_utils(monkeypatch, tmp_path): return _ssl_utils -def test_cert_files_are_created_and_reused(ssl_utils): - cert_path, key_path = ssl_utils.ensure_persistent_self_signed_cert() +@pytest.mark.asyncio +async def test_cert_files_are_created_and_reused(ssl_utils): + cert_path, key_path = await ssl_utils.ensure_persistent_self_signed_cert() assert Path(cert_path).exists() assert Path(key_path).exists() first_mtime = Path(cert_path).stat().st_mtime - cert_path2, key_path2 = ssl_utils.ensure_persistent_self_signed_cert() + cert_path2, key_path2 = await ssl_utils.ensure_persistent_self_signed_cert() assert cert_path2 == cert_path assert key_path2 == key_path assert Path(cert_path).stat().st_mtime == first_mtime -def test_certificate_subject_cn_is_expected(ssl_utils): +@pytest.mark.asyncio +async def test_certificate_subject_cn_is_expected(ssl_utils): from cryptography import x509 from cryptography.x509.oid import NameOID - cert_path, _ = ssl_utils.ensure_persistent_self_signed_cert() + cert_path, _ = await ssl_utils.ensure_persistent_self_signed_cert() cert = x509.load_pem_x509_certificate(Path(cert_path).read_bytes()) cn = cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value assert cn == "TransformerLab-Selfhost" -def test_private_key_matches_cert(ssl_utils): +@pytest.mark.asyncio +async def test_private_key_matches_cert(ssl_utils): from cryptography.hazmat.primitives import serialization from cryptography import x509 - cert_path, key_path = ssl_utils.ensure_persistent_self_signed_cert() + cert_path, key_path = await ssl_utils.ensure_persistent_self_signed_cert() cert = x509.load_pem_x509_certificate(Path(cert_path).read_bytes()) key = serialization.load_pem_private_key(Path(key_path).read_bytes(), password=None) assert key.key_size == 2048 assert cert.public_key().public_numbers() == key.public_key().public_numbers() -def test_certificate_sans(ssl_utils): +@pytest.mark.asyncio +async def test_certificate_sans(ssl_utils): from cryptography import x509 - cert_path, _ = ssl_utils.ensure_persistent_self_signed_cert() + cert_path, _ = await ssl_utils.ensure_persistent_self_signed_cert() cert = x509.load_pem_x509_certificate(Path(cert_path).read_bytes()) sans = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value dns_names = set(sans.get_values_for_type(x509.DNSName)) @@ -70,26 +75,23 @@ def test_certificate_sans(ssl_utils): assert ip_addrs == {"127.0.0.1", "::1"} -def test_lock_guards_concurrent_writes(ssl_utils, tmp_path): - from threading import Thread - from queue import Queue +@pytest.mark.asyncio +async def test_lock_guards_concurrent_writes(ssl_utils, tmp_path): + from asyncio import Queue results = [] q = Queue() - def worker(): - result = ssl_utils.ensure_persistent_self_signed_cert() - q.put(result) + async def worker(): + result = await ssl_utils.ensure_persistent_self_signed_cert() + await q.put(result) - threads = [Thread(target=worker) for _ in range(4)] - for t in threads: - t.start() - for t in threads: - t.join() + tasks = [asyncio.create_task(worker()) for _ in range(4)] + await asyncio.gather(*tasks) # Collect results while not q.empty(): - results.append(q.get()) + results.append(await q.get()) assert len(results) == 4 assert len({r[0] for r in results}) == 1 diff --git a/api/transformerlab/compute_providers/slurm.py b/api/transformerlab/compute_providers/slurm.py index a76e014f1..b77b15558 100644 --- a/api/transformerlab/compute_providers/slurm.py +++ b/api/transformerlab/compute_providers/slurm.py @@ -1,5 +1,6 @@ """SLURM provider implementation.""" +import asyncio import requests import os from typing import Dict, Any, Optional, Union, List @@ -131,8 +132,8 @@ def _ssh_sftp_upload(self, local_path: str, remote_path: str) -> None: local_path = os.path.expanduser(local_path) # Determine existence and directory-ness using storage first, then os.* - if storage.exists(local_path): - is_dir = storage.isdir(local_path) + if asyncio.run(storage.exists(local_path)): + is_dir = asyncio.run(storage.isdir(local_path)) else: raise FileNotFoundError(f"Local path for file_mounts does not exist: {local_path}") @@ -184,7 +185,7 @@ def _upload_file(local_f: str, remote_f: str) -> None: if is_dir: # Recursively upload directory contents - walker = storage.walk(local_path) + walker = asyncio.run(storage.walk(local_path)) for root, _dirs, files in walker: rel = os.path.relpath(root, local_path) if rel == ".": @@ -193,7 +194,7 @@ def _upload_file(local_f: str, remote_f: str) -> None: remote_root = f"{remote_path.rstrip('/')}/{rel}" _mkdir_p(remote_root) for fname in files: - local_f = storage.join(root, fname) + local_f = asyncio.run(storage.join(root, fname)) remote_f = f"{remote_root.rstrip('/')}/{fname}" _upload_file(local_f, remote_f) else: diff --git a/api/transformerlab/db/filesystem_migrations.py b/api/transformerlab/db/filesystem_migrations.py index 61d3153cb..eab5d7679 100644 --- a/api/transformerlab/db/filesystem_migrations.py +++ b/api/transformerlab/db/filesystem_migrations.py @@ -62,10 +62,10 @@ async def migrate_datasets_table_to_filesystem(): try: try: - ds = dataset_service.get(dataset_id) + ds = await dataset_service.get(dataset_id) except FileNotFoundError: - ds = dataset_service.create(dataset_id) - ds.set_metadata( + ds = await dataset_service.create(dataset_id) + await ds.set_metadata( location=location, description=description, size=size, @@ -104,7 +104,7 @@ async def migrate_models_table_to_filesystem(): from sqlalchemy import text as sqlalchemy_text from transformerlab.db.session import async_session - models_dir = get_models_dir() + models_dir = await get_models_dir() # Initialize the exists variable exists = False @@ -159,10 +159,10 @@ async def migrate_models_table_to_filesystem(): try: try: - model = model_service.get(model_id) + model = await model_service.get(model_id) except FileNotFoundError: - model = model_service.create(model_id) - model.set_metadata( + model = await model_service.create(model_id) + await model.set_metadata( model_id=model_id, name=name, json_data=json_data, @@ -188,7 +188,7 @@ async def migrate_models_table_to_filesystem(): try: from lab.dirs import get_models_dir - models_dir = get_models_dir() + models_dir = await get_models_dir() if os.path.isdir(models_dir): fs_migrated = 0 for entry in os.listdir(models_dir): @@ -215,10 +215,10 @@ async def migrate_models_table_to_filesystem(): try: try: - model = model_service.get(model_id) + model = await model_service.get(model_id) except FileNotFoundError: - model = model_service.create(model_id) - model.set_metadata( + model = await model_service.create(model_id) + await model.set_metadata( model_id=model_id, name=name, json_data=json_data, @@ -330,11 +330,11 @@ async def migrate_tasks_table_to_filesystem(): try: try: - task = task_service.get(task_id) + task = await task_service.get(task_id) except FileNotFoundError: - task = task_service.create(task_id) + task = await task_service.create(task_id) - task.set_metadata( + await task.set_metadata( name=name, type=task_type, inputs=inputs, @@ -345,7 +345,7 @@ async def migrate_tasks_table_to_filesystem(): ) # Set the timestamps manually since they come from the database - metadata = task.get_metadata() + metadata = await task.get_metadata() if created_at: metadata["created_at"] = ( created_at.isoformat() if hasattr(created_at, "isoformat") else str(created_at) @@ -354,7 +354,7 @@ async def migrate_tasks_table_to_filesystem(): metadata["updated_at"] = ( updated_at.isoformat() if hasattr(updated_at, "isoformat") else str(updated_at) ) - task._set_json_data(metadata) + await task._set_json_data(metadata) migrated += 1 except Exception as e: @@ -444,7 +444,7 @@ async def migrate_jobs(): # the existing directories aside, let the SDK create clean directories with proper structure, # then copy back all the existing files (preserving user data like logs, configs, etc.) temp_jobs_dir = None - jobs_dir = lab_dirs.get_jobs_dir() + jobs_dir = await lab_dirs.get_jobs_dir() if os.path.exists(jobs_dir): temp_jobs_dir = f"{jobs_dir}_migration_temp" print(f"Moving existing jobs directory to: {temp_jobs_dir}") @@ -461,16 +461,16 @@ async def migrate_jobs(): try: # Create SDK Job - job_obj = Job.create(job["id"]) + job_obj = await Job.create(job["id"]) # Update the JSON data with DB data - job_obj._update_json_data_field(key="id", value=job["id"]) - job_obj._update_json_data_field( + await job_obj._update_json_data_field(key="id", value=job["id"]) + await job_obj._update_json_data_field( key="experiment_id", value=experiment_name ) # Use name instead of numeric ID - job_obj._update_json_data_field(key="job_data", value=job.get("job_data", {})) - job_obj._update_json_data_field(key="status", value=job["status"]) - job_obj._update_json_data_field(key="type", value=job["type"]) - job_obj._update_json_data_field(key="progress", value=job.get("progress")) + await job_obj._update_json_data_field(key="job_data", value=job.get("job_data", {})) + await job_obj._update_json_data_field(key="status", value=job["status"]) + await job_obj._update_json_data_field(key="type", value=job["type"]) + await job_obj._update_json_data_field(key="progress", value=job.get("progress")) # Copy existing files from temp directory if they exist # This preserves all user data (logs, configs, outputs, etc.) that was in the @@ -478,7 +478,7 @@ async def migrate_jobs(): if temp_jobs_dir: old_job_dir = os.path.join(temp_jobs_dir, str(job["id"])) if os.path.exists(old_job_dir): - new_job_dir = job_obj.get_dir() + new_job_dir = await job_obj.get_dir() # Copy all files except index.json (which we just created) for item in os.listdir(old_job_dir): src = os.path.join(old_job_dir, item) @@ -490,13 +490,13 @@ async def migrate_jobs(): else: # Job not found in jobs directory, check if it's in the wrong place # (experiments/{experiment_name}/jobs/{job_id}) from the last month - temp_experiments_dir = f"{lab_dirs.get_experiments_dir()}_migration_temp" + temp_experiments_dir = f"{await lab_dirs.get_experiments_dir()}_migration_temp" if os.path.exists(temp_experiments_dir): wrong_place_job_dir = os.path.join( temp_experiments_dir, str(experiment_name), "jobs", str(job["id"]) ) if os.path.exists(wrong_place_job_dir): - new_job_dir = job_obj.get_dir() + new_job_dir = await job_obj.get_dir() # Copy all files except index.json (which we just created) for item in os.listdir(wrong_place_job_dir): src = os.path.join(wrong_place_job_dir, item) @@ -517,7 +517,7 @@ async def migrate_jobs(): shutil.rmtree(temp_jobs_dir) # Clean up temp experiments directory if it was used for job migration - temp_experiments_dir = f"{lab_dirs.get_experiments_dir()}_migration_temp" + temp_experiments_dir = f"{await lab_dirs.get_experiments_dir()}_migration_temp" if os.path.exists(temp_experiments_dir): print(f"Cleaning up temp experiments directory after job migration: {temp_experiments_dir}") shutil.rmtree(temp_experiments_dir) @@ -588,7 +588,7 @@ async def migrate_experiments(): # SDK create clean directories with proper structure, then copy back all the existing files # (preserving user data like models, datasets, configs, etc.) temp_experiments_dir = None - experiments_dir = lab_dirs.get_experiments_dir() + experiments_dir = await lab_dirs.get_experiments_dir() if os.path.exists(experiments_dir): temp_experiments_dir = f"{experiments_dir}_migration_temp" print(f"Moving existing experiments directory to: {temp_experiments_dir}") @@ -598,15 +598,15 @@ async def migrate_experiments(): for exp in experiments_rows: try: # Create SDK Experiment - experiment = Experiment.create(exp["name"]) + experiment = await Experiment.create(exp["name"]) # Update the JSON data with DB data - experiment._update_json_data_field(key="id", value=exp["name"]) - experiment._update_json_data_field(key="db_experiment_id", value=exp["id"]) - experiment._update_json_data_field(key="config", value=exp.get("config", {})) - experiment._update_json_data_field( + await experiment._update_json_data_field(key="id", value=exp["name"]) + await experiment._update_json_data_field(key="db_experiment_id", value=exp["id"]) + await experiment._update_json_data_field(key="config", value=exp.get("config", {})) + await experiment._update_json_data_field( key="created_at", value=exp.get("created_at", datetime.now().isoformat()) ) - experiment._update_json_data_field( + await experiment._update_json_data_field( key="updated_at", value=exp.get("updated_at", datetime.now().isoformat()) ) @@ -616,7 +616,7 @@ async def migrate_experiments(): if temp_experiments_dir: old_experiment_dir = os.path.join(temp_experiments_dir, exp["name"]) if os.path.exists(old_experiment_dir): - new_experiment_dir = experiment.get_dir() + new_experiment_dir = await experiment.get_dir() for item in os.listdir(old_experiment_dir): src = os.path.join(old_experiment_dir, item) dst = os.path.join(new_experiment_dir, item) diff --git a/api/transformerlab/db/session.py b/api/transformerlab/db/session.py index 4a24f9ed1..39aa4b338 100644 --- a/api/transformerlab/db/session.py +++ b/api/transformerlab/db/session.py @@ -69,7 +69,7 @@ async def init(): """ global db # Migrate database from old location if necessary - old_db_base = os.path.join(get_workspace_dir(), "llmlab.sqlite3") + old_db_base = os.path.join(await get_workspace_dir(), "llmlab.sqlite3") if os.path.exists(old_db_base): if not os.path.exists(DATABASE_FILE_NAME): for ext in ["", "-wal", "-shm"]: diff --git a/api/transformerlab/fastchat_openai_api.py b/api/transformerlab/fastchat_openai_api.py index 3ac02227f..8ef65938b 100644 --- a/api/transformerlab/fastchat_openai_api.py +++ b/api/transformerlab/fastchat_openai_api.py @@ -238,41 +238,47 @@ async def check_model(request, bypass_adaptor=False) -> Optional[JSONResponse]: return ret -def log_prompt(prompt): +async def log_prompt(prompt): """Log the prompt to the global prompt.log file""" MAX_LOG_SIZE_BEFORE_ROTATE = 1000000 # 1MB in bytes from lab.dirs import get_logs_dir - logs_dir = get_logs_dir() - prompt_log_path = storage.join(logs_dir, "prompt.log") - if storage.exists(prompt_log_path): - # Get file size - for remote storage, we may need to read the file to check size - try: - with storage.open(prompt_log_path, "r") as f: - lines = f.readlines() - file_size = sum(len(line.encode("utf-8")) for line in lines) - if file_size > MAX_LOG_SIZE_BEFORE_ROTATE: - with storage.open(prompt_log_path, "w") as f: - f.writelines(lines[-1000:]) - with storage.open(storage.join(logs_dir, f"prompt_{time.strftime('%Y%m%d%H%M%S')}.log"), "w") as f: - f.writelines(lines[:-1000]) - except Exception: - # If we can't read the file, just continue with appending - pass - - with storage.open(prompt_log_path, "a") as f: - log_entry = {} - log_entry["date"] = time.strftime("%Y-%m-%d %H:%M:%S") - log_entry["log"] = prompt - log_entry = json.dumps(log_entry) - f.write(f"{log_entry}\n") + # Run async operations + async def _log(): + logs_dir = await get_logs_dir() + prompt_log_path = storage.join(logs_dir, "prompt.log") + if await storage.exists(prompt_log_path): + # Get file size - for remote storage, we may need to read the file to check size + try: + async with await storage.open(prompt_log_path, "r") as f: + lines = (await f.read()).splitlines(keepends=True) + file_size = sum(len(line.encode("utf-8")) for line in lines) + if file_size > MAX_LOG_SIZE_BEFORE_ROTATE: + async with await storage.open(prompt_log_path, "w") as f: + await f.write("".join(lines[-1000:])) + async with await storage.open( + storage.join(logs_dir, f"prompt_{time.strftime('%Y%m%d%H%M%S')}.log"), "w" + ) as f: + await f.write("".join(lines[:-1000])) + except Exception: + # If we can't read the file, just continue with appending + pass + + async with await storage.open(prompt_log_path, "a") as f: + log_entry = {} + log_entry["date"] = time.strftime("%Y-%m-%d %H:%M:%S") + log_entry["log"] = prompt + log_entry = json.dumps(log_entry) + await f.write(f"{log_entry}\n") + + await _log() @router.get("/prompt_log", tags=["chat"]) async def get_prompt_log(): from lab.dirs import get_logs_dir - prompt_log_path = storage.join(get_logs_dir(), "prompt.log") + prompt_log_path = storage.join(await get_logs_dir(), "prompt.log") # FileResponse needs a local file path, so use the path string directly # For remote storage, this would need special handling return FileResponse(prompt_log_path) @@ -545,11 +551,11 @@ async def create_audio_tts(request: AudioGenerationRequest): request.model = error_check_ret["model_name"] # TODO: Change this - exp_obj = Experiment.get(request.experiment_id) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(request.experiment_id) + experiment_dir = await exp_obj.get_dir() audio_dir = storage.join(experiment_dir, "audio") - storage.makedirs(audio_dir, exist_ok=True) + await storage.makedirs(audio_dir, exist_ok=True) gen_params = { "audio_dir": audio_dir, @@ -579,10 +585,10 @@ async def create_audio_tts(request: AudioGenerationRequest): @router.post("/v1/audio/upload_reference", tags=["audio"]) async def upload_audio_reference(experimentId: str, audio: UploadFile = File(...)): - exp_obj = Experiment(experimentId) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.create_or_get(experimentId) + experiment_dir = await exp_obj.get_dir() uploaded_audio_dir = storage.join(experiment_dir, "uploaded_audio") - storage.makedirs(uploaded_audio_dir, exist_ok=True) + await storage.makedirs(uploaded_audio_dir, exist_ok=True) file_prefix = str(uuid.uuid4()) _, ext = os.path.splitext(audio.filename) @@ -590,8 +596,8 @@ async def upload_audio_reference(experimentId: str, audio: UploadFile = File(... # Save the uploaded file content = await audio.read() - with storage.open(file_path, "wb") as f: - f.write(content) + async with await storage.open(file_path, "wb") as f: + await f.write(content) return JSONResponse({"audioPath": file_path}) @@ -605,10 +611,10 @@ async def create_text_stt(request: AudioTranscriptionsRequest): elif isinstance(error_check_ret, dict) and "model_name" in error_check_ret.keys(): request.model = error_check_ret["model_name"] - exp_obj = Experiment.get(request.experiment_id) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(request.experiment_id) + experiment_dir = await exp_obj.get_dir() transcription_dir = storage.join(experiment_dir, "transcriptions") - storage.makedirs(transcription_dir, exist_ok=True) + await storage.makedirs(transcription_dir, exist_ok=True) gen_params = { "model": request.model, @@ -658,7 +664,7 @@ async def create_openapi_chat_completion(request: ChatCompletionRequest): error_check_ret = await check_length(request, gen_params["prompt"], gen_params["max_new_tokens"]) if error_check_ret is not None: return error_check_ret - log_prompt(gen_params) + await log_prompt(gen_params) if request.stream: generator = chat_completion_stream_generator(request.model, gen_params, request.n) return StreamingResponse(generator, media_type="text/event-stream") @@ -812,7 +818,7 @@ async def create_completion(request: ModifiedCompletionRequest): logprobs=request.logprobs, ) - log_prompt(gen_params) + await log_prompt(gen_params) for i in range(request.n): content = asyncio.create_task(generate_completion(gen_params)) @@ -985,7 +991,7 @@ async def generate_completion_stream_generator(request: ModifiedCompletionReques logprobs=request.logprobs, ) gen_params["type"] = "completion" - log_prompt(gen_params) + await log_prompt(gen_params) async for content in generate_completion_stream(gen_params): if content["error_code"] != 0: diff --git a/api/transformerlab/models/basemodel.py b/api/transformerlab/models/basemodel.py index 8defd0da2..fb2fb989f 100644 --- a/api/transformerlab/models/basemodel.py +++ b/api/transformerlab/models/basemodel.py @@ -99,10 +99,10 @@ async def install(self): from lab.model import Model as ModelService try: - model_service = ModelService.create(self.id) + model_service = await ModelService.create(self.id) except FileExistsError: - model_service = ModelService.get(self.id) - model_service.set_metadata(model_id=self.id, name=self.name, json_data=json_data) + model_service = await ModelService.get(self.id) + await model_service.set_metadata(model_id=self.id, name=self.name, json_data=json_data) # MODEL UTILITY FUNCTIONS diff --git a/api/transformerlab/models/localmodel.py b/api/transformerlab/models/localmodel.py index bf118c46c..4dd6d679f 100644 --- a/api/transformerlab/models/localmodel.py +++ b/api/transformerlab/models/localmodel.py @@ -13,7 +13,7 @@ from lab import storage -def is_sentence_transformer_model( +async def is_sentence_transformer_model( model_name_or_path: str, token: bool | str | None = None, cache_folder: str | None = None, @@ -34,7 +34,7 @@ def is_sentence_transformer_model( bool: True if the model is a SentenceTransformer model, False otherwise. """ return bool( - load_file_path( + await load_file_path( model_name_or_path, "modules.json", token=token, @@ -45,7 +45,7 @@ def is_sentence_transformer_model( ) -def load_file_path( +async def load_file_path( model_name_or_path: str, filename: str, token: bool | str | None = None, @@ -69,7 +69,7 @@ def load_file_path( """ # If file is local file_path = storage.join(model_name_or_path, filename) - if storage.exists(file_path): + if await storage.exists(file_path): return file_path # If file is remote @@ -115,7 +115,7 @@ async def filter_embedding_models(self, models, embedding=False): print(model) continue - if is_sentence_transformer_model(model_id): + if await is_sentence_transformer_model(model_id): embedding_models.append(model) else: non_embedding_models.append(model) @@ -130,12 +130,12 @@ async def list_models(self, embedding=False): # Use SDK to get all models from the filesystem from lab.model import Model as ModelService - models = ModelService.list_all() + models = await ModelService.list_all() # Add additional metadata to each model from lab.dirs import get_models_dir - models_dir = get_models_dir() + models_dir = await get_models_dir() for model in models: if model == {} or model is None or model == "": print("Model entry not found, skipping") @@ -156,7 +156,7 @@ async def list_models(self, embedding=False): model_id = model.get("model_id", "") potential_path = storage.join(models_dir, secure_filename(model_id)) # Check if local path exists - if not storage.exists(potential_path): + if not await storage.exists(potential_path): # Remove the Starting TransformerLab/ prefix to handle the save_transformerlab_model function potential_path = storage.join(models_dir, secure_filename("/".join(model_id.split("/")[1:]))) @@ -168,10 +168,10 @@ async def list_models(self, embedding=False): # For non-HuggingFace models, check if it has model_filename or files in directory if has_model_filename: is_local_model = True - elif storage.exists(potential_path) and storage.isdir(potential_path): + elif await storage.exists(potential_path) and await storage.isdir(potential_path): # Check if directory has files other than index.json try: - files = storage.ls(potential_path, detail=False) + files = await storage.ls(potential_path, detail=False) # Extract basenames from full paths returned by storage.ls() file_basenames = [posixpath.basename(f.rstrip("/")) for f in files] # Filter out index.json and other metadata files @@ -184,7 +184,7 @@ async def list_models(self, embedding=False): elif is_huggingface and has_model_filename: # For HuggingFace models, if they have a model_filename and the file/directory exists locally, # treat them as stored locally (e.g., downloaded GGUF files) - if storage.exists(potential_path): + if await storage.exists(potential_path): is_local_model = True if is_local_model: @@ -201,10 +201,10 @@ async def list_models(self, embedding=False): # This ensures we get the full path like: /path/to/models/dir/model.gguf base_path = model["local_path"] model_path = storage.join(base_path, model_filename) - if storage.exists(model_path): - if storage.isdir(model_path): + if await storage.exists(model_path): + if await storage.isdir(model_path): # List all files in the directory ending with .gguf - files = storage.ls(model_path, detail=False) + files = await storage.ls(model_path, detail=False) gguf_files = [ posixpath.basename(f.rstrip("/")) for f in files @@ -214,7 +214,7 @@ async def list_models(self, embedding=False): model_path = storage.join(model_path, gguf_files[0]) else: # Search for files ending with .gguf in the directory - files = storage.ls(model["local_path"], detail=False) + files = await storage.ls(model["local_path"], detail=False) gguf_files = [ posixpath.basename(f.rstrip("/")) for f in files @@ -223,8 +223,8 @@ async def list_models(self, embedding=False): if gguf_files: gguf_file = gguf_files[0] model_path = storage.join(base_path, gguf_file) - if storage.isdir(model_path): - files = storage.ls(model_path, detail=False) + if await storage.isdir(model_path): + files = await storage.ls(model_path, detail=False) gguf_files = [ posixpath.basename(f.rstrip("/")) for f in files @@ -267,14 +267,14 @@ async def build_provenance(self): provenance = {} from lab.dirs import get_models_dir - models_dir = get_models_dir() + models_dir = await get_models_dir() # Load the tlab_complete_provenance.json file if it exists complete_provenance_file = storage.join(models_dir, "_tlab_complete_provenance.json") - if storage.exists(complete_provenance_file): - with storage.open(complete_provenance_file, "r") as f: + if await storage.exists(complete_provenance_file): + async with await storage.open(complete_provenance_file, "r") as f: try: - provenance = json.load(f) + provenance = json.loads(await f.read()) except json.JSONDecodeError: print(f"Error loading {complete_provenance_file}: Invalid JSON format.") provenance = {} @@ -285,13 +285,13 @@ async def build_provenance(self): provenance, local_added_count = await self.check_provenance_for_local_models(provenance) if local_added_count != 0: # Save new provenance mapping - with storage.open(complete_provenance_file, "w") as f: - json.dump(provenance, f) + async with await storage.open(complete_provenance_file, "w") as f: + await f.write(json.dumps(provenance)) # Check if the provenance mapping is up to date # The -1 here indicates that we are not counting the _tlab_complete_provenance.json file in models_dir try: - entries = storage.ls(models_dir, detail=False) - dir_count = sum(1 for entry in entries if storage.isdir(entry)) + entries = await storage.ls(models_dir, detail=False) + dir_count = sum(1 for entry in entries if await storage.isdir(entry)) except Exception: dir_count = 0 if len(provenance) > 0 and dir_count + local_added_count - 1 == len(provenance): @@ -300,17 +300,17 @@ async def build_provenance(self): # If the provenance mapping is not built or models_dir has changed, we need to rebuild it # Scan all model directories try: - entries = storage.ls(models_dir, detail=False) + entries = await storage.ls(models_dir, detail=False) for entry_path in entries: - if storage.isdir(entry_path): + if await storage.isdir(entry_path): # Extract entry name from path entry_name = entry_path.rstrip("/").split("/")[-1] # Look for provenance file provenance_file = storage.join(models_dir, entry_name, "_tlab_provenance.json") try: - if storage.exists(provenance_file): - with storage.open(provenance_file, "r") as f: - prov_data = json.load(f) + if await storage.exists(provenance_file): + async with await storage.open(provenance_file, "r") as f: + prov_data = json.loads(await f.read()) if "md5_checksums" in prov_data: prov_data["parameters"]["md5_checksums"] = prov_data["md5_checksums"] @@ -339,7 +339,7 @@ async def check_provenance_for_local_models(self, provenance): # Get the list of all local models from lab.model import Model as ModelService - models = ModelService.list_all() + models = await ModelService.list_all() models_added_to_provenance = 0 # Iterate through models and check if they have provenance data and if they exist already in provenance for model_dict in models: @@ -349,13 +349,13 @@ async def check_provenance_for_local_models(self, provenance): ): # Check if the model_source is local source_path = model_dict.get("json_data", {}).get("source_id_or_path", "") - if model_dict.get("json_data", {}).get("source", "") == "local" and storage.exists(source_path): + if model_dict.get("json_data", {}).get("source", "") == "local" and await storage.exists(source_path): # Check if the model has a _tlab_provenance.json file provenance_file = storage.join(source_path, "_tlab_provenance.json") - if storage.exists(provenance_file): + if await storage.exists(provenance_file): # Load the provenance file - with storage.open(provenance_file, "r") as f: - prov_data = json.load(f) + async with await storage.open(provenance_file, "r") as f: + prov_data = json.loads(await f.read()) if "md5_checksums" in prov_data: prov_data["parameters"]["md5_checksums"] = prov_data["md5_checksums"] @@ -401,7 +401,7 @@ async def get_evals_by_model(self, model_id): """ from lab.dirs import get_models_dir - models_dir = get_models_dir() + models_dir = await get_models_dir() evaluations_by_model = {} # Extract just the model name if model_id contains a path @@ -415,9 +415,9 @@ async def get_evals_by_model(self, model_id): # Look for the model directory - prioritize exact matches model_dir = None try: - entries = storage.ls(models_dir, detail=False) + entries = await storage.ls(models_dir, detail=False) for entry_path in entries: - if storage.isdir(entry_path): + if await storage.isdir(entry_path): entry = entry_path.rstrip("/").split("/")[-1] # Exact match first, then check for suffixes if entry == search_model_id: @@ -429,12 +429,12 @@ async def get_evals_by_model(self, model_id): except Exception: pass - if model_dir and storage.exists(model_dir): + if model_dir and await storage.exists(model_dir): provenance_file = storage.join(model_dir, "_tlab_provenance.json") - if storage.exists(provenance_file): + if await storage.exists(provenance_file): try: - with storage.open(provenance_file, "r") as f: - provenance_data = json.load(f) + async with await storage.open(provenance_file, "r") as f: + provenance_data = json.loads(await f.read()) # Get evaluations from the same file evaluations = provenance_data.get("evaluations", []) @@ -480,9 +480,9 @@ async def list_model_provenance(self, model_id): # Save the provenance mapping as a json file from lab.dirs import get_models_dir - provenance_file = storage.join(get_models_dir(), "_tlab_complete_provenance.json") - with storage.open(provenance_file, "w") as f: - json.dump(provenance_mapping, f) + provenance_file = storage.join(await get_models_dir(), "_tlab_complete_provenance.json") + async with await storage.open(provenance_file, "w") as f: + await f.write(json.dumps(provenance_mapping)) # Trace the provenance chain leading to the given model chain = await self.trace_provenance(model_id, provenance_mapping) diff --git a/api/transformerlab/models/ollamamodel.py b/api/transformerlab/models/ollamamodel.py index c3f316986..7af2d70a7 100644 --- a/api/transformerlab/models/ollamamodel.py +++ b/api/transformerlab/models/ollamamodel.py @@ -178,10 +178,10 @@ async def install(self): # Create a directory for the model. Make sure it doesn't exist already. from lab.dirs import get_models_dir - output_path = storage.join(get_models_dir(), output_model_id) - if storage.exists(output_path): + output_path = storage.join(await get_models_dir(), output_model_id) + if await storage.exists(output_path): raise FileExistsError(errno.EEXIST, "Directory already exists", output_path) - storage.makedirs(output_path, exist_ok=True) + await storage.makedirs(output_path, exist_ok=True) # Create a link in the directory that points to the source blob # Note: symlinks may not work with remote storage, but this is for local filesystem @@ -196,9 +196,9 @@ async def install(self): except Exception as e: # If symlink fails, we could copy the file instead print(f"Warning: Could not create symlink, copying file instead: {e}") - with storage.open(link_name, "wb") as out_f: + async with await storage.open(link_name, "wb") as out_f: with open(input_model_path, "rb") as in_f: - out_f.write(in_f.read()) + await out_f.write(in_f.read()) # Create an index.json file so this can be read by the system (SDK format) model_description = { @@ -217,8 +217,8 @@ async def install(self): }, } model_info_file = storage.join(output_path, "index.json") - with storage.open(model_info_file, "w") as f: - json.dump(model_description, f) + async with await storage.open(model_info_file, "w") as f: + await f.write(json.dumps(model_description)) ######################### diff --git a/api/transformerlab/plugin_sdk/plugin_harness.py b/api/transformerlab/plugin_sdk/plugin_harness.py index eb909fcd9..1697eb66b 100644 --- a/api/transformerlab/plugin_sdk/plugin_harness.py +++ b/api/transformerlab/plugin_sdk/plugin_harness.py @@ -12,6 +12,7 @@ import sys import argparse import traceback +import asyncio parser = argparse.ArgumentParser() @@ -23,7 +24,7 @@ def set_config_env_vars(env_var: str, target_env_var: str = None, user_id: str = try: from transformerlab.plugin import get_db_config_value - value = get_db_config_value(env_var, user_id=user_id, team_id=team_id) + value = asyncio.run(get_db_config_value(env_var, user_id=user_id, team_id=team_id)) if value: os.environ[target_env_var] = value print(f"Set {target_env_var} from {'user' if user_id else 'team'} config: {value}") diff --git a/api/transformerlab/plugin_sdk/transformerlab/plugin.py b/api/transformerlab/plugin_sdk/transformerlab/plugin.py index 7f87f2a13..8ac79da58 100644 --- a/api/transformerlab/plugin_sdk/transformerlab/plugin.py +++ b/api/transformerlab/plugin_sdk/transformerlab/plugin.py @@ -1,3 +1,4 @@ +import asyncio import os import json import sqlite3 @@ -16,11 +17,26 @@ # useful constants # Use shared constant as sole source of truth DATABASE_FILE_NAME = f"{HOME_DIR}/llmlab.sqlite3" -WORKSPACE_DIR = get_workspace_dir() + +# Initialize WORKSPACE_DIR synchronously +# This is called during plugin initialization, not during API runtime +try: + WORKSPACE_DIR = asyncio.run(get_workspace_dir()) +except RuntimeError: + # If there's already an event loop running (shouldn't happen in plugin context) + # fall back to using the existing loop + try: + loop = asyncio.get_event_loop() + WORKSPACE_DIR = loop.run_until_complete(get_workspace_dir()) + except Exception as e: + print(f"Plugin Harness Error: Could not get WORKSPACE_DIR: {e}") + WORKSPACE_DIR = None + if WORKSPACE_DIR is None: print("Plugin Harness Error: WORKSPACE_DIR not available. Quitting.") exit(1) -TEMP_DIR = storage.join(get_workspace_dir(), "temp") + +TEMP_DIR = storage.join(WORKSPACE_DIR, "temp") # Maintain a singleton database connection db = None @@ -66,26 +82,30 @@ def get_dataset_path(dataset_id: str): Returns the ID or filesystem path to pass to load_dataset() for a given ID, using the dataset service instead of the deprecated DB table. """ - try: - ds = dataset_service.get(dataset_id) - metadata = ds.get_metadata() - except FileNotFoundError: - raise Exception(f"No dataset named {dataset_id} installed.") - - location = (metadata or {}).get("location", "huggingface") - if location == "local": - # Use service path resolution to ensure correctness + + async def _get(): try: - return ds.get_dir() - except Exception: - # Fallback to previous behavior if needed - return storage.join(get_workspace_dir(), "datasets", dataset_id) + ds = await dataset_service.get(dataset_id) + metadata = await ds.get_metadata() + except FileNotFoundError: + raise Exception(f"No dataset named {dataset_id} installed.") + + location = (metadata or {}).get("location", "huggingface") + if location == "local": + # Use service path resolution to ensure correctness + try: + return await ds.get_dir() + except Exception: + # Fallback to previous behavior if needed + return storage.join(await get_workspace_dir(), "datasets", dataset_id) + + # Otherwise assume it is a HuggingFace dataset id + return dataset_id - # Otherwise assume it is a HuggingFace dataset id - return dataset_id + return asyncio.run(_get()) -def get_db_config_value(key: str, team_id: Optional[str] = None, user_id: Optional[str] = None): +async def get_db_config_value(key: str, team_id: Optional[str] = None, user_id: Optional[str] = None): """ Returns the value of a config key from the database with priority: user-specific -> team-wide -> global config. @@ -105,7 +125,7 @@ def get_db_config_value(key: str, team_id: Optional[str] = None, user_id: Option # Extract team_id from workspace_dir if not provided if team_id is None: try: - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() if workspace_dir and "/orgs/" in workspace_dir: # Extract team_id from path like ~/.transformerlab/orgs//workspace parts = workspace_dir.split("/orgs/") @@ -180,24 +200,31 @@ def test_wandb_login(project_name: str = "TFL_Training"): def experiment_get(id): - try: - exp_obj = Experiment.get(id) - return exp_obj.get_json_data() - except Exception: - return None + async def _get(): + try: + exp_obj = await Experiment.get(id) + return await exp_obj.get_json_data() + except Exception: + return None + + return asyncio.run(_get()) def get_experiment_config(name: str): """ Returns the experiment config from the experiment name. """ - try: - exp_obj = Experiment.get(name) - json_data = exp_obj.get_json_data() - if json_data: - return json_data["config"], name - except Exception: - return None, name + + async def _get(): + try: + exp_obj = await Experiment.get(name) + json_data = await exp_obj.get_json_data() + if json_data: + return json_data["config"], name + except Exception: + return None, name + + return asyncio.run(_get()) def get_python_executable(plugin_dir): @@ -257,11 +284,14 @@ def generate_model_json( model_description["json_data"].update(json_data) # Output the json to the file - if not output_directory: - output_directory = storage.join(get_workspace_dir(), "models", model_id) - with storage.open(storage.join(output_directory, "index.json"), "w") as outfile: - json.dump(model_description, outfile) - + async def _write(): + nonlocal output_directory + if not output_directory: + output_directory = storage.join(await get_workspace_dir(), "models", model_id) + async with await storage.open(storage.join(output_directory, "index.json"), "w") as outfile: + await outfile.write(json.dumps(model_description)) + + asyncio.run(_write()) return model_description @@ -278,40 +308,43 @@ def prepare_dataset_files( if chat_template: tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - for split_name in datasets: - dataset_split = datasets[split_name] - print(f"Processing {split_name} dataset with {len(dataset_split)} examples.") - - output_file = storage.join(data_directory, f"{split_name}.jsonl") - with storage.open(output_file, "w") as f: - for i in range(len(dataset_split)): - example = dataset_split[i] - try: - rendered_text = format_template( - example=example, - formatting_template=formatting_template, - chat_template=chat_template, - tokenizer=tokenizer, - chat_column=chat_column, - ) - rendered_text = rendered_text.replace("\n", "\\n").replace("\r", "\\r") - f.write(json.dumps({"text": rendered_text}) + "\n") - except Exception: - print(f"Warning: Failed to process example {i} in '{split_name}'. Skipping.") - continue # Skip problematic examples - - # Print one example from the written jsonl file - try: - with storage.open(output_file, "r") as f: - first_line = f.readline() - if first_line: - parsed = json.loads(first_line) - print(f"Example from {split_name} split:") - print(parsed.get("text", first_line)) - else: - print(f"Example from {split_name} split: file is empty.") - except Exception as e: - print(f"Error reading example from {output_file}: {e}") + async def _process_datasets(): + for split_name in datasets: + dataset_split = datasets[split_name] + print(f"Processing {split_name} dataset with {len(dataset_split)} examples.") + + output_file = storage.join(data_directory, f"{split_name}.jsonl") + async with await storage.open(output_file, "w") as f: + for i in range(len(dataset_split)): + example = dataset_split[i] + try: + rendered_text = format_template( + example=example, + formatting_template=formatting_template, + chat_template=chat_template, + tokenizer=tokenizer, + chat_column=chat_column, + ) + rendered_text = rendered_text.replace("\n", "\\n").replace("\r", "\\r") + await f.write(json.dumps({"text": rendered_text}) + "\n") + except Exception: + print(f"Warning: Failed to process example {i} in '{split_name}'. Skipping.") + continue # Skip problematic examples + + # Print one example from the written jsonl file + try: + async with await storage.open(output_file, "r") as f: + first_line = await f.readline() + if first_line: + parsed = json.loads(first_line) + print(f"Example from {split_name} split:") + print(parsed.get("text", first_line)) + else: + print(f"Example from {split_name} split: file is empty.") + except Exception as e: + print(f"Error reading example from {output_file}: {e}") + + asyncio.run(_process_datasets()) def format_template( diff --git a/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/evals.py b/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/evals.py index 8ede71619..c3bf64386 100644 --- a/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/evals.py +++ b/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/evals.py @@ -1,3 +1,4 @@ +import asyncio import json import os import time @@ -56,24 +57,28 @@ def setup_eval_logging(self, wandb_project_name: str = "TLab_Evaluations", manua today = time.strftime("%Y%m%d-%H%M%S") from transformerlab.plugin import WORKSPACE_DIR as workspace_dir - # Create tensorboard directory structure - tensorboard_dir = storage.join(workspace_dir, "experiments", self.params.experiment_name, "tensorboards") - storage.makedirs(tensorboard_dir, exist_ok=True) + async def _setup_dirs(): + # Create tensorboard directory structure + tensorboard_dir = storage.join(workspace_dir, "experiments", self.params.experiment_name, "tensorboards") + await storage.makedirs(tensorboard_dir, exist_ok=True) - # Find directory based on eval name - combined_dir = None - for entry in storage.ls(tensorboard_dir): - dir_name = entry.rstrip("/").split("/")[-1] - if self.params.run_name == dir_name or self.params.run_name == dir_name.lower(): - if storage.isdir(entry): - combined_dir = storage.join(tensorboard_dir, dir_name) - break + # Find directory based on eval name + combined_dir = None + for entry in await storage.ls(tensorboard_dir): + dir_name = entry.rstrip("/").split("/")[-1] + if self.params.run_name == dir_name or self.params.run_name == dir_name.lower(): + if await storage.isdir(entry): + combined_dir = storage.join(tensorboard_dir, dir_name) + break - if combined_dir is None: - combined_dir = storage.join(tensorboard_dir, self.params.run_name) + if combined_dir is None: + combined_dir = storage.join(tensorboard_dir, self.params.run_name) - output_dir = storage.join(combined_dir, f"evaljob_{self.params.job_id}_{today}") - storage.makedirs(output_dir, exist_ok=True) + output_dir = storage.join(combined_dir, f"evaljob_{self.params.job_id}_{today}") + await storage.makedirs(output_dir, exist_ok=True) + return output_dir + + output_dir = asyncio.run(_setup_dirs()) # Store the writer and output directory as instance variables self.params["tensorboard_output_dir"] = output_dir @@ -153,7 +158,7 @@ def get_output_file_path(self, suffix="", is_plotting=False, dir_only=False): experiment_dir = storage.join(workspace_dir, "experiments", self.params.experiment_name) eval_dir = storage.join(experiment_dir, "evals", self.params.eval_name, self.params.job_id) - storage.makedirs(eval_dir, exist_ok=True) + asyncio.run(storage.makedirs(eval_dir, exist_ok=True)) if dir_only: return eval_dir @@ -194,27 +199,32 @@ def save_evaluation_results(self, metrics_df): raise ValueError(f"Missing required columns in metrics DataFrame: {missing_columns}") # Save full DataFrame to CSV - output_path = self.get_output_file_path() - with storage.open(output_path, "w", encoding="utf-8") as f: - metrics_df.to_csv(f, index=False) - print(f"Saved detailed evaluation results to {output_path}") + async def _save_results(): + output_path = self.get_output_file_path() + async with await storage.open(output_path, "w", encoding="utf-8") as f: + await f.write(metrics_df.to_csv(index=False)) + print(f"Saved detailed evaluation results to {output_path}") + + # Create and save plotting data + plot_data_path = self.get_output_file_path(is_plotting=True) - # Create and save plotting data - plot_data_path = self.get_output_file_path(is_plotting=True) + # Extract and format plotting data + plotting_data = metrics_df[["test_case_id", "metric_name", "score"]].copy() - # Extract and format plotting data - plotting_data = metrics_df[["test_case_id", "metric_name", "score"]].copy() + # Format metric names for better display (replace underscores with spaces and capitalize) + plotting_data["metric_name"] = plotting_data["metric_name"].apply(lambda x: x.replace("_", " ").title()) - # Format metric names for better display (replace underscores with spaces and capitalize) - plotting_data["metric_name"] = plotting_data["metric_name"].apply(lambda x: x.replace("_", " ").title()) + # Save as JSON + async with await storage.open(plot_data_path, "w", encoding="utf-8") as f: + await f.write(plotting_data.to_json(orient="records", lines=False)) + print(f"Saved plotting data to {plot_data_path}") - # Save as JSON - with storage.open(plot_data_path, "w", encoding="utf-8") as f: - plotting_data.to_json(f, orient="records", lines=False) - print(f"Saved plotting data to {plot_data_path}") + return output_path, plot_data_path - self.job.update_job_data_field("additional_output_path", output_path) - self.job.update_job_data_field("plot_data_path", plot_data_path) + output_path, plot_data_path = asyncio.run(_save_results()) + + asyncio.run(self.job.update_job_data_field("additional_output_path", output_path)) + asyncio.run(self.job.update_job_data_field("plot_data_path", plot_data_path)) # Add evaluation data to existing provenance file self.add_evaluation_to_provenance_file(metrics_df) @@ -276,55 +286,58 @@ def add_evaluation_to_provenance_file(self, metrics_df): # Add evaluation data to the existing provenance file in the model directory # Try to find the model directory using environment variables - from lab.dirs import get_workspace_dir - - workspace_dir = get_workspace_dir() - - models_dir = storage.join(workspace_dir, "models") - - # Look for the model directory - since we have the actual model path, we can be more precise - model_dir = None - for entry in storage.ls(models_dir): - entry_name = entry.rstrip("/").split("/")[-1] - if storage.isdir(entry): - # Exact match first, then check for suffixes - if entry_name == model_name: - model_dir = storage.join(models_dir, entry_name) - break - elif entry_name.endswith(f"_{model_name}"): - model_dir = storage.join(models_dir, entry_name) - break - - if not model_dir or not storage.exists(model_dir): - print( - "Unable to add evaluation details to model provenance file, since that is only supported for fine-tuned (locally trained) models." - ) - return - - provenance_path = storage.join(model_dir, "_tlab_provenance.json") - - # Load existing provenance data - existing_provenance = {} - if storage.exists(provenance_path): - try: - with storage.open(provenance_path, "r", encoding="utf-8") as f: - existing_provenance = json.load(f) - except Exception as e: - print(f"Error loading existing provenance: {e}") - existing_provenance = {} - - # Initialize evaluations list if it doesn't exist - if "evaluations" not in existing_provenance: - existing_provenance["evaluations"] = [] - - # Add new evaluation to the list - existing_provenance["evaluations"].append(evaluation_data) - - # Write updated provenance file - with storage.open(provenance_path, "w", encoding="utf-8") as f: - json.dump(existing_provenance, f, indent=2) - - print(f"Evaluation data added to provenance file: {provenance_path}") + async def _add_to_provenance(): + from lab.dirs import get_workspace_dir + + workspace_dir = await get_workspace_dir() + + models_dir = storage.join(workspace_dir, "models") + + # Look for the model directory - since we have the actual model path, we can be more precise + model_dir = None + for entry in await storage.ls(models_dir): + entry_name = entry.rstrip("/").split("/")[-1] + if await storage.isdir(entry): + # Exact match first, then check for suffixes + if entry_name == model_name: + model_dir = storage.join(models_dir, entry_name) + break + elif entry_name.endswith(f"_{model_name}"): + model_dir = storage.join(models_dir, entry_name) + break + + if not model_dir or not await storage.exists(model_dir): + print( + "Unable to add evaluation details to model provenance file, since that is only supported for fine-tuned (locally trained) models." + ) + return + + provenance_path = storage.join(model_dir, "_tlab_provenance.json") + + # Load existing provenance data + existing_provenance = {} + if await storage.exists(provenance_path): + try: + async with await storage.open(provenance_path, "r", encoding="utf-8") as f: + existing_provenance = json.loads(await f.read()) + except Exception as e: + print(f"Error loading existing provenance: {e}") + existing_provenance = {} + + # Initialize evaluations list if it doesn't exist + if "evaluations" not in existing_provenance: + existing_provenance["evaluations"] = [] + + # Add new evaluation to the list + existing_provenance["evaluations"].append(evaluation_data) + + # Write updated provenance file + async with await storage.open(provenance_path, "w", encoding="utf-8") as f: + await f.write(json.dumps(existing_provenance, indent=2)) + + print(f"Evaluation data added to provenance file: {provenance_path}") + + asyncio.run(_add_to_provenance()) tlab_evals = EvalsTLabPlugin() diff --git a/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/export.py b/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/export.py index b74fad074..64c426d86 100644 --- a/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/export.py +++ b/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/export.py @@ -1,3 +1,4 @@ +import asyncio import time import traceback @@ -52,16 +53,16 @@ def wrapper(*args, **kwargs): self.add_job_data("model_name", self.params.model_name) # Update starting progress - self.job.update_progress(progress_start) + asyncio.run(self.job.update_progress(progress_start)) try: # Call the wrapped function result = func(*args, **kwargs) # Update final progress and success status - self.job.update_progress(progress_end) - self.job.update_job_data_field("completion_status", "success") - self.job.update_job_data_field("completion_details", "Export completed successfully") + asyncio.run(self.job.update_progress(progress_end)) + asyncio.run(self.job.update_job_data_field("completion_status", "success")) + asyncio.run(self.job.update_job_data_field("completion_details", "Export completed successfully")) self.add_job_data("end_time", time.strftime("%Y-%m-%d %H:%M:%S")) return result @@ -72,8 +73,8 @@ def wrapper(*args, **kwargs): print(error_msg) # Log the error - self.job.update_job_data_field("completion_status", "failed") - self.job.update_job_data_field("completion_details", f"Error occured: {str(e)}") + asyncio.run(self.job.update_job_data_field("completion_status", "failed")) + asyncio.run(self.job.update_job_data_field("completion_details", f"Error occured: {str(e)}")) self.add_job_data("end_time", time.strftime("%Y-%m-%d %H:%M:%S")) raise diff --git a/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/generate.py b/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/generate.py index afced394c..9dcd78a88 100644 --- a/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/generate.py +++ b/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/generate.py @@ -1,3 +1,4 @@ +import asyncio import json import os import time @@ -63,7 +64,7 @@ def save_generated_dataset(self, df, additional_metadata=None, dataset_id=None, else: output_dir = self.get_output_file_path(dir_only=True) - storage.makedirs(output_dir, exist_ok=True) + asyncio.run(storage.makedirs(output_dir, exist_ok=True)) if is_image: lines = True @@ -98,11 +99,18 @@ def save_generated_dataset(self, df, additional_metadata=None, dataset_id=None, metadata_file = storage.join(output_dir, f"{self.params.run_name}_{self.params.job_id}_metadata.json") else: metadata_file = storage.join(output_dir, f"{dataset_id}_metadata.json") - with storage.open(metadata_file, "w", encoding="utf-8") as f: - json.dump(metadata, f, indent=2) - with storage.open(output_file, "w", encoding="utf-8") as f: - df.to_json(f, orient="records", lines=lines) + async def _save_metadata(): + async with await storage.open(metadata_file, "w", encoding="utf-8") as f: + await f.write(json.dumps(metadata, indent=2)) + + asyncio.run(_save_metadata()) + + async def _save_data(): + async with await storage.open(output_file, "w", encoding="utf-8") as f: + await f.write(df.to_json(orient="records", lines=lines)) + + asyncio.run(_save_data()) print(f"Generated data saved to {output_file}") # Upload to Transformer Lab as a dataset @@ -128,54 +136,61 @@ def upload_to_transformerlab(self, output_file_path, dataset_id=None): Returns: bool: Whether upload was successful """ - try: - # Determine dataset ID - if not dataset_id: - dataset_id = f"{self.params.run_name}_{self.params.job_id}" - # Create a new dataset using internal SDK (same logic as /data/new endpoint) - # Check if dataset already exists + async def _upload(): try: - _ = dataset_service.get(dataset_id) - print(f"Dataset '{dataset_id}' already exists, skipping creation") - except FileNotFoundError: - # Dataset doesn't exist, create it - dataset_path = dirs.dataset_dir_by_id(dataset_id) - if not storage.exists(dataset_path): - storage.makedirs(dataset_path, exist_ok=True) - - # Create filesystem metadata + # Determine dataset ID + if not dataset_id: + dataset_id_local = f"{self.params.run_name}_{self.params.job_id}" + else: + dataset_id_local = dataset_id + + # Create a new dataset using internal SDK (same logic as /data/new endpoint) + # Check if dataset already exists try: - ds = dataset_service.create(dataset_id) - ds.set_metadata( - location="local", - description="", - size=-1, - json_data={"generated": True}, - ) - except Exception as e: - print(f"Failed to write dataset metadata to SDK store: {type(e).__name__}: {e}") - - # Upload the file (same logic as /data/fileupload endpoint) - filename = os.path.basename(output_file_path) - target_path = storage.join(dirs.dataset_dir_by_id(dataset_id), filename) - - # Copy the file to the dataset directory - storage.makedirs(dirs.dataset_dir_by_id(dataset_id), exist_ok=True) - with storage.open(output_file_path, "rb") as src_file: - content = src_file.read() - with storage.open(target_path, "wb") as dst_file: - dst_file.write(content) - - # Adding dataset so it can be previewed - self.add_job_data("additional_output_path", output_file_path) - - print(f"Dataset '{dataset_id}' uploaded successfully to TransformerLab") - return True - - except Exception as e: - print(f"Error uploading to TransformerLab: {e}") - raise + _ = await dataset_service.get(dataset_id_local) + print(f"Dataset '{dataset_id_local}' already exists, skipping creation") + except FileNotFoundError: + # Dataset doesn't exist, create it + dataset_path = await dirs.dataset_dir_by_id(dataset_id_local) + if not await storage.exists(dataset_path): + await storage.makedirs(dataset_path, exist_ok=True) + + # Create filesystem metadata + try: + ds = await dataset_service.create(dataset_id_local) + await ds.set_metadata( + location="local", + description="", + size=-1, + json_data={"generated": True}, + ) + except Exception as e: + print(f"Failed to write dataset metadata to SDK store: {type(e).__name__}: {e}") + + # Upload the file (same logic as /data/fileupload endpoint) + filename = os.path.basename(output_file_path) + dataset_dir = await dirs.dataset_dir_by_id(dataset_id_local) + target_path = storage.join(dataset_dir, filename) + + # Copy the file to the dataset directory + await storage.makedirs(dataset_dir, exist_ok=True) + async with await storage.open(output_file_path, "rb") as src_file: + content = await src_file.read() + async with await storage.open(target_path, "wb") as dst_file: + await dst_file.write(content) + + # Adding dataset so it can be previewed + self.add_job_data("additional_output_path", output_file_path) + + print(f"Dataset '{dataset_id_local}' uploaded successfully to TransformerLab") + return True + + except Exception as e: + print(f"Error uploading to TransformerLab: {e}") + raise + + return asyncio.run(_upload()) def get_output_file_path(self, suffix="", dataset_id=None, dir_only=False): """Get path for saving generated outputs @@ -191,17 +206,21 @@ def get_output_file_path(self, suffix="", dataset_id=None, dir_only=False): from lab.dirs import get_workspace_dir - workspace_dir = get_workspace_dir() + async def _get_dir(): + workspace_dir = await get_workspace_dir() - experiment_dir = storage.join(workspace_dir, "experiments", self.params.experiment_name) - dataset_dir = storage.join(experiment_dir, "datasets") + experiment_dir = storage.join(workspace_dir, "experiments", self.params.experiment_name) + dataset_dir = storage.join(experiment_dir, "datasets") - # Create a specific directory for this generation job - if dataset_id is None: - gen_dir = storage.join(dataset_dir, f"{self.params.run_name}_{self.params.job_id}") - else: - gen_dir = storage.join(dataset_dir, dataset_id) - storage.makedirs(gen_dir, exist_ok=True) + # Create a specific directory for this generation job + if dataset_id is None: + gen_dir = storage.join(dataset_dir, f"{self.params.run_name}_{self.params.job_id}") + else: + gen_dir = storage.join(dataset_dir, dataset_id) + await storage.makedirs(gen_dir, exist_ok=True) + return gen_dir + + gen_dir = asyncio.run(_get_dir()) if dir_only: return gen_dir diff --git a/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/tlab_plugin.py b/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/tlab_plugin.py index ba2c6ee49..c0bc9ca58 100644 --- a/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/tlab_plugin.py +++ b/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/tlab_plugin.py @@ -107,18 +107,18 @@ def wrapper(*args, **kwargs): # Update final progress and success status self.progress_update(progress_end) - job_data = self.job.get_json_data() + job_data = asyncio.run(self.job.get_json_data()) if job_data.get("job_data", {}).get("completion_status", "") != "success": - self.job.update_job_data_field("completion_status", "success") + asyncio.run(self.job.update_job_data_field("completion_status", "success")) - job_data = self.job.get_json_data() - if job_data.get("job_data", {}).get("completion_status", "") != "Job completed successfully": - self.job.update_job_data_field("completion_details", "Job completed successfully") + job_data = asyncio.run(self.job.get_json_data()) + if job_data.get("job_data", {}).get("completion_details", "") != "Job completed successfully": + asyncio.run(self.job.update_job_data_field("completion_details", "Job completed successfully")) - job_data = self.job.get_json_data() + job_data = asyncio.run(self.job.get_json_data()) if ( - job_data.get("job_data", {}).get("end_time", "") is not None - and job_data.get("job_data", {}).get("end_time", "") != "" + job_data.get("job_data", {}).get("end_time", "") is None + or job_data.get("job_data", {}).get("end_time", "") == "" ): self.add_job_data("end_time", time.strftime("%Y-%m-%d %H:%M:%S")) @@ -137,8 +137,10 @@ def wrapper(*args, **kwargs): print(error_msg) # Update job with failure status - self.job.update_job_data_field("completion_status", "failed") - self.job.update_job_data_field("completion_details", "Error occurred while executing job") + asyncio.run(self.job.update_job_data_field("completion_status", "failed")) + asyncio.run( + self.job.update_job_data_field("completion_details", "Error occurred while executing job") + ) self.add_job_data("end_time", time.strftime("%Y-%m-%d %H:%M:%S")) if manual_logging and getattr(self.params, "wandb_run") is not None: self.wandb_run.finish() @@ -197,8 +199,8 @@ async def run_async(): # Update final progress and success status self.progress_update(progress_end) - self.job.update_job_data_field("completion_status", "success") - self.job.update_job_data_field("completion_details", "Job completed successfully") + await self.job.update_job_data_field("completion_status", "success") + await self.job.update_job_data_field("completion_details", "Job completed successfully") self.add_job_data("end_time", time.strftime("%Y-%m-%d %H:%M:%S")) if manual_logging and getattr(self, "wandb_run") is not None: self.wandb_run.finish() @@ -215,8 +217,8 @@ async def run_async(): print(error_msg) # Update job with failure status - self.job.update_job_data_field("completion_status", "failed") - self.job.update_job_data_field("completion_details", "Error occurred while executing job") + await self.job.update_job_data_field("completion_status", "failed") + await self.job.update_job_data_field("completion_details", "Error occurred while executing job") self.add_job_data("end_time", time.strftime("%Y-%m-%d %H:%M:%S")) if manual_logging and getattr(self, "wandb_run") is not None: self.wandb_run.finish() @@ -237,17 +239,17 @@ async def run_async(): def progress_update(self, progress: int): """Update job progress using SDK directly""" - job_data = self.job.get_job_data() + job_data = asyncio.run(self.job.get_job_data()) if job_data.get("sweep_progress") is not None: if int(job_data.get("sweep_progress")) != 100: - self.job.update_job_data_field("sweep_subprogress", progress) + asyncio.run(self.job.update_job_data_field("sweep_subprogress", progress)) return - self.job.update_progress(progress) + asyncio.run(self.job.update_progress(progress)) # Check stop status using SDK - job_data = self.job.get_job_data() + job_data = asyncio.run(self.job.get_job_data()) if job_data.get("stop", False): - self.job.update_status("STOPPED") + asyncio.run(self.job.update_status("STOPPED")) raise KeyboardInterrupt("Job stopped by user") def get_experiment_config(self, experiment_name: str): @@ -256,7 +258,7 @@ def get_experiment_config(self, experiment_name: str): def add_job_data(self, key: str, value: Any): """Add data to job using SDK directly""" - self.job.update_job_data_field(key, value) + asyncio.run(self.job.update_job_data_field(key, value)) def load_dataset(self, dataset_types: List[str] = ["train"], config_name: str = None): """Decorator for loading datasets with error handling""" @@ -264,8 +266,8 @@ def load_dataset(self, dataset_types: List[str] = ["train"], config_name: str = self._ensure_args_parsed() if not self.params.dataset_name: - self.job.update_job_data_field("completion_status", "failed") - self.job.update_job_data_field("completion_details", "Dataset name not provided") + asyncio.run(self.job.update_job_data_field("completion_status", "failed")) + asyncio.run(self.job.update_job_data_field("completion_details", "Dataset name not provided")) self.add_job_data("end_time", time.strftime("%Y-%m-%d %H:%M:%S")) raise ValueError("Dataset name not provided") @@ -274,18 +276,25 @@ def load_dataset(self, dataset_types: List[str] = ["train"], config_name: str = dataset_target = get_dataset_path(self.params.dataset_name) # If this is a directory, prepare data_files excluding index.json and hidden files - is_dir = isinstance(dataset_target, str) and ( - storage.isdir(dataset_target) if storage.exists(dataset_target) else os.path.isdir(dataset_target) - ) + async def _check_dir(): + if await storage.exists(dataset_target): + return await storage.isdir(dataset_target) + return os.path.isdir(dataset_target) + + is_dir = isinstance(dataset_target, str) and asyncio.run(_check_dir()) data_files_map = None if is_dir: try: - if storage.exists(dataset_target): - entries_full = storage.ls(dataset_target) - # normalize to basenames - entries = [e.rstrip("/").split("/")[-1] for e in entries_full] - else: - entries = os.listdir(dataset_target) + + async def _get_entries(): + if await storage.exists(dataset_target): + entries_full = await storage.ls(dataset_target) + # normalize to basenames + return [e.rstrip("/").split("/")[-1] for e in entries_full] + else: + return os.listdir(dataset_target) + + entries = asyncio.run(_get_entries()) except Exception: entries = [] @@ -297,17 +306,21 @@ def load_dataset(self, dataset_types: List[str] = ["train"], config_name: str = lower = name.lower() if not (lower.endswith(".json") or lower.endswith(".jsonl") or lower.endswith(".csv")): continue - full_path = ( - storage.join(dataset_target, name) - if storage.exists(dataset_target) - else os.path.join(dataset_target, name) - ) - if storage.exists(dataset_target): - if storage.isfile(full_path): - filtered_files.append(full_path) - else: - if os.path.isfile(full_path): - filtered_files.append(full_path) + + async def _check_file(): + if await storage.exists(dataset_target): + full_path = storage.join(dataset_target, name) + if await storage.isfile(full_path): + return full_path + else: + full_path = os.path.join(dataset_target, name) + if os.path.isfile(full_path): + return full_path + return None + + full_path = asyncio.run(_check_file()) + if full_path: + filtered_files.append(full_path) if len(filtered_files) > 0: data_files_map = {"train": filtered_files} @@ -390,8 +403,8 @@ def load_dataset(self, dataset_types: List[str] = ["train"], config_name: str = except Exception as e: error_msg = f"Error loading dataset: {str(e)}\n{traceback.format_exc()}" print(error_msg) - self.job.update_job_data_field("completion_status", "failed") - self.job.update_job_data_field("completion_details", "Failed to load dataset") + asyncio.run(self.job.update_job_data_field("completion_status", "failed")) + asyncio.run(self.job.update_job_data_field("completion_details", "Failed to load dataset")) self.add_job_data("end_time", time.strftime("%Y-%m-%d %H:%M:%S")) raise @@ -454,7 +467,7 @@ def load_evaluation_model(self, field_name="generation_model", model_type=None, return self._create_local_model_wrapper(custom_model, model_name) elif model_type == "claude": - anthropic_api_key = tlab_core.get_db_config_value("ANTHROPIC_API_KEY") + anthropic_api_key = asyncio.run(tlab_core.get_db_config_value("ANTHROPIC_API_KEY")) if not anthropic_api_key or anthropic_api_key.strip() == "": raise ValueError("Please set the Anthropic API Key from Settings.") @@ -462,14 +475,14 @@ def load_evaluation_model(self, field_name="generation_model", model_type=None, return self._create_commercial_model_wrapper("claude", generation_model) elif model_type == "azure": - azure_api_details = tlab_core.get_db_config_value("AZURE_OPENAI_DETAILS") + azure_api_details = asyncio.run(tlab_core.get_db_config_value("AZURE_OPENAI_DETAILS")) if not azure_api_details or azure_api_details.strip() == "": raise ValueError("Please set the Azure OpenAI Details from Settings.") return self._create_commercial_model_wrapper("azure", "") elif model_type == "openai": - openai_api_key = tlab_core.get_db_config_value("OPENAI_API_KEY") + openai_api_key = asyncio.run(tlab_core.get_db_config_value("OPENAI_API_KEY")) if not openai_api_key or openai_api_key.strip() == "": raise ValueError("Please set the OpenAI API Key from Settings.") @@ -477,7 +490,7 @@ def load_evaluation_model(self, field_name="generation_model", model_type=None, return self._create_commercial_model_wrapper("openai", generation_model) elif model_type == "custom": - custom_api_details = tlab_core.get_db_config_value("CUSTOM_MODEL_API_KEY") + custom_api_details = asyncio.run(tlab_core.get_db_config_value("CUSTOM_MODEL_API_KEY")) if not custom_api_details or custom_api_details.strip() == "": raise ValueError("Please set the Custom API Details from Settings.") @@ -531,7 +544,7 @@ def check_local_server(self, model_server=None): def _start_worker_sync(self, model_server=None): """Start the local model server and wait for it to be ready""" # Get experiment_id from the job - experiment_id = self.job.get_experiment_id() if self.job else None + experiment_id = asyncio.run(self.job.get_experiment_id()) if self.job else None params = { "model_name": self.params.model_name, @@ -675,7 +688,7 @@ def __init__(self, model_type="claude", model_name="claude-3-7-sonnet-latest"): if model_type == "claude": self.chat_completions_url = "https://api.anthropic.com/v1/chat/completions" self.base_url = "https://api.anthropic.com/v1" - anthropic_api_key = tlab_core.get_db_config_value("ANTHROPIC_API_KEY") + anthropic_api_key = asyncio.run(tlab_core.get_db_config_value("ANTHROPIC_API_KEY")) self.api_key = anthropic_api_key if not anthropic_api_key or anthropic_api_key.strip() == "": raise ValueError("Please set the Anthropic API Key from Settings.") @@ -683,7 +696,7 @@ def __init__(self, model_type="claude", model_name="claude-3-7-sonnet-latest"): os.environ["ANTHROPIC_API_KEY"] = anthropic_api_key self.model = Anthropic() elif model_type == "azure": - azure_api_details = tlab_core.get_db_config_value("AZURE_OPENAI_DETAILS") + azure_api_details = asyncio.run(tlab_core.get_db_config_value("AZURE_OPENAI_DETAILS")) if not azure_api_details or azure_api_details.strip() == "": raise ValueError("Please set the Azure OpenAI Details from Settings.") azure_api_details = json.loads(azure_api_details) @@ -702,7 +715,7 @@ def __init__(self, model_type="claude", model_name="claude-3-7-sonnet-latest"): elif model_type == "openai": self.chat_completions_url = "https://api.openai.com/v1/chat/completions" self.base_url = "https://api.openai.com/v1" - openai_api_key = tlab_core.get_db_config_value("OPENAI_API_KEY") + openai_api_key = asyncio.run(tlab_core.get_db_config_value("OPENAI_API_KEY")) self.api_key = openai_api_key if not openai_api_key or openai_api_key.strip() == "": raise ValueError("Please set the OpenAI API Key from Settings.") @@ -711,7 +724,7 @@ def __init__(self, model_type="claude", model_name="claude-3-7-sonnet-latest"): self.model = OpenAI() elif model_type == "custom": - custom_api_details = tlab_core.get_db_config_value("CUSTOM_MODEL_API_KEY") + custom_api_details = asyncio.run(tlab_core.get_db_config_value("CUSTOM_MODEL_API_KEY")) if not custom_api_details or custom_api_details.strip() == "": raise ValueError("Please set the Custom API Details from Settings.") diff --git a/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/train.py b/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/train.py index fa43c044b..d13ed91cc 100644 --- a/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/train.py +++ b/api/transformerlab/plugin_sdk/transformerlab/sdk/v1/train.py @@ -1,3 +1,4 @@ +import asyncio import hashlib import json import time @@ -174,8 +175,11 @@ def load_config(self): import json # Load configuration from file - with storage.open(self.params.input_file, "r", encoding="utf-8") as json_file: - input_config = json.load(json_file) + async def _load_config(): + async with await storage.open(self.params.input_file, "r", encoding="utf-8") as json_file: + return json.loads(await json_file.read()) + + input_config = asyncio.run(_load_config()) if "config" in input_config: self.params._config = input_config["config"] @@ -190,8 +194,8 @@ def load_config(self): except Exception as e: error_msg = f"Error loading configuration: {str(e)}\n{traceback.format_exc()}" print(error_msg) - self.job.update_job_data_field("completion_status", "failed") - self.job.update_job_data_field("completion_details", "Error loading configuration") + asyncio.run(self.job.update_job_data_field("completion_status", "failed")) + asyncio.run(self.job.update_job_data_field("completion_details", "Error loading configuration")) self.add_job_data("end_time", time.strftime("%Y-%m-%d %H:%M:%S")) raise @@ -224,7 +228,7 @@ def setup_train_logging(self, wandb_project_name: str = "TLab_Training", manual_ print("Writing tensorboard logs to:", output_dir) # Ensure directory exists - storage.makedirs(self.params.tensorboard_output_dir, exist_ok=True) + asyncio.run(storage.makedirs(self.params.tensorboard_output_dir, exist_ok=True)) self.writer = SummaryWriter(self.params.tensorboard_output_dir) @@ -268,91 +272,6 @@ def setup_train_logging(self, wandb_project_name: str = "TLab_Training", manual_ self.report_to = report_to - # def _get_system_metrics(self): - # """Collect system metrics for logging (CPU, RAM, VRAM, etc.), using select macmon metrics on macOS if available.""" - # import psutil - # import torch - # import sys - - # metrics = {} - # if sys.platform == "darwin": - # try: - # from macmon import MacMon - # import json as _json - - # macmon = MacMon() - # data = macmon.get_metrics() - # if isinstance(data, str): - # mac_metrics = _json.loads(data) - # else: - # mac_metrics = data - # mm = mac_metrics - # if "cpu_power" in mm: - # metrics["system/cpu_power"] = mm["cpu_power"] - # if "gpu_power" in mm: - # metrics["system/gpu_power"] = mm["gpu_power"] - # if "ram_power" in mm: - # metrics["system/ram_power"] = mm["ram_power"] - # if "all_power" in mm: - # metrics["system/all_power"] = mm["all_power"] - # if "sys_power" in mm: - # metrics["system/sys_power"] = mm["sys_power"] - # if "gpu_usage" in mm and isinstance(mm["gpu_usage"], list) and len(mm["gpu_usage"]) == 2: - # metrics["system/gpu_usage_id"] = mm["gpu_usage"][0] - # metrics["system/gpu_usage_percent"] = mm["gpu_usage"][1] - # if "ecpu_usage" in mm and isinstance(mm["ecpu_usage"], list) and len(mm["ecpu_usage"]) == 2: - # metrics["system/ecpu_usage_id"] = mm["ecpu_usage"][0] - # metrics["system/ecpu_usage_percent"] = mm["ecpu_usage"][1] - # if "pcpu_usage" in mm and isinstance(mm["pcpu_usage"], list) and len(mm["pcpu_usage"]) == 2: - # metrics["system/pcpu_usage_id"] = mm["pcpu_usage"][0] - # metrics["system/pcpu_usage_percent"] = mm["pcpu_usage"][1] - # if "temp" in mm and isinstance(mm["temp"], dict): - # if "cpu_temp_avg" in mm["temp"]: - # metrics["system/cpu_temp_avg"] = mm["temp"]["cpu_temp_avg"] - # if "gpu_temp_avg" in mm["temp"]: - # metrics["system/gpu_temp_avg"] = mm["temp"]["gpu_temp_avg"] - # if "memory" in mm and isinstance(mm["memory"], dict): - # if "ram_total" in mm["memory"]: - # metrics["system/ram_total"] = mm["memory"]["ram_total"] - # if "ram_usage" in mm["memory"]: - # metrics["system/ram_usage"] = mm["memory"]["ram_usage"] - # if "swap_total" in mm["memory"]: - # metrics["system/swap_total"] = mm["memory"]["swap_total"] - # if "swap_usage" in mm["memory"]: - # metrics["system/swap_usage"] = mm["memory"]["swap_usage"] - # except Exception: - # # Fallback to psutil/torch if macmon fails - # metrics["system/cpu_percent"] = psutil.cpu_percent() - # metrics["system/ram_used_mb"] = psutil.virtual_memory().used / (1024 * 1024) - # metrics["system/ram_total_mb"] = psutil.virtual_memory().total / (1024 * 1024) - # metrics["system/ram_percent"] = psutil.virtual_memory().percent - # else: - # # CPU and RAM - # metrics["system/cpu_percent"] = psutil.cpu_percent() - # metrics["system/ram_used_mb"] = psutil.virtual_memory().used / (1024 * 1024) - # metrics["system/ram_total_mb"] = psutil.virtual_memory().total / (1024 * 1024) - # metrics["system/ram_percent"] = psutil.virtual_memory().percent - - # # Device-specific metrics - # if torch.cuda.is_available(): - # try: - # import pynvml - - # pynvml.nvmlInit() - # # Get metrics for the main GPU - # handle = pynvml.nvmlDeviceGetHandleByIndex(0) - # meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) - # metrics["system/vram_used_mb"] = meminfo.used / (1024 * 1024) - # metrics["system/vram_total_mb"] = meminfo.total / (1024 * 1024) - # metrics["system/vram_free_mb"] = meminfo.free / (1024 * 1024) - # util = pynvml.nvmlDeviceGetUtilizationRates(handle) - # metrics["system/gpu_utilization"] = util.gpu - # except Exception: - # metrics["system/vram_used_mb"] = -1 - # metrics["system/vram_total_mb"] = -1 - # metrics["system/gpu_utilization"] = -1 - # return metrics - def log_metric(self, metric_name: str, metric_value: float, step: int = None, logging_platforms: bool = True): """Log a metric to all reporting targets""" if logging_platforms: @@ -378,13 +297,17 @@ def log_metric(self, metric_name: str, metric_value: float, step: int = None, lo try: # Ensure output_dir exists output_dir = self.params.get("output_dir", "") - if output_dir and storage.exists(output_dir): - # Save metrics to a JSON file - metrics_path = storage.join(output_dir, "metrics.json") - with storage.open(metrics_path, "w", encoding="utf-8") as f: - json.dump(self._metrics, f, indent=2) - else: - print(f"Output directory not found or not specified: {output_dir}") + + async def _save_metrics(): + if output_dir and await storage.exists(output_dir): + # Save metrics to a JSON file + metrics_path = storage.join(output_dir, "metrics.json") + async with await storage.open(metrics_path, "w", encoding="utf-8") as f: + await f.write(json.dumps(self._metrics, indent=2)) + else: + print(f"Output directory not found or not specified: {output_dir}") + + asyncio.run(_save_metrics()) except Exception as e: print(f"Error saving metrics to file: {str(e)}") @@ -431,12 +354,15 @@ def create_transformerlab_model( # The actual filename will be set by the export process, so we don't set it here # For now, if it's GGUF and the file exists, use the filename if "GGUF" in model_architecture.upper() or model_architecture.upper() == "GGUF": - if storage.exists(fused_model_location): - if storage.isfile(fused_model_location): - # File-based model - use the filename - model_filename = posixpath.basename(fused_model_location) - # If it's a directory for GGUF, keep "." (directory-based) - # This shouldn't normally happen for GGUF, but handle it gracefully + + async def _check_gguf(): + if await storage.exists(fused_model_location): + if await storage.isfile(fused_model_location): + # File-based model - use the filename + return posixpath.basename(fused_model_location) + return "." + + model_filename = asyncio.run(_check_gguf()) # If GGUF file doesn't exist yet, the export process will set the filename if generate_json: @@ -461,35 +387,38 @@ def create_transformerlab_model( print(f"Provenance file created at: {provenance_file}") def create_md5_checksum_model_files(self, fused_model_location): - def compute_md5(file_path): + async def compute_md5(file_path): md5 = hashlib.md5() - with storage.open(file_path, "rb") as f: + async with await storage.open(file_path, "rb") as f: while True: - chunk = f.read(8192) + chunk = await f.read(8192) if not chunk: break md5.update(chunk) return md5.hexdigest() - md5_objects = [] + async def _create_checksums(): + md5_objects = [] - if not storage.isdir(fused_model_location): - print("Fused model location is not a directory, skipping md5 within provenance") - return md5_objects + if not await storage.isdir(fused_model_location): + print("Fused model location is not a directory, skipping md5 within provenance") + return md5_objects - # Walk directory using storage - stack = [fused_model_location] - while stack: - current_dir = stack.pop() - for entry in storage.ls(current_dir): - if storage.isdir(entry): - stack.append(entry) - else: - file_path = entry - md5_hash = compute_md5(file_path) - md5_objects.append({"file_path": file_path, "md5_hash": md5_hash}) + # Walk directory using storage + stack = [fused_model_location] + while stack: + current_dir = stack.pop() + for entry in await storage.ls(current_dir): + if await storage.isdir(entry): + stack.append(entry) + else: + file_path = entry + md5_hash = await compute_md5(file_path) + md5_objects.append({"file_path": file_path, "md5_hash": md5_hash}) + + return md5_objects - return md5_objects + return asyncio.run(_create_checksums()) def create_provenance_file(self, model_location, model_name, model_architecture, md5_objects): """Create a _tlab_provenance.json file containing model provenance data""" @@ -512,11 +441,13 @@ def create_provenance_file(self, model_location, model_name, model_architecture, } # Write provenance to file - provenance_path = storage.join(model_location, "_tlab_provenance.json") - with storage.open(provenance_path, "w", encoding="utf-8") as f: - json.dump(provenance_data, f, indent=2) + async def _write_provenance(): + provenance_path = storage.join(model_location, "_tlab_provenance.json") + async with await storage.open(provenance_path, "w", encoding="utf-8") as f: + await f.write(json.dumps(provenance_data, indent=2)) + return provenance_path - return provenance_path + return asyncio.run(_write_provenance()) # Create an instance of the TrainerTLabPlugin class diff --git a/api/transformerlab/plugins/airllm_mlx_server/index.json b/api/transformerlab/plugins/airllm_mlx_server/index.json index 874f6806b..29fb061a4 100644 --- a/api/transformerlab/plugins/airllm_mlx_server/index.json +++ b/api/transformerlab/plugins/airllm_mlx_server/index.json @@ -4,7 +4,7 @@ "description": "Performing inference using AirLLM on your Apple Silicon machine", "plugin-format": "python", "type": "loader", - "version": "0.1.8", + "version": "0.1.9", "model_architectures": [ "MLX", "CohereForCausalLM", diff --git a/api/transformerlab/plugins/autotrain_sft_trainer/index.json b/api/transformerlab/plugins/autotrain_sft_trainer/index.json index 00f69045a..6017b6528 100644 --- a/api/transformerlab/plugins/autotrain_sft_trainer/index.json +++ b/api/transformerlab/plugins/autotrain_sft_trainer/index.json @@ -4,7 +4,7 @@ "description": "SFT training using Huggingface autotrain", "plugin-format": "python", "type": "trainer", - "version": "0.1.15", + "version": "0.1.16", "model_architectures": [ "LlamaForCausalLM", "MistralForCausalLM", diff --git a/api/transformerlab/plugins/basic_evals/index.json b/api/transformerlab/plugins/basic_evals/index.json index 4f47d8a06..4f7a6381d 100644 --- a/api/transformerlab/plugins/basic_evals/index.json +++ b/api/transformerlab/plugins/basic_evals/index.json @@ -5,7 +5,7 @@ "description": "Evaluating outputs of LLMs using basic defined metrics", "plugin-format": "python", "type": "evaluator", - "version": "0.1.9", + "version": "0.1.10", "git": "", "url": "", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/batched_generation_datasets/index.json b/api/transformerlab/plugins/batched_generation_datasets/index.json index 829db4f75..7716449f2 100644 --- a/api/transformerlab/plugins/batched_generation_datasets/index.json +++ b/api/transformerlab/plugins/batched_generation_datasets/index.json @@ -4,7 +4,7 @@ "description": "Use a local or commercial LLM to generated outputs for the dataset generated.", "plugin-format": "python", "type": "generator", - "version": "0.1.12", + "version": "0.1.13", "git": "", "url": "", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/common-eleuther-ai-lm-eval-harness-mlx/index.json b/api/transformerlab/plugins/common-eleuther-ai-lm-eval-harness-mlx/index.json index 2dac4e4fa..2cebb08d1 100644 --- a/api/transformerlab/plugins/common-eleuther-ai-lm-eval-harness-mlx/index.json +++ b/api/transformerlab/plugins/common-eleuther-ai-lm-eval-harness-mlx/index.json @@ -5,7 +5,7 @@ "plugin-format": "python", "type": "evaluator", "evalsType": "model", - "version": "0.1.13", + "version": "0.1.14", "git": "https://github.com/chimezie/lm-evaluation-harness-mlx", "url": "https://github.com/chimezie/lm-evaluation-harness-mlx", "supported_hardware_architectures": ["mlx"], diff --git a/api/transformerlab/plugins/common-eleuther-ai-lm-eval-harness/index.json b/api/transformerlab/plugins/common-eleuther-ai-lm-eval-harness/index.json index 1880eb955..898843004 100644 --- a/api/transformerlab/plugins/common-eleuther-ai-lm-eval-harness/index.json +++ b/api/transformerlab/plugins/common-eleuther-ai-lm-eval-harness/index.json @@ -5,7 +5,7 @@ "plugin-format": "python", "type": "evaluator", "evalsType": "model", - "version": "0.1.16", + "version": "0.1.17", "git": "https://github.com/EleutherAI/lm-evaluation-harness", "url": "https://github.com/EleutherAI/lm-evaluation-harness", "supported_hardware_architectures": ["cuda", "amd"], diff --git a/api/transformerlab/plugins/common-eleuther-ai-lm-eval-harness/main.py b/api/transformerlab/plugins/common-eleuther-ai-lm-eval-harness/main.py index c0c7625c3..1597c78bf 100644 --- a/api/transformerlab/plugins/common-eleuther-ai-lm-eval-harness/main.py +++ b/api/transformerlab/plugins/common-eleuther-ai-lm-eval-harness/main.py @@ -2,7 +2,9 @@ import re import subprocess import json +import asyncio from lab import storage +from lab.dirs import get_workspace_dir import pandas as pd import torch @@ -88,10 +90,10 @@ def run_evaluation(): model_args = f"model={model_name},trust_remote_code=True" if tlab_evals.params.model_adapter and tlab_evals.params.model_adapter.strip() != "": - from transformerlab.plugin import WORKSPACE_DIR + workspace_dir = asyncio.run(get_workspace_dir()) adapter_path = storage.join( - WORKSPACE_DIR, + workspace_dir, "adaptors", secure_filename(tlab_evals.params.model_name), tlab_evals.params.model_adapter, @@ -115,10 +117,10 @@ def run_evaluation(): model_args = f"pretrained={model_name},trust_remote_code=True" if tlab_evals.params.model_adapter and tlab_evals.params.model_adapter.strip() != "": - from transformerlab.plugin import WORKSPACE_DIR + workspace_dir = asyncio.run(get_workspace_dir()) adapter_path = storage.join( - WORKSPACE_DIR, + workspace_dir, "adaptors", secure_filename(tlab_evals.params.model_name), tlab_evals.params.model_adapter, diff --git a/api/transformerlab/plugins/dataset_imagegen/index.json b/api/transformerlab/plugins/dataset_imagegen/index.json index 27df7ac27..ebec52a69 100644 --- a/api/transformerlab/plugins/dataset_imagegen/index.json +++ b/api/transformerlab/plugins/dataset_imagegen/index.json @@ -4,7 +4,7 @@ "description": "Generate images using a text-to-image model from a set of user-provided prompts.", "plugin-format": "python", "type": "generator", - "version": "0.1.7", + "version": "0.1.8", "git": "", "url": "", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/dataset_imagegen/main.py b/api/transformerlab/plugins/dataset_imagegen/main.py index 89e550835..fc4f739bf 100644 --- a/api/transformerlab/plugins/dataset_imagegen/main.py +++ b/api/transformerlab/plugins/dataset_imagegen/main.py @@ -2,8 +2,10 @@ import shutil import requests import pandas as pd +import asyncio from tqdm import tqdm from transformerlab.sdk.v1.generate import tlab_gen +from lab.dirs import get_workspace_dir @tlab_gen.job_wrapper(progress_start=0, progress_end=100) @@ -21,9 +23,9 @@ def run_generation(): negative_column = tlab_gen.params.negative_prompt_column.strip() dataset_id = tlab_gen.params.get("output_dataset_name") - from transformerlab.plugin import WORKSPACE_DIR + workspace_dir = asyncio.run(get_workspace_dir()) - output_dir = os.path.join(WORKSPACE_DIR, "datasets", dataset_id) + output_dir = os.path.join(workspace_dir, "datasets", dataset_id) os.makedirs(output_dir, exist_ok=True) # Load dataset diff --git a/api/transformerlab/plugins/deepeval_llm_judge/index.json b/api/transformerlab/plugins/deepeval_llm_judge/index.json index db7fa4739..450d8c0aa 100644 --- a/api/transformerlab/plugins/deepeval_llm_judge/index.json +++ b/api/transformerlab/plugins/deepeval_llm_judge/index.json @@ -5,7 +5,7 @@ "plugin-format": "python", "evalsType": "dataset", "type": "evaluator", - "version": "0.2.18", + "version": "0.2.19", "git": "https://github.com/confident-ai/deepeval", "url": "https://github.com/confident-ai/deepeval", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/deepeval_objective/index.json b/api/transformerlab/plugins/deepeval_objective/index.json index 1d0490274..bc324cbe3 100644 --- a/api/transformerlab/plugins/deepeval_objective/index.json +++ b/api/transformerlab/plugins/deepeval_objective/index.json @@ -5,7 +5,7 @@ "plugin-format": "python", "type": "evaluator", "evalsType": "dataset", - "version": "0.2.15", + "version": "0.2.16", "git": "https://github.com/confident-ai/deepeval", "url": "https://github.com/confident-ai/deepeval", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/diffusion_trainer/index.json b/api/transformerlab/plugins/diffusion_trainer/index.json index b48cce62e..0d59be36d 100644 --- a/api/transformerlab/plugins/diffusion_trainer/index.json +++ b/api/transformerlab/plugins/diffusion_trainer/index.json @@ -4,7 +4,7 @@ "description": "A plugin for fine-tuning Stable Diffusion using LoRA adapters.", "plugin-format": "python", "type": "trainer", - "version": "0.1.9", + "version": "0.1.10", "git": "", "url": "", "model_architectures": [ diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index ee87c94d7..3510e1851 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -2,6 +2,7 @@ import random import json import gc +import asyncio import numpy as np import torch @@ -29,7 +30,7 @@ from lab.dirs import get_workspace_dir from lab import storage -workspace_dir = get_workspace_dir() +workspace_dir = asyncio.run(get_workspace_dir()) def cleanup_pipeline(): diff --git a/api/transformerlab/plugins/dllm_evaluator/index.json b/api/transformerlab/plugins/dllm_evaluator/index.json index b668da5ae..50a6438fc 100644 --- a/api/transformerlab/plugins/dllm_evaluator/index.json +++ b/api/transformerlab/plugins/dllm_evaluator/index.json @@ -5,7 +5,7 @@ "plugin-format": "python", "type": "evaluator", "evalsType": "model", - "version": "0.0.3", + "version": "0.0.4", "git": "https://github.com/ZHZisZZ/dllm", "url": "https://github.com/ZHZisZZ/dllm", "supported_hardware_architectures": ["cuda"], diff --git a/api/transformerlab/plugins/dllm_server/index.json b/api/transformerlab/plugins/dllm_server/index.json index 8df7a97e3..dbe0154a3 100644 --- a/api/transformerlab/plugins/dllm_server/index.json +++ b/api/transformerlab/plugins/dllm_server/index.json @@ -4,7 +4,7 @@ "description": "Inference server for diffusion language models (dLLM) - supports LLaDA, Dream, and other diffusion-based text generation models", "plugin-format": "python", "type": "loader", - "version": "0.0.3", + "version": "0.0.4", "supports": ["text_diffusion"], "model_architectures": [ "BertForMaskedLM", diff --git a/api/transformerlab/plugins/dllm_trainer_multi_gpu/index.json b/api/transformerlab/plugins/dllm_trainer_multi_gpu/index.json index 253066c4b..e61805c2c 100644 --- a/api/transformerlab/plugins/dllm_trainer_multi_gpu/index.json +++ b/api/transformerlab/plugins/dllm_trainer_multi_gpu/index.json @@ -4,7 +4,7 @@ "description": "A training plugin using dllm library for SFT training with support for bert, dream, and llada training methods", "plugin-format": "python", "type": "trainer", - "version": "0.0.3", + "version": "0.0.5", "model_architectures": [ "BertForMaskedLM", "ModernBertForMaskedLM", diff --git a/api/transformerlab/plugins/dllm_trainer_multi_gpu/main.py b/api/transformerlab/plugins/dllm_trainer_multi_gpu/main.py index 456eb159a..fa74cb279 100644 --- a/api/transformerlab/plugins/dllm_trainer_multi_gpu/main.py +++ b/api/transformerlab/plugins/dllm_trainer_multi_gpu/main.py @@ -1,9 +1,11 @@ import os import subprocess +import asyncio from functools import partial from transformerlab.sdk.v1.train import tlab_trainer -from transformerlab.plugin import WORKSPACE_DIR, get_python_executable +from transformerlab.plugin import get_python_executable +from lab.dirs import get_workspace_dir # Add custom arguments tlab_trainer.add_argument( @@ -383,7 +385,8 @@ def convert_to_messages(row): adaptor_name = tlab_trainer.params.get("adaptor_name", "default") fused_model_name = f"{model_id_short}_{adaptor_name}" - fused_model_location = os.path.join(WORKSPACE_DIR, "models", fused_model_name) + workspace_dir = asyncio.run(get_workspace_dir()) + fused_model_location = os.path.join(workspace_dir, "models", fused_model_name) trainer.save_model(output_dir=fused_model_location) if hasattr(trainer, "processing_class") and trainer.processing_class: diff --git a/api/transformerlab/plugins/dpo_orpo_simpo_trainer_llama_factory/index.json b/api/transformerlab/plugins/dpo_orpo_simpo_trainer_llama_factory/index.json index 8a202b725..2f6fe76e4 100644 --- a/api/transformerlab/plugins/dpo_orpo_simpo_trainer_llama_factory/index.json +++ b/api/transformerlab/plugins/dpo_orpo_simpo_trainer_llama_factory/index.json @@ -4,7 +4,7 @@ "description": "An implementation of several Preference Optimization methods using Llama Factory.", "plugin-format": "python", "type": "trainer", - "version": "0.0.16", + "version": "0.0.17", "model_architectures": [ "LlamaForCausalLM", "MistralForCausalLM", diff --git a/api/transformerlab/plugins/dpo_orpo_simpo_trainer_llama_factory/main.py b/api/transformerlab/plugins/dpo_orpo_simpo_trainer_llama_factory/main.py index 4be4451b4..902125224 100644 --- a/api/transformerlab/plugins/dpo_orpo_simpo_trainer_llama_factory/main.py +++ b/api/transformerlab/plugins/dpo_orpo_simpo_trainer_llama_factory/main.py @@ -14,6 +14,7 @@ import json import yaml import re +import asyncio from transformerlab.sdk.v1.train import tlab_trainer from transformerlab.plugin import get_python_executable @@ -33,7 +34,7 @@ @tlab_trainer.job_wrapper(progress_start=0, progress_end=100, wandb_project_name="LlamaFactory_DPO") def run_train(): # Directory for storing temporary working files - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) data_directory = storage.join(workspace_dir, "temp", "llama_factory_reward", "data") storage.makedirs(data_directory, exist_ok=True) @@ -194,7 +195,7 @@ def fuse_model(): """Fuse the adapter with the base model""" print("Now fusing the adaptor with the model.") - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) data_directory = storage.join(workspace_dir, "temp", "llama_factory_reward", "data") model_name = tlab_trainer.params.model_name adaptor_name = tlab_trainer.params.adaptor_name diff --git a/api/transformerlab/plugins/eleuther-ai-lm-evaluation-harness-mlx/index.json b/api/transformerlab/plugins/eleuther-ai-lm-evaluation-harness-mlx/index.json index 0ba4c0fa6..de10f950b 100644 --- a/api/transformerlab/plugins/eleuther-ai-lm-evaluation-harness-mlx/index.json +++ b/api/transformerlab/plugins/eleuther-ai-lm-evaluation-harness-mlx/index.json @@ -4,7 +4,7 @@ "description": "A harness for evaluating language models using MLX.", "plugin-format": "python", "type": "evaluator", - "version": "0.2.14", + "version": "0.2.15", "git": "https://github.com/chimezie/lm-evaluation-harness-mlx", "url": "https://github.com/chimezie/lm-evaluation-harness-mlx", "supported_hardware_architectures": ["mlx"], diff --git a/api/transformerlab/plugins/eleuther-ai-lm-evaluation-harness/index.json b/api/transformerlab/plugins/eleuther-ai-lm-evaluation-harness/index.json index 2325aa9e2..7188d6aee 100644 --- a/api/transformerlab/plugins/eleuther-ai-lm-evaluation-harness/index.json +++ b/api/transformerlab/plugins/eleuther-ai-lm-evaluation-harness/index.json @@ -4,7 +4,7 @@ "description": "A harness for evaluating language models.", "plugin-format": "python", "type": "evaluator", - "version": "0.2.17", + "version": "0.2.18", "git": "https://github.com/EleutherAI/lm-evaluation-harness", "url": "https://github.com/EleutherAI/lm-evaluation-harness", "supported_hardware_architectures": ["cuda", "amd"], diff --git a/api/transformerlab/plugins/eleuther-ai-lm-evaluation-harness/main.py b/api/transformerlab/plugins/eleuther-ai-lm-evaluation-harness/main.py index eb455dfbc..fc68a98d5 100644 --- a/api/transformerlab/plugins/eleuther-ai-lm-evaluation-harness/main.py +++ b/api/transformerlab/plugins/eleuther-ai-lm-evaluation-harness/main.py @@ -2,7 +2,9 @@ import re import subprocess import json +import asyncio from lab import storage +from lab.dirs import get_workspace_dir import pandas as pd import torch @@ -87,10 +89,10 @@ def run_evaluation(): model_args = f"model={model_name},trust_remote_code=True" if tlab_evals.params.model_adapter and tlab_evals.params.model_adapter.strip() != "": - from transformerlab.plugin import WORKSPACE_DIR + workspace_dir = asyncio.run(get_workspace_dir()) adapter_path = storage.join( - WORKSPACE_DIR, + workspace_dir, "adaptors", secure_filename(tlab_evals.params.model_name), tlab_evals.params.model_adapter, @@ -114,10 +116,10 @@ def run_evaluation(): model_args = f"pretrained={model_name},trust_remote_code=True" if tlab_evals.params.model_adapter and tlab_evals.params.model_adapter.strip() != "": - from transformerlab.plugin import WORKSPACE_DIR + workspace_dir = asyncio.run(get_workspace_dir()) adapter_path = storage.join( - WORKSPACE_DIR, + workspace_dir, "adaptors", secure_filename(tlab_evals.params.model_name), tlab_evals.params.model_adapter, diff --git a/api/transformerlab/plugins/embedding_model_trainer/index.json b/api/transformerlab/plugins/embedding_model_trainer/index.json index 9a35b6dd6..4f8b93521 100644 --- a/api/transformerlab/plugins/embedding_model_trainer/index.json +++ b/api/transformerlab/plugins/embedding_model_trainer/index.json @@ -5,7 +5,7 @@ "plugin-format": "python", "type": "trainer", "train_type": "embedding", - "version": "0.0.11", + "version": "0.0.12", "model_architectures": [ "BertModel", "SentenceTransformer", diff --git a/api/transformerlab/plugins/embedding_model_trainer/main.py b/api/transformerlab/plugins/embedding_model_trainer/main.py index 92b990f03..7d44500e9 100644 --- a/api/transformerlab/plugins/embedding_model_trainer/main.py +++ b/api/transformerlab/plugins/embedding_model_trainer/main.py @@ -1,6 +1,7 @@ import os import random import torch +import asyncio from datasets import Dataset from sentence_transformers import ( @@ -144,7 +145,7 @@ def train_embedding_model(): final_model_name = f"{template_name}_{job_id}" # Define output directory - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) output_dir = storage.join(workspace_dir, "models", final_model_name) storage.makedirs(output_dir, exist_ok=True) diff --git a/api/transformerlab/plugins/fastchat_server/index.json b/api/transformerlab/plugins/fastchat_server/index.json index b949fbb35..fbe7550f4 100644 --- a/api/transformerlab/plugins/fastchat_server/index.json +++ b/api/transformerlab/plugins/fastchat_server/index.json @@ -4,7 +4,7 @@ "description": "Fastchat loads models for inference using Huggingface Transformers for generation.", "plugin-format": "python", "type": "loader", - "version": "1.0.30", + "version": "1.0.31", "supports": [ "chat", "completion", diff --git a/api/transformerlab/plugins/fastchat_vision_server/index.json b/api/transformerlab/plugins/fastchat_vision_server/index.json index 77a332076..1178feda9 100644 --- a/api/transformerlab/plugins/fastchat_vision_server/index.json +++ b/api/transformerlab/plugins/fastchat_vision_server/index.json @@ -4,7 +4,7 @@ "description": "Fastchat loads vision models for inference using Huggingface Transformers for generation.", "plugin-format": "python", "type": "loader", - "version": "1.0.9", + "version": "1.0.10", "model_architectures": [ "LlavaForConditionalGeneration", "Mistral3ForConditionalGeneration" diff --git a/api/transformerlab/plugins/generate_rag_outputs/index.json b/api/transformerlab/plugins/generate_rag_outputs/index.json index a567a66c9..ad5fdab75 100644 --- a/api/transformerlab/plugins/generate_rag_outputs/index.json +++ b/api/transformerlab/plugins/generate_rag_outputs/index.json @@ -4,7 +4,7 @@ "description": "Run RAG on dataset queries and generate evaluation results.", "plugin-format": "python", "type": "generator", - "version": "0.1.7", + "version": "0.1.8", "git": "https://github.com/transformerlab/transformerlab", "url": "https://github.com/transformerlab/transformerlab", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/gguf_exporter/index.json b/api/transformerlab/plugins/gguf_exporter/index.json index aeb74d62f..463c68228 100644 --- a/api/transformerlab/plugins/gguf_exporter/index.json +++ b/api/transformerlab/plugins/gguf_exporter/index.json @@ -4,7 +4,7 @@ "description": "Exports the current model to GGUF format so it can be run on computers without a GPU.", "plugin-format": "python", "type": "exporter", - "version": "0.2.8", + "version": "0.2.9", "model_architectures": [ "CohereForCausalLM", "FalconForCausalLM", diff --git a/api/transformerlab/plugins/gguf_exporter/main.py b/api/transformerlab/plugins/gguf_exporter/main.py index 764e73496..175307ed8 100644 --- a/api/transformerlab/plugins/gguf_exporter/main.py +++ b/api/transformerlab/plugins/gguf_exporter/main.py @@ -1,4 +1,5 @@ # This plugin exports a model to GGUF format so you can interact and train on a MBP with Apple Silicon +import asyncio import os import subprocess import contextlib @@ -30,7 +31,7 @@ def gguf_export(): output_dir = tlab_exporter.params.get("output_dir") # Create output file - storage.makedirs(output_dir, exist_ok=True) + asyncio.run(storage.makedirs(output_dir, exist_ok=True)) plugin_dir = os.path.realpath(os.path.dirname(__file__)) python_executable = get_python_executable(plugin_dir) diff --git a/api/transformerlab/plugins/grpo_trainer_multi_gpu/index.json b/api/transformerlab/plugins/grpo_trainer_multi_gpu/index.json index 4bf23230d..c75493603 100644 --- a/api/transformerlab/plugins/grpo_trainer_multi_gpu/index.json +++ b/api/transformerlab/plugins/grpo_trainer_multi_gpu/index.json @@ -4,7 +4,7 @@ "description": "A GPRO trainer based using a Multi GPU setup trained using TRL and Accelerate", "plugin-format": "python", "type": "trainer", - "version": "0.0.12", + "version": "0.0.13", "model_architectures": [ "LlamaForCausalLM", "Qwen2ForCausalLM", diff --git a/api/transformerlab/plugins/grpo_trainer_multi_gpu/main.py b/api/transformerlab/plugins/grpo_trainer_multi_gpu/main.py index 1276f4ff8..6383c570a 100644 --- a/api/transformerlab/plugins/grpo_trainer_multi_gpu/main.py +++ b/api/transformerlab/plugins/grpo_trainer_multi_gpu/main.py @@ -2,6 +2,7 @@ import time import re import subprocess +import asyncio from transformerlab.sdk.v1.train import tlab_trainer from transformerlab.plugin import get_python_executable @@ -133,7 +134,7 @@ def train_model(): adam_beta1 = float(tlab_trainer.params.adam_beta1) adam_beta2 = float(tlab_trainer.params.adam_beta2) adam_epsilon = float(tlab_trainer.params.adam_epsilon) - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) # Use storage for workspace paths if relative; allow absolute passthrough output_dir = ( tlab_trainer.params.output_dir diff --git a/api/transformerlab/plugins/image_diffusion/index.json b/api/transformerlab/plugins/image_diffusion/index.json index 1065c9e4e..3b79f7b1f 100644 --- a/api/transformerlab/plugins/image_diffusion/index.json +++ b/api/transformerlab/plugins/image_diffusion/index.json @@ -4,7 +4,7 @@ "description": "Generate images in the Diffusion tab using this plugin", "plugin-format": "python", "type": "diffusion", - "version": "0.0.7", + "version": "0.0.8", "git": "", "url": "", "files": ["main.py", "diffusion_worker.py", "setup.sh"], diff --git a/api/transformerlab/plugins/image_diffusion/main.py b/api/transformerlab/plugins/image_diffusion/main.py index ccfe639ca..92af3d6dc 100644 --- a/api/transformerlab/plugins/image_diffusion/main.py +++ b/api/transformerlab/plugins/image_diffusion/main.py @@ -420,14 +420,14 @@ def get_pipeline( # Load LoRA adaptor if provided - same code for local and HF Hub! if adaptor and adaptor.strip(): try: - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) adaptor_dir = storage.join( workspace_dir, "adaptors", secure_filename(model), ) adaptor_path = storage.join(adaptor_dir, secure_filename(adaptor)) - if storage.exists(adaptor_path): + if asyncio.run(storage.exists(adaptor_path)): pipe.load_lora_weights(adaptor_path) # if not isinstance(pipe, StableDiffusionXLPipeline): # pipe.load_lora_weights(adaptor_path) @@ -560,7 +560,7 @@ def get_python_executable(): def get_diffusion_dir(experiment_name: str = None): """Get the diffusion directory path""" - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) if experiment_name is not None: return storage.join(workspace_dir, "experiments", experiment_name, "diffusion") else: @@ -584,11 +584,15 @@ def ensure_directories(experiment_name: str = None): images_dir = get_images_dir(experiment_name) history_file_path = get_history_file_path(experiment_name) - storage.makedirs(diffusion_dir, exist_ok=True) - storage.makedirs(images_dir, exist_ok=True) - if not storage.exists(history_file_path): - with storage.open(history_file_path, "a"): - pass + asyncio.run(storage.makedirs(diffusion_dir, exist_ok=True)) + asyncio.run(storage.makedirs(images_dir, exist_ok=True)) + if not asyncio.run(storage.exists(history_file_path)): + + async def _create_file(): + async with await storage.open(history_file_path, "a"): + pass + + asyncio.run(_create_file()) def save_to_history(item: ImageHistoryItem, experiment_name: str = None): @@ -596,21 +600,24 @@ def save_to_history(item: ImageHistoryItem, experiment_name: str = None): ensure_directories(experiment_name) history_file = get_history_file_path(experiment_name) - # Load existing history - history = [] - if storage.exists(history_file): - try: - with storage.open(history_file, "r") as f: - history = json.load(f) - except (json.JSONDecodeError, FileNotFoundError): - history = [] + async def _save(): + # Load existing history + history = [] + if await storage.exists(history_file): + try: + async with await storage.open(history_file, "r") as f: + history = json.loads(await f.read()) + except (json.JSONDecodeError, FileNotFoundError): + history = [] + + # Add new item to the beginning of the list + history.insert(0, item.model_dump()) - # Add new item to the beginning of the list - history.insert(0, item.model_dump()) + # Save updated history + async with await storage.open(history_file, "w") as f: + await f.write(json.dumps(history, indent=2)) - # Save updated history - with storage.open(history_file, "w") as f: - json.dump(history, f, indent=2) + asyncio.run(_save()) def should_use_diffusion_worker(model) -> bool: @@ -684,8 +691,12 @@ async def run_multi_gpu_generation( # Save config to temporary file ensure_directories(experiment_name) config_path = storage.join(get_diffusion_dir(experiment_name), secure_filename(f"config_{generation_id}.json")) - with storage.open(config_path, "w") as f: - json.dump(config, f, indent=2) + + async def _save_config(): + async with await storage.open(config_path, "w") as f: + await f.write(json.dumps(config, indent=2)) + + await _save_config() # Get worker script path # current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -749,10 +760,10 @@ async def run_multi_gpu_generation( if return_code == -9 or "CUDA out of memory" in combined_output or "OutOfMemoryError" in combined_output: # Try to load any partial result to get OOM details result_path = os.path.join(images_folder, "result.json") - if storage.exists(result_path): + if await storage.exists(result_path): try: - with storage.open(result_path, "r") as f: - worker_result = json.load(f) + async with await storage.open(result_path, "r") as f: + worker_result = json.loads(await f.read()) if worker_result.get("error_type") == "OOM": oom_suggestions = worker_result.get("suggestions", []) suggestion_text = "\n".join([f" • {s}" for s in oom_suggestions]) @@ -769,11 +780,11 @@ async def run_multi_gpu_generation( # Load result from worker result_path = os.path.join(images_folder, "result.json") - if not storage.exists(result_path): + if not await storage.exists(result_path): raise RuntimeError("Worker did not produce result file") - with storage.open(result_path, "r") as f: - worker_result = json.load(f) + async with await storage.open(result_path, "r") as f: + worker_result = json.loads(await f.read()) if not worker_result.get("success", False): error_msg = worker_result.get("error", "Unknown error") @@ -788,8 +799,8 @@ async def run_multi_gpu_generation( # Clean up config file try: - if storage.exists(config_path): - storage.rm(config_path) + if await storage.exists(config_path): + await storage.rm(config_path) except Exception: pass @@ -806,8 +817,8 @@ async def run_multi_gpu_generation( except Exception as e: # Clean up config file on error try: - if storage.exists(config_path): - storage.rm(config_path) + if await storage.exists(config_path): + await storage.rm(config_path) except Exception: pass raise e @@ -873,7 +884,7 @@ async def diffusion_generate_job(): images_folder = os.path.normpath(os.path.join(get_images_dir(experiment_name), generation_id)) if not images_folder.startswith(get_images_dir(experiment_name)): raise HTTPException(status_code=400, detail="Invalid path for images_folder") - storage.makedirs(images_folder, exist_ok=True) + await storage.makedirs(images_folder, exist_ok=True) # Determine pipeline type based on flags and provided images controlnet_id = request.is_controlnet or "off" @@ -1273,8 +1284,8 @@ def run_upscale(): } output_path = os.path.join(images_folder, "tmp_json.json") - with storage.open(output_path, "w") as f: - json.dump(output_data, f, indent=2) + async with await storage.open(output_path, "w") as f: + await f.write(json.dumps(output_data, indent=2)) tlab_diffusion.progress_update(100) diff --git a/api/transformerlab/plugins/inference_evals/index.json b/api/transformerlab/plugins/inference_evals/index.json index 07d059d25..8c227485f 100644 --- a/api/transformerlab/plugins/inference_evals/index.json +++ b/api/transformerlab/plugins/inference_evals/index.json @@ -5,7 +5,7 @@ "plugin-format": "python", "type": "evaluator", "evalsType": "model", - "version": "0.1.10", + "version": "0.1.11", "git": "", "url": "", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/llama_cpp_server/index.json b/api/transformerlab/plugins/llama_cpp_server/index.json index b8a887e13..97c06d05f 100644 --- a/api/transformerlab/plugins/llama_cpp_server/index.json +++ b/api/transformerlab/plugins/llama_cpp_server/index.json @@ -4,7 +4,7 @@ "description": "Runs llama-cpp-python server that can run GGUF models that work well on CPU only machines.", "plugin-format": "python", "type": "loader", - "version": "0.1.17", + "version": "0.1.18", "model_architectures": ["GGUF"], "supported_hardware_architectures": ["cpu", "cuda", "mlx", "amd"], "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/llama_trainer/index.json b/api/transformerlab/plugins/llama_trainer/index.json index 1b649c18a..ce61c194b 100644 --- a/api/transformerlab/plugins/llama_trainer/index.json +++ b/api/transformerlab/plugins/llama_trainer/index.json @@ -4,7 +4,7 @@ "description": "A training script adapted from https://www.philschmid.de/instruction-tune-llama-2 for training Llama2 using PeFT", "plugin-format": "python", "type": "trainer", - "version": "1.0.40", + "version": "1.0.41", "model_architectures": [ "LlamaForCausalLM", "Qwen2ForCausalLM", diff --git a/api/transformerlab/plugins/llama_trainer/main.py b/api/transformerlab/plugins/llama_trainer/main.py index 75bff4ed3..9a1013816 100644 --- a/api/transformerlab/plugins/llama_trainer/main.py +++ b/api/transformerlab/plugins/llama_trainer/main.py @@ -1,6 +1,7 @@ import time import os from random import randrange +import asyncio import torch import shutil @@ -338,7 +339,8 @@ def train_model(): model_id = model_id.split("/")[-1] adaptor_name = tlab_trainer.params.get("adaptor_name", "default") fused_model_name = f"{model_id}_{adaptor_name}" - fused_model_location = storage.join(get_workspace_dir(), "models", fused_model_name) + workspace_dir = asyncio.run(get_workspace_dir()) + fused_model_location = storage.join(workspace_dir, "models", fused_model_name) peft_model = PeftModel.from_pretrained(model, tlab_trainer.params.adaptor_output_dir) merged_model = peft_model.merge_and_unload() merged_model.save_pretrained(fused_model_location) diff --git a/api/transformerlab/plugins/llama_trainer_multi_gpu/index.json b/api/transformerlab/plugins/llama_trainer_multi_gpu/index.json index a32af35bd..b33ed8b1e 100644 --- a/api/transformerlab/plugins/llama_trainer_multi_gpu/index.json +++ b/api/transformerlab/plugins/llama_trainer_multi_gpu/index.json @@ -4,7 +4,7 @@ "description": "A training script adapted from https://www.philschmid.de/instruction-tune-llama-2 for training Llama2 using PeFT", "plugin-format": "python", "type": "trainer", - "version": "0.1.23", + "version": "0.1.24", "model_architectures": [ "LlamaForCausalLM", "Qwen2ForCausalLM", diff --git a/api/transformerlab/plugins/llama_trainer_multi_gpu/main.py b/api/transformerlab/plugins/llama_trainer_multi_gpu/main.py index 92c292508..19035b429 100644 --- a/api/transformerlab/plugins/llama_trainer_multi_gpu/main.py +++ b/api/transformerlab/plugins/llama_trainer_multi_gpu/main.py @@ -4,6 +4,7 @@ from random import randrange import torch.nn as nn from functools import partial +import asyncio from transformerlab.sdk.v1.train import tlab_trainer @@ -292,7 +293,8 @@ def train_model(): model_id = model_id.split("/")[-1] adaptor_name = tlab_trainer.params.get("adaptor_name", "default") fused_model_name = f"{model_id}_{adaptor_name}" - fused_model_location = storage.join(get_workspace_dir(), "models", fused_model_name) + workspace_dir = asyncio.run(get_workspace_dir()) + fused_model_location = storage.join(workspace_dir, "models", fused_model_name) peft_model = PeftModel.from_pretrained(model, tlab_trainer.params.adaptor_output_dir) merged_model = peft_model.merge_and_unload() merged_model.save_pretrained(fused_model_location) diff --git a/api/transformerlab/plugins/llamafile_exporter/index.json b/api/transformerlab/plugins/llamafile_exporter/index.json index 6aeb05b59..990632069 100644 --- a/api/transformerlab/plugins/llamafile_exporter/index.json +++ b/api/transformerlab/plugins/llamafile_exporter/index.json @@ -4,7 +4,7 @@ "description": "Exports the current model to a fully contained self-executing llamafile.", "plugin-format": "python", "type": "exporter", - "version": "0.1.10", + "version": "0.1.11", "model_architectures": ["GGUF"], "supported_hardware_architectures": ["cpu", "cuda", "mlx", "amd"], "export_architecture": "llamafile", diff --git a/api/transformerlab/plugins/llamaindex_simple_document_search/index.json b/api/transformerlab/plugins/llamaindex_simple_document_search/index.json index 9d43c559d..29c485ebe 100644 --- a/api/transformerlab/plugins/llamaindex_simple_document_search/index.json +++ b/api/transformerlab/plugins/llamaindex_simple_document_search/index.json @@ -4,7 +4,7 @@ "description": "", "plugin-format": "python", "type": "rag", - "version": "0.0.17", + "version": "0.0.18", "model_architectures": [], "supported_hardware_architectures": ["cpu", "cuda", "mlx", "amd"], "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/mlx_audio_server/index.json b/api/transformerlab/plugins/mlx_audio_server/index.json index 667af8e8a..66b666599 100644 --- a/api/transformerlab/plugins/mlx_audio_server/index.json +++ b/api/transformerlab/plugins/mlx_audio_server/index.json @@ -4,7 +4,7 @@ "description": "A text-to-speech (TTS), speech-to-text(STT) library built on Apple's MLX framework, providing efficient speech synthesis on Apple Silicon.", "plugin-format": "python", "type": "loader", - "version": "0.1.5", + "version": "0.1.6", "supports": ["Text-to-Speech", "Audio", "Speech-to-Text"], "model_architectures": ["MLXTextToSpeech", "StyleTTS2", "MLXSpeechToText"], "supported_hardware_architectures": ["mlx"], diff --git a/api/transformerlab/plugins/mlx_audio_server/main.py b/api/transformerlab/plugins/mlx_audio_server/main.py index 994da760a..39afbaa96 100644 --- a/api/transformerlab/plugins/mlx_audio_server/main.py +++ b/api/transformerlab/plugins/mlx_audio_server/main.py @@ -89,7 +89,7 @@ async def generate(self, params): lang_code = params.get("lang_code", None) stream = params.get("stream", False) - experiment_dir = get_experiments_dir() + experiment_dir = await get_experiments_dir() audio_dir_name = secure_filename(params.get("audio_dir", "audio")) audio_dir = storage.join(experiment_dir, audio_dir_name) storage.makedirs(name=audio_dir, exist_ok=True) @@ -150,7 +150,7 @@ async def generate(self, params): model = params.get("model", None) format = params.get("format", "txt") output_path_name = secure_filename(params.get("output_path", "transcriptions")) - transcriptions_dir = storage.join(get_workspace_dir(), output_path_name) + transcriptions_dir = storage.join(await get_workspace_dir(), output_path_name) storage.makedirs(name=transcriptions_dir, exist_ok=True) # Generate a UUID for this file name: diff --git a/api/transformerlab/plugins/mlx_exporter/index.json b/api/transformerlab/plugins/mlx_exporter/index.json index 74bf745b2..89b438011 100644 --- a/api/transformerlab/plugins/mlx_exporter/index.json +++ b/api/transformerlab/plugins/mlx_exporter/index.json @@ -4,7 +4,7 @@ "description": "Exports the current model to MLX format so it can be run on Apple Silicon.", "plugin-format": "python", "type": "exporter", - "version": "1.0.26", + "version": "1.0.27", "model_architectures": [ "CohereForCausalLM", "DeepseekV2ForCausalLM", diff --git a/api/transformerlab/plugins/mlx_lora_trainer/index.json b/api/transformerlab/plugins/mlx_lora_trainer/index.json index 113c590ad..165ea9695 100644 --- a/api/transformerlab/plugins/mlx_lora_trainer/index.json +++ b/api/transformerlab/plugins/mlx_lora_trainer/index.json @@ -4,7 +4,7 @@ "description": "MLX Machine learning research on your laptop or in a data center - by Apple", "plugin-format": "python", "type": "trainer", - "version": "0.4.27", + "version": "0.4.28", "model_architectures": [ "MLX", "CohereForCausalLM", diff --git a/api/transformerlab/plugins/mlx_lora_trainer/main.py b/api/transformerlab/plugins/mlx_lora_trainer/main.py index 2eabc0130..d1f5e1088 100644 --- a/api/transformerlab/plugins/mlx_lora_trainer/main.py +++ b/api/transformerlab/plugins/mlx_lora_trainer/main.py @@ -9,6 +9,7 @@ import subprocess import os import time +import asyncio # Import tlab_trainer from the SDK @@ -91,10 +92,10 @@ def train_mlx_lora(): print(lora_config) # Directory for storing temporary working files - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) data_directory = storage.join(workspace_dir, "plugins", "mlx_lora_trainer", "data") - if not storage.exists(data_directory): - storage.makedirs(data_directory) + if not asyncio.run(storage.exists(data_directory)): + asyncio.run(storage.makedirs(data_directory)) prepare_dataset_files( data_directory=data_directory, @@ -108,11 +109,11 @@ def train_mlx_lora(): # Set output directory for the adaptor adaptor_output_dir = tlab_trainer.params.get("adaptor_output_dir", "") if adaptor_output_dir == "" or adaptor_output_dir is None: - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) adaptor_output_dir = storage.join(workspace_dir, "adaptors", tlab_trainer.params.model_name, adaptor_name) print("Using default adaptor output directory:", adaptor_output_dir) - if not storage.exists(adaptor_output_dir): - storage.makedirs(adaptor_output_dir) + if not asyncio.run(storage.exists(adaptor_output_dir)): + asyncio.run(storage.makedirs(adaptor_output_dir)) # Get Python executable (from venv if available) python_executable = get_python_executable(plugin_dir) @@ -230,12 +231,12 @@ def train_mlx_lora(): if "/" in model_name: model_name = model_name.split("/")[-1] fused_model_name = f"{model_name}_{adaptor_name}" - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) fused_model_location = storage.join(workspace_dir, "models", fused_model_name) # Make the directory to save the fused model - if not storage.exists(fused_model_location): - storage.makedirs(fused_model_location) + if not asyncio.run(storage.exists(fused_model_location)): + asyncio.run(storage.makedirs(fused_model_location)) fuse_popen_command = [ python_executable, diff --git a/api/transformerlab/plugins/mlx_rlaif_trainer/index.json b/api/transformerlab/plugins/mlx_rlaif_trainer/index.json index 6a596cd01..766ccd1d4 100644 --- a/api/transformerlab/plugins/mlx_rlaif_trainer/index.json +++ b/api/transformerlab/plugins/mlx_rlaif_trainer/index.json @@ -4,7 +4,7 @@ "description": "MLX PPO (Proximal Policy Optimization) Reinforcement Learning from AI Feedback (RLAIF) trainer for MLX models.", "plugin-format": "python", "type": "trainer", - "version": "0.1.10", + "version": "0.1.11", "model_architectures": [ "MLX", "LlamaForCausalLM", diff --git a/api/transformerlab/plugins/mlx_rlaif_trainer/main.py b/api/transformerlab/plugins/mlx_rlaif_trainer/main.py index c37b13c51..82ce00429 100644 --- a/api/transformerlab/plugins/mlx_rlaif_trainer/main.py +++ b/api/transformerlab/plugins/mlx_rlaif_trainer/main.py @@ -13,6 +13,7 @@ from transformerlab.plugin import get_python_executable from lab.dirs import get_workspace_dir from lab import storage +import asyncio @tlab_trainer.job_wrapper(wandb_project_name="TLab_RLAIF", manual_logging=True) @@ -37,7 +38,7 @@ def train_mlx_rlaif(): datasets = tlab_trainer.load_dataset(["train"]) # Directory for storing temporary working files - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) data_directory = storage.join(workspace_dir, "plugins", "mlx_rlaif_trainer", "data") if not storage.exists(data_directory): storage.makedirs(data_directory) diff --git a/api/transformerlab/plugins/mlx_server/index.json b/api/transformerlab/plugins/mlx_server/index.json index 2aa87b0d0..a5ec2f171 100644 --- a/api/transformerlab/plugins/mlx_server/index.json +++ b/api/transformerlab/plugins/mlx_server/index.json @@ -4,7 +4,7 @@ "description": "MLX Machine learning research on your laptop or in a data center - by Apple", "plugin-format": "python", "type": "loader", - "version": "0.1.47", + "version": "0.1.48", "supports": [ "chat", "completion", diff --git a/api/transformerlab/plugins/mlx_vlm_server/index.json b/api/transformerlab/plugins/mlx_vlm_server/index.json index 3087b2fb1..68c988782 100644 --- a/api/transformerlab/plugins/mlx_vlm_server/index.json +++ b/api/transformerlab/plugins/mlx_vlm_server/index.json @@ -4,7 +4,7 @@ "description": "MLX Machine learning multimodal research on your laptop or in a data center - by Blaizzy", "plugin-format": "python", "type": "loader", - "version": "0.1.28", + "version": "0.1.29", "model_architectures": ["LlavaForConditionalGeneration"], "supported_hardware_architectures": ["mlx"], "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/nanotron_pretrainer/index.json b/api/transformerlab/plugins/nanotron_pretrainer/index.json index f838c2593..788c16c7b 100644 --- a/api/transformerlab/plugins/nanotron_pretrainer/index.json +++ b/api/transformerlab/plugins/nanotron_pretrainer/index.json @@ -5,7 +5,7 @@ "plugin-format": "python", "type": "trainer", "train_type": "pretraining", - "version": "0.1.8", + "version": "0.1.9", "git": "https://github.com/huggingface/nanotron", "url": "https://github.com/huggingface/nanotron", "files": ["main.py", "setup.sh", "config.py"], diff --git a/api/transformerlab/plugins/nanotron_pretrainer/main.py b/api/transformerlab/plugins/nanotron_pretrainer/main.py index b1f9c5904..1c70c7ee0 100644 --- a/api/transformerlab/plugins/nanotron_pretrainer/main.py +++ b/api/transformerlab/plugins/nanotron_pretrainer/main.py @@ -3,10 +3,11 @@ import yaml import re import torch +import asyncio from typing import Dict, Tuple - from transformerlab.sdk.v1.train import tlab_trainer +from lab.dirs import get_workspace_dir def get_gpu_count(): @@ -33,9 +34,9 @@ def generate_nanotron_config(): """ # Format the run name with job ID run_name = tlab_trainer.params.get("template_name", "nanotron_run") + "_" + str(tlab_trainer.params.job_id) - from transformerlab.plugin import WORKSPACE_DIR + workspace_dir = asyncio.run(get_workspace_dir()) - checkpoint_path = os.path.join(WORKSPACE_DIR, "models", "pretrained", run_name, "checkpoints") + checkpoint_path = os.path.join(workspace_dir, "models", "pretrained", run_name, "checkpoints") MODEL_SIZES: Dict[str, Tuple[int, int, int, int, int]] = { # (layers, hidden, heads, kv_heads, ffn_size) @@ -206,9 +207,9 @@ def train_model(): run_name = tlab_trainer.params.get("template_name", "nanotron_run") + "_" + str(tlab_trainer.params.job_id) # Create output directories - from transformerlab.plugin import WORKSPACE_DIR + workspace_dir = asyncio.run(get_workspace_dir()) - output_path = os.path.join(WORKSPACE_DIR, "models", "pretrained", run_name, "nanotron_config_files") + output_path = os.path.join(workspace_dir, "models", "pretrained", run_name, "nanotron_config_files") os.makedirs(output_path, exist_ok=True) # Save the configuration to a YAML file config_path = os.path.join(output_path, f"{run_name}.yaml") @@ -225,9 +226,9 @@ def train_model(): # Get GPU count num_gpus = get_gpu_count() # Create run_train.py script - from transformerlab.plugin import WORKSPACE_DIR + workspace_dir = asyncio.run(get_workspace_dir()) - run_train_path = os.path.join(WORKSPACE_DIR, "plugins", "nanotron_pretrainer", "nanotron", "run_train.py") + run_train_path = os.path.join(workspace_dir, "plugins", "nanotron_pretrainer", "nanotron", "run_train.py") # Run training with torchrun env = os.environ.copy() @@ -307,23 +308,23 @@ def train_model(): tlab_trainer.progress_update(100) # Convert Nanotron checkpoint to HF format - from transformerlab.plugin import WORKSPACE_DIR + workspace_dir = asyncio.run(get_workspace_dir()) - checkpoint_path = os.path.join(WORKSPACE_DIR, "models", "pretrained", run_name, "checkpoints") + checkpoint_path = os.path.join(workspace_dir, "models", "pretrained", run_name, "checkpoints") try: with open(os.path.join(checkpoint_path, "latest.txt"), "r") as f: latest_checkpoint = f.read().strip() - from transformerlab.plugin import WORKSPACE_DIR + workspace_dir = asyncio.run(get_workspace_dir()) - save_path = os.path.join(WORKSPACE_DIR, "models", run_name) + save_path = os.path.join(workspace_dir, "models", run_name) latest_checkpoint_path = os.path.join(checkpoint_path, latest_checkpoint) print("Latest checkpoint path:", latest_checkpoint_path) print("Save path:", save_path) - from transformerlab.plugin import WORKSPACE_DIR + workspace_dir = asyncio.run(get_workspace_dir()) - convert_script_path = os.path.join(WORKSPACE_DIR, "plugins", "nanotron_pretrainer", "convert_nanotron_to_hf.py") + convert_script_path = os.path.join(workspace_dir, "plugins", "nanotron_pretrainer", "convert_nanotron_to_hf.py") cmd_convert = [ "torchrun", diff --git a/api/transformerlab/plugins/ollama_server/index.json b/api/transformerlab/plugins/ollama_server/index.json index 6c6d3c400..48c7cccce 100644 --- a/api/transformerlab/plugins/ollama_server/index.json +++ b/api/transformerlab/plugins/ollama_server/index.json @@ -4,7 +4,7 @@ "description": "Connects to your instance of ollama to run GGUF models that can host models across GPU and/or CPU.", "plugin-format": "python", "type": "loader", - "version": "0.1.11", + "version": "0.1.12", "supports": [ "chat", "completion", diff --git a/api/transformerlab/plugins/red_teaming/index.json b/api/transformerlab/plugins/red_teaming/index.json index d05e92067..8e90e7120 100644 --- a/api/transformerlab/plugins/red_teaming/index.json +++ b/api/transformerlab/plugins/red_teaming/index.json @@ -4,7 +4,7 @@ "description": "Evaluating outputs of LLMs using objective metrics", "plugin-format": "python", "type": "evaluator", - "version": "0.1.11", + "version": "0.1.12", "git": "https://github.com/confident-ai/deepeval", "url": "https://github.com/confident-ai/deepeval", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/reward_modeling_llama_factory/index.json b/api/transformerlab/plugins/reward_modeling_llama_factory/index.json index 7bd43c808..46d9bb6d7 100644 --- a/api/transformerlab/plugins/reward_modeling_llama_factory/index.json +++ b/api/transformerlab/plugins/reward_modeling_llama_factory/index.json @@ -4,7 +4,7 @@ "description": "An implementation of RLHF (Reward Modeling) using Llama Factory.", "plugin-format": "python", "type": "trainer", - "version": "0.0.14", + "version": "0.0.15", "model_architectures": [ "GemmaForCausalLM", "LlamaForCausalLM", diff --git a/api/transformerlab/plugins/reward_modeling_llama_factory/main.py b/api/transformerlab/plugins/reward_modeling_llama_factory/main.py index 1fd1f8392..9a9324f5c 100644 --- a/api/transformerlab/plugins/reward_modeling_llama_factory/main.py +++ b/api/transformerlab/plugins/reward_modeling_llama_factory/main.py @@ -13,6 +13,7 @@ import json import yaml import re +import asyncio from transformerlab.sdk.v1.train import tlab_trainer from transformerlab.plugin import get_python_executable @@ -25,7 +26,7 @@ print("Plugin dir:", plugin_dir) # Directory for storing temporary working files -workspace_dir = get_workspace_dir() +workspace_dir = asyncio.run(get_workspace_dir()) data_directory = storage.join(workspace_dir, "temp", "llama_factory_reward", "data") if not storage.exists(data_directory): storage.makedirs(data_directory) diff --git a/api/transformerlab/plugins/sample_plugin/index.json b/api/transformerlab/plugins/sample_plugin/index.json index 1135c57ed..2b4e16543 100644 --- a/api/transformerlab/plugins/sample_plugin/index.json +++ b/api/transformerlab/plugins/sample_plugin/index.json @@ -4,7 +4,7 @@ "description": "This plugin just takes in the inputs provided, and prints them out. Use it as a template to see what parameters are available to a script.", "plugin-format": "python", "type": "trainer", - "version": "1.0.4", + "version": "1.0.5", "model_architectures": ["LlamaForCausalLM"], "supported_hardware_architectures": ["cpu", "cuda", "mlx"], "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/sft_llama_factory/index.json b/api/transformerlab/plugins/sft_llama_factory/index.json index 480f9ba8b..0474846b7 100644 --- a/api/transformerlab/plugins/sft_llama_factory/index.json +++ b/api/transformerlab/plugins/sft_llama_factory/index.json @@ -4,7 +4,7 @@ "description": "An implementation of Supervised Finetuning using Llama Factory.", "plugin-format": "python", "type": "trainer", - "version": "0.0.17", + "version": "0.0.18", "model_architectures": [ "GemmaForCausalLM", "LlamaForCausalLM", diff --git a/api/transformerlab/plugins/sft_llama_factory/main.py b/api/transformerlab/plugins/sft_llama_factory/main.py index 04dca0915..745283121 100644 --- a/api/transformerlab/plugins/sft_llama_factory/main.py +++ b/api/transformerlab/plugins/sft_llama_factory/main.py @@ -14,6 +14,7 @@ import json import yaml import re +import asyncio from transformerlab.sdk.v1.train import tlab_trainer from transformerlab.plugin import get_python_executable @@ -35,7 +36,7 @@ @tlab_trainer.job_wrapper() def run_train(): # Directory for storing temporary working files - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) data_directory = storage.join(workspace_dir, "temp", "llama_factory", "data") if not storage.exists(data_directory): storage.makedirs(data_directory) @@ -219,7 +220,7 @@ def fuse_model(): """Fuse the adapter with the base model""" print("Now fusing the adaptor with the model.") - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) data_directory = storage.join(workspace_dir, "temp", "llama_factory", "data") model_name = tlab_trainer.params.model_name adaptor_name = tlab_trainer.params.adaptor_name diff --git a/api/transformerlab/plugins/sglang_server/index.json b/api/transformerlab/plugins/sglang_server/index.json index e80572919..c96d82739 100644 --- a/api/transformerlab/plugins/sglang_server/index.json +++ b/api/transformerlab/plugins/sglang_server/index.json @@ -4,7 +4,7 @@ "description": "SGLang Server loads models for multimodal inference using Huggingface Transformers for generation.", "plugin-format": "python", "type": "loader", - "version": "0.0.9", + "version": "0.0.10", "supports": [ "chat", "completion", diff --git a/api/transformerlab/plugins/synthesizer_docs/index.json b/api/transformerlab/plugins/synthesizer_docs/index.json index ba3163de0..4e8f8669d 100644 --- a/api/transformerlab/plugins/synthesizer_docs/index.json +++ b/api/transformerlab/plugins/synthesizer_docs/index.json @@ -4,7 +4,7 @@ "description": "Use LLMs to create synthetic data for your usecases from reference documents.", "plugin-format": "python", "type": "generator", - "version": "0.1.23", + "version": "0.1.24", "git": "https://github.com/confident-ai/deepeval", "url": "https://github.com/confident-ai/deepeval", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/synthesizer_docs/main.py b/api/transformerlab/plugins/synthesizer_docs/main.py index 703d4dae6..dfd41ab2b 100644 --- a/api/transformerlab/plugins/synthesizer_docs/main.py +++ b/api/transformerlab/plugins/synthesizer_docs/main.py @@ -1,6 +1,7 @@ import traceback import pandas as pd from typing import List +import asyncio from deepeval.models import DeepEvalBaseEmbeddingModel from deepeval.synthesizer import Synthesizer @@ -47,7 +48,7 @@ def get_model_name(self): def get_docs_list(docs: str, experiment_name: str) -> List[str]: """Get list of document paths from comma-separated string of doc names""" docs_list = docs.split(",") - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) documents_dir = storage.join(workspace_dir, "experiments", experiment_name, "documents") result_docs = [] diff --git a/api/transformerlab/plugins/synthesizer_raw_text/index.json b/api/transformerlab/plugins/synthesizer_raw_text/index.json index 176254497..641e6ad90 100644 --- a/api/transformerlab/plugins/synthesizer_raw_text/index.json +++ b/api/transformerlab/plugins/synthesizer_raw_text/index.json @@ -4,7 +4,7 @@ "description": "Use LLMs to create synthetic data for your usecases from reference context. Paste all your reference context to generate a dataset.", "plugin-format": "python", "type": "generator", - "version": "0.1.17", + "version": "0.1.18", "git": "https://github.com/confident-ai/deepeval", "url": "https://github.com/confident-ai/deepeval", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/synthesizer_scratch/index.json b/api/transformerlab/plugins/synthesizer_scratch/index.json index dc85e6e42..a31604b41 100644 --- a/api/transformerlab/plugins/synthesizer_scratch/index.json +++ b/api/transformerlab/plugins/synthesizer_scratch/index.json @@ -4,7 +4,7 @@ "description": "Use LLMs to create synthetic data for your usecases from scratch using just a concept of a dataset.", "plugin-format": "python", "type": "generator", - "version": "0.1.18", + "version": "0.1.19", "git": "https://github.com/confident-ai/deepeval", "url": "https://github.com/confident-ai/deepeval", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/synthetic_dataset_kit/index.json b/api/transformerlab/plugins/synthetic_dataset_kit/index.json index 97b181e63..2d26ee564 100644 --- a/api/transformerlab/plugins/synthetic_dataset_kit/index.json +++ b/api/transformerlab/plugins/synthetic_dataset_kit/index.json @@ -4,7 +4,7 @@ "description": "Generates QA pairs, chain-of-thought, or summaries using Meta's synthetic-data-kit and vLLM-based or proxy-backed models.", "plugin-format": "python", "type": "generator", - "version": "0.1.7", + "version": "0.1.8", "git": "", "url": "", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/synthetic_dataset_kit/main.py b/api/transformerlab/plugins/synthetic_dataset_kit/main.py index fb9fc5218..ee3efc80d 100644 --- a/api/transformerlab/plugins/synthetic_dataset_kit/main.py +++ b/api/transformerlab/plugins/synthetic_dataset_kit/main.py @@ -22,6 +22,7 @@ import shutil import pandas as pd from pathlib import Path +import asyncio from transformerlab.sdk.v1.generate import tlab_gen from lab.dirs import get_workspace_dir @@ -58,7 +59,7 @@ def run_generation(): prompt_template = tlab_gen.params.get("prompt_template", "") api_base = tlab_gen.params.get("vllm_api_base", "http://localhost:8338/v1") port = str(api_base.rsplit(":", 1)[-1].rstrip("/v1")) - workspace = get_workspace_dir() + workspace = asyncio.run(get_workspace_dir()) experiment = tlab_gen.params.experiment_name documents_dir = storage.join(workspace, "experiments", experiment, "documents") doc_filenames = [d.strip() for d in docs_str.split(",") if d.strip()] @@ -70,14 +71,14 @@ def run_generation(): # Prompt selector based on generation_type DEFAULT_PROMPTS = { "summary": "Summarize this document in 3-5 sentences, focusing on the main topic and key concepts.", - "qa_generation": """ + "qa_generation": """ Create {num_pairs} question-answer pairs from this text for LLM training. Rules: 1. Questions must be about important facts in the text 2. Answers must be directly supported by the text 3. Return JSON format only: - + [ {{ "question": "Question 1?", @@ -88,7 +89,7 @@ def run_generation(): "answer": "Answer 2." }} ] - + Text: {text} """, @@ -98,35 +99,35 @@ def run_generation(): - Relevance (0-2): relevance to content - Clarity (0-2): clear language - Usefulness (0-3): value for model learning - + YOU MUST RETURN A VALID JSON OBJECT OR ARRAY WITH THIS EXACT SCHEMA: {{ "question": "Exact question text", "answer": "Exact answer text", "rating": 8 }} - + OR FOR MULTIPLE PAIRS: [ {{"question": "Q1", "answer": "A1", "rating": 8}}, {{"question": "Q2", "answer": "A2", "rating": 9}} ] - + *** YOUR RESPONSE MUST BE VALID JSON AND NOTHING ELSE - NO EXPLANATION, NO MARKDOWN *** - + QA pairs to rate: - {pairs} + {pairs} """, "cot_generation": """ Create {num_examples} complex reasoning examples from this text that demonstrate chain-of-thought thinking. - + Each example should have: 1. A challenging question that requires step-by-step reasoning 2. Detailed reasoning steps that break down the problem 3. A concise final answer - + Return JSON format only: - + [ {{ "question": "Complex question about the text?", @@ -139,17 +140,17 @@ def run_generation(): "answer": "Final answer drawn from the reasoning." }} ] - + Text: {text} """, "cot_enhancement": """ You are an expert reasoning assistant. Your task is to enhance the given conversations by adding chain-of-thought reasoning. - + For each conversation, add detailed step-by-step reasoning to the assistant's responses while preserving the original answer. - + {include_simple_steps} = Whether to add reasoning to simple responses too. If false, only add reasoning to complex responses. - + Return the enhanced conversations as a JSON array matching this format: [ [ @@ -163,9 +164,9 @@ def run_generation(): {{"role": "assistant", "content": "Let me work through this:\n\n1. I'll start by...\n2. Next...\n\nIn conclusion, [original answer]"}} ] ] - + Original conversations: - {conversations} + {conversations} """, } prompt_lookup = {"qa": "qa_generation", "cot": "cot_generation", "summary": "summary"} diff --git a/api/transformerlab/plugins/synthetic_dataset_rag/index.json b/api/transformerlab/plugins/synthetic_dataset_rag/index.json index f021934ea..e86de8de5 100644 --- a/api/transformerlab/plugins/synthetic_dataset_rag/index.json +++ b/api/transformerlab/plugins/synthetic_dataset_rag/index.json @@ -4,7 +4,7 @@ "description": "Generate synthetic question-answer pairs from documents for RAG system evaluation.", "plugin-format": "python", "type": "generator", - "version": "0.1.14", + "version": "0.1.15", "git": "", "url": "", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/synthetic_dataset_rag/main.py b/api/transformerlab/plugins/synthetic_dataset_rag/main.py index 09816eda4..4f9cc4561 100644 --- a/api/transformerlab/plugins/synthetic_dataset_rag/main.py +++ b/api/transformerlab/plugins/synthetic_dataset_rag/main.py @@ -2,6 +2,7 @@ import random from typing import List import pandas as pd +import asyncio import fitz from langchain.docstore.document import Document as LangchainDocument @@ -33,7 +34,7 @@ def get_docs_list(docs: str) -> List[dict]: Supports text files, PDFs, and other document formats """ docs_list = docs.split(",") - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) documents_dir = storage.join(workspace_dir, "experiments", tlab_gen.params.experiment_name, "documents") # Use the markdown files if they exist diff --git a/api/transformerlab/plugins/t5_lora_trainer/index.json b/api/transformerlab/plugins/t5_lora_trainer/index.json index dff6a0904..2a695b3ec 100644 --- a/api/transformerlab/plugins/t5_lora_trainer/index.json +++ b/api/transformerlab/plugins/t5_lora_trainer/index.json @@ -4,7 +4,7 @@ "description": "A training script adapted from https://www.philschmid.de/fine-tune-flan-t5-peft . Works with T5 models using transformers.Se2SecTrainer.", "plugin-format": "python", "type": "trainer", - "version": "1.0.11", + "version": "1.0.12", "model_architectures": ["T5ForConditionalGeneration"], "supported_hardware_architectures": ["cuda", "amd"], "git": "", diff --git a/api/transformerlab/plugins/unsloth_grpo_trainer/index.json b/api/transformerlab/plugins/unsloth_grpo_trainer/index.json index 36ebc4cf6..2ae5f8c31 100644 --- a/api/transformerlab/plugins/unsloth_grpo_trainer/index.json +++ b/api/transformerlab/plugins/unsloth_grpo_trainer/index.json @@ -4,7 +4,7 @@ "description": "A GPRO trainer based on the unsloth grpo training notebooks", "plugin-format": "python", "type": "trainer", - "version": "0.0.14", + "version": "0.0.15", "model_architectures": [ "LlamaForCausalLM", "Qwen2ForCausalLM", diff --git a/api/transformerlab/plugins/unsloth_text_to_speech_server/index.json b/api/transformerlab/plugins/unsloth_text_to_speech_server/index.json index d21e01586..41a20c4a8 100644 --- a/api/transformerlab/plugins/unsloth_text_to_speech_server/index.json +++ b/api/transformerlab/plugins/unsloth_text_to_speech_server/index.json @@ -4,7 +4,7 @@ "description": "A text-to-speech (TTS) audio generation server, supporting efficient, high-quality speech synthesis.", "plugin-format": "python", "type": "loader", - "version": "0.0.14", + "version": "0.0.15", "supports": ["Text-to-Speech", "Audio"], "model_architectures": ["CsmForConditionalGeneration", "LlamaForCausalLM"], "supported_hardware_architectures": ["cuda", "amd"], diff --git a/api/transformerlab/plugins/unsloth_text_to_speech_server/main.py b/api/transformerlab/plugins/unsloth_text_to_speech_server/main.py index 28f3b33b4..45bfd9420 100644 --- a/api/transformerlab/plugins/unsloth_text_to_speech_server/main.py +++ b/api/transformerlab/plugins/unsloth_text_to_speech_server/main.py @@ -117,7 +117,7 @@ async def generate(self, params): else: logger.info("No reference audio provided, performing standard TTS") - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() # Make sure the path is still inside the workspace directory # For both local and fsspec paths, check if audio_dir starts with workspace_dir is_safe = audio_dir.startswith(workspace_dir) diff --git a/api/transformerlab/plugins/unsloth_text_to_speech_trainer/index.json b/api/transformerlab/plugins/unsloth_text_to_speech_trainer/index.json index 08934d644..c84d688b9 100644 --- a/api/transformerlab/plugins/unsloth_text_to_speech_trainer/index.json +++ b/api/transformerlab/plugins/unsloth_text_to_speech_trainer/index.json @@ -4,7 +4,7 @@ "description": "A Text-to-Speech (TTS) trainer based on the unsloth audio training notebooks", "plugin-format": "python", "type": "trainer", - "version": "0.0.8", + "version": "0.0.9", "model_architectures": ["CsmForConditionalGeneration", "LlamaForCausalLM"], "supports": [], "supported_hardware_architectures": ["cuda", "amd"], diff --git a/api/transformerlab/plugins/vllm_server/index.json b/api/transformerlab/plugins/vllm_server/index.json index 941792c2f..ede0e708c 100644 --- a/api/transformerlab/plugins/vllm_server/index.json +++ b/api/transformerlab/plugins/vllm_server/index.json @@ -4,7 +4,7 @@ "description": "vLLM is a fast and easy-to-use library for LLM inference and serving.", "plugin-format": "python", "type": "loader", - "version": "1.0.25", + "version": "1.0.26", "supports": [ "chat", "completion", diff --git a/api/transformerlab/plugins/vllm_server/main.py b/api/transformerlab/plugins/vllm_server/main.py index dccd89e2f..88d3468af 100644 --- a/api/transformerlab/plugins/vllm_server/main.py +++ b/api/transformerlab/plugins/vllm_server/main.py @@ -8,6 +8,7 @@ import threading import site import shutil +import asyncio from fastchat.constants import TEMP_IMAGE_DIR @@ -86,7 +87,8 @@ def inject_ninja_into_path(): print("Starting vLLM server...", file=sys.stderr) # Use storage.makedirs for workspace paths (TLAB_TEMP_IMAGE_DIR), os.makedirs for system paths -if TEMP_IMAGE_DIR.startswith(get_workspace_dir()): +workspace_dir = asyncio.run(get_workspace_dir()) +if TEMP_IMAGE_DIR.startswith(workspace_dir): storage.makedirs(TEMP_IMAGE_DIR, exist_ok=True) else: os.makedirs(TEMP_IMAGE_DIR, exist_ok=True) diff --git a/api/transformerlab/plugins/vlm_trainer/index.json b/api/transformerlab/plugins/vlm_trainer/index.json index 814b34de6..a8a6e3d65 100644 --- a/api/transformerlab/plugins/vlm_trainer/index.json +++ b/api/transformerlab/plugins/vlm_trainer/index.json @@ -4,7 +4,7 @@ "description": "A training script for vision-language models (VLMs) using Huggingface TRL and PeFT. Supports models like Qwen2-VL.", "plugin-format": "python", "type": "trainer", - "version": "0.1.8", + "version": "0.1.9", "model_architectures": [ "Qwen2ForVision2Seq", "Qwen2VLForVision2Seq", diff --git a/api/transformerlab/plugins/vlm_trainer/main.py b/api/transformerlab/plugins/vlm_trainer/main.py index 11962cf0f..14f27a3e0 100644 --- a/api/transformerlab/plugins/vlm_trainer/main.py +++ b/api/transformerlab/plugins/vlm_trainer/main.py @@ -3,6 +3,7 @@ import torch from lab.dirs import get_workspace_dir from lab import storage +import asyncio if torch.cuda.is_available(): os.environ["CUDA_VISIBLE_DEVICES"] = "0" @@ -198,7 +199,7 @@ def collate_fn(examples): model_id_short = model_id adaptor_name = tlab_trainer.params.get("adaptor_name", "default") fused_model_name = f"{model_id_short}_{adaptor_name}" - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) fused_model_location = storage.join(workspace_dir, "models", fused_model_name) peft_model = PeftModel.from_pretrained(model, args.adaptor_output_dir) merged_model = peft_model.merge_and_unload() diff --git a/api/transformerlab/plugins/wd14_captioner/index.json b/api/transformerlab/plugins/wd14_captioner/index.json index 5183a6d95..9432bd99a 100644 --- a/api/transformerlab/plugins/wd14_captioner/index.json +++ b/api/transformerlab/plugins/wd14_captioner/index.json @@ -4,7 +4,7 @@ "description": "Generate tags/captions for image datasets using the WD14 tagger (Danbooru-style tagging via ONNX).", "plugin-format": "python", "type": "generator", - "version": "0.1.9", + "version": "0.1.10", "git": "", "url": "", "files": ["main.py", "setup.sh"], diff --git a/api/transformerlab/plugins/wd14_captioner/main.py b/api/transformerlab/plugins/wd14_captioner/main.py index 4079bb5b1..ca6c2ecab 100644 --- a/api/transformerlab/plugins/wd14_captioner/main.py +++ b/api/transformerlab/plugins/wd14_captioner/main.py @@ -3,6 +3,7 @@ import shutil import subprocess from pathlib import Path +import asyncio import pandas as pd from tqdm import tqdm @@ -16,7 +17,7 @@ @tlab_gen.job_wrapper() def run_generation(): # ----- Constants ----- - workspace = get_workspace_dir() + workspace = asyncio.run(get_workspace_dir()) REPO_ROOT = storage.join(workspace, "plugins", "wd14_captioner", "sd-caption-wd14", "sd-scripts") SCRIPT_PATH = storage.join( diff --git a/api/transformerlab/plugins/yourbench_data_gen/index.json b/api/transformerlab/plugins/yourbench_data_gen/index.json index 10e291499..2e9a2fd98 100644 --- a/api/transformerlab/plugins/yourbench_data_gen/index.json +++ b/api/transformerlab/plugins/yourbench_data_gen/index.json @@ -4,7 +4,7 @@ "description": "Generates and uploads a dataset to Huggingface with the specified configurations.", "plugin-format": "python", "type": "generator", - "version": "0.1.6", + "version": "0.1.8", "supported_hardware_architectures": ["cpu", "cuda", "mlx"], "git": "", "url": "", diff --git a/api/transformerlab/plugins/yourbench_data_gen/main.py b/api/transformerlab/plugins/yourbench_data_gen/main.py index 43ea4c896..c860ad660 100644 --- a/api/transformerlab/plugins/yourbench_data_gen/main.py +++ b/api/transformerlab/plugins/yourbench_data_gen/main.py @@ -1,3 +1,4 @@ +import asyncio import os import yaml import subprocess @@ -20,12 +21,12 @@ def generate_config(): print("Model loaded successfully") tlab_gen.progress_update(30) - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) tlab_gen.params.documents_dir = storage.join( workspace_dir, "experiments", tlab_gen.params.experiment_name, "documents", docs ) - if not storage.isdir(tlab_gen.params.documents_dir): + if not asyncio.run(storage.isdir(tlab_gen.params.documents_dir)): raise FileNotFoundError("Please provide a directory containing all your files instead of individual files") base_url = getattr(trlab_model, "base_url", None) @@ -126,8 +127,8 @@ def run_yourbench(): output_dir = tlab_gen.get_output_file_path(dir_only=True) tlab_gen.params.local_dataset_dir = output_dir tlab_gen.params.output_dir = storage.join(output_dir, "temp") - if not storage.exists(tlab_gen.params.output_dir): - storage.makedirs(tlab_gen.params.output_dir) + if not asyncio.run(storage.exists(tlab_gen.params.output_dir)): + asyncio.run(storage.makedirs(tlab_gen.params.output_dir)) config_path = storage.join(output_dir, f"yourbench_config_{tlab_gen.params.job_id}.yaml") # Generate the configuration @@ -137,8 +138,11 @@ def run_yourbench(): tlab_gen.progress_update(20) # Write the configuration to a file - with storage.open(config_path, "w") as config_file: - yaml.dump(config, config_file, default_flow_style=False) + async def _write_config(): + async with await storage.open(config_path, "w") as config_file: + await config_file.write(yaml.dump(config, default_flow_style=False)) + + asyncio.run(_write_config()) print(f"Configuration written to {config_path}") tlab_gen.add_job_data("config_file", config_path) @@ -146,7 +150,7 @@ def run_yourbench(): tlab_gen.progress_update(30) # Get the yourbench directory path - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) current_dir = storage.join(workspace_dir, "plugins", "yourbench_data_gen") # Run yourbench with the configuration diff --git a/api/transformerlab/routers/batched_prompts.py b/api/transformerlab/routers/batched_prompts.py index 5a1410f08..340d263fc 100644 --- a/api/transformerlab/routers/batched_prompts.py +++ b/api/transformerlab/routers/batched_prompts.py @@ -50,7 +50,7 @@ async def list_prompts(): batched_prompts = [] from lab.dirs import get_batched_prompts_dir - batched_prompts_dir = get_batched_prompts_dir() + batched_prompts_dir = await get_batched_prompts_dir() for file in os.listdir(batched_prompts_dir): if file.endswith(".json"): with open(os.path.join(batched_prompts_dir, file), "r") as f: @@ -74,7 +74,7 @@ async def new_prompt(name: Annotated[str, Body()], prompts: Annotated[Union[list slug = slugify(name) from lab.dirs import get_batched_prompts_dir - prompts_dir = get_batched_prompts_dir() + prompts_dir = await get_batched_prompts_dir() prompt_file = os.path.join(prompts_dir, f"{slug}.json") with open(prompt_file, "w") as f: @@ -91,7 +91,7 @@ async def delete_prompt(prompt_id: str): prompt_id = secure_filename(prompt_id) from lab.dirs import get_batched_prompts_dir - prompts_dir = get_batched_prompts_dir() + prompts_dir = await get_batched_prompts_dir() prompt_file = os.path.join(prompts_dir, f"{prompt_id}.json") if os.path.exists(prompt_file): diff --git a/api/transformerlab/routers/compute_provider.py b/api/transformerlab/routers/compute_provider.py index ab65683ec..b0e4ae931 100644 --- a/api/transformerlab/routers/compute_provider.py +++ b/api/transformerlab/routers/compute_provider.py @@ -94,18 +94,18 @@ async def upload_task_file_for_provider( raise HTTPException(status_code=404, detail="Provider not found") try: - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() if not workspace_dir: raise RuntimeError("Workspace directory is not configured") # uploads/task/{task_id}/ uploads_root = storage.join(workspace_dir, "uploads", "task") - storage.makedirs(uploads_root, exist_ok=True) + await storage.makedirs(uploads_root, exist_ok=True) import uuid task_dir = storage.join(uploads_root, str(task_id)) - storage.makedirs(task_dir, exist_ok=True) + await storage.makedirs(task_dir, exist_ok=True) # Use original filename with a random suffix to avoid collisions original_name = file.filename or "uploaded_file" @@ -118,8 +118,8 @@ async def upload_task_file_for_provider( # Persist file contents await file.seek(0) content = await file.read() - with storage.open(stored_path, "wb") as f: - f.write(content) + async with await storage.open(stored_path, "wb") as f: + await f.write(content) return ProviderTemplateFileUploadResponse( status="success", @@ -256,7 +256,7 @@ async def get_usage_report( # Get all experiments in the current workspace try: - experiments_data = Experiment.get_all() + experiments_data = await Experiment.get_all() experiments = [exp.get("id") for exp in experiments_data if exp.get("id")] except Exception as e: print(f"Error getting experiments: {e}") @@ -267,7 +267,7 @@ async def get_usage_report( for experiment_id in experiments: try: - jobs = job_service.jobs_get_all(experiment_id=experiment_id, type="REMOTE") + jobs = await job_service.jobs_get_all(experiment_id=experiment_id, type="REMOTE") for job in jobs: job_data = job.get("job_data", {}) or {} @@ -720,7 +720,7 @@ async def _create_sweep_parent_job( provider_display_name = request.provider_name or provider.name - parent_job_id = job_service.job_create( + parent_job_id = await job_service.job_create( type="SWEEP", status="RUNNING", experiment_id=request.experiment_id, @@ -748,7 +748,7 @@ async def _create_sweep_parent_job( for key, value in parent_job_data.items(): if value is not None: - job_service.job_update_job_data_insert_key_value(parent_job_id, key, value, request.experiment_id) + await job_service.job_update_job_data_insert_key_value(parent_job_id, key, value, request.experiment_id) return parent_job_id @@ -824,7 +824,7 @@ async def _launch_sweep_jobs( formatted_cluster_name = f"{_sanitize_cluster_basename(base_name)}-{run_suffix}-job-{parent_job_id}" # Create child job - child_job_id = job_service.job_create( + child_job_id = await job_service.job_create( type="REMOTE", status="QUEUED", experiment_id=request.experiment_id, @@ -838,7 +838,7 @@ async def _launch_sweep_jobs( # Get TFL_STORAGE_URI tfl_storage_uri = None try: - storage_root = storage.root_uri() + storage_root = await storage.root_uri() if storage_root and any( storage_root.startswith(prefix) for prefix in ("s3://", "gs://", "gcs://", "abfs://") ): @@ -863,7 +863,7 @@ async def _launch_sweep_jobs( env_vars["AWS_PROFILE"] = aws_profile if request.github_repo_url: - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() github_pat = read_github_pat_from_workspace(workspace_dir) github_setup = generate_github_clone_setup( repo_url=request.github_repo_url, @@ -907,7 +907,7 @@ async def _launch_sweep_jobs( for key, value in child_job_data.items(): if value is not None: - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value( child_job_id, key, value, request.experiment_id ) @@ -940,7 +940,7 @@ async def _launch_sweep_jobs( launch_result = provider_instance.launch_cluster(formatted_cluster_name, cluster_config) if isinstance(launch_result, dict): - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value( child_job_id, "provider_launch_result", launch_result, @@ -948,7 +948,7 @@ async def _launch_sweep_jobs( ) request_id = launch_result.get("request_id") if request_id: - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value( child_job_id, "orchestrator_request_id", request_id, @@ -971,10 +971,10 @@ async def _launch_sweep_jobs( child_job_ids.append(str(child_job_id)) # Update parent job with child job IDs and running count - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value( parent_job_id, "sweep_job_ids", child_job_ids, request.experiment_id ) - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value( parent_job_id, "sweep_running", len(child_job_ids), request.experiment_id ) @@ -1070,7 +1070,11 @@ async def launch_template_on_provider( # Interactive templates should start directly in INTERACTIVE state instead of LAUNCHING initial_status = "INTERACTIVE" if request.subtype == "interactive" else "LAUNCHING" - job_id = job_service.job_create(type="REMOTE", status=initial_status, experiment_id=request.experiment_id) + job_id = await job_service.job_create( + type="REMOTE", + status=initial_status, + experiment_id=request.experiment_id, + ) # Create quota hold if minutes_requested is provided quota_hold = None @@ -1119,8 +1123,8 @@ async def launch_template_on_provider( # Add GitHub clone setup if enabled if request.github_repo_url: - workspace_dir = get_workspace_dir() - github_pat = read_github_pat_from_workspace(workspace_dir) + workspace_dir = await get_workspace_dir() + github_pat = await read_github_pat_from_workspace(workspace_dir) github_setup = generate_github_clone_setup( repo_url=request.github_repo_url, directory=request.github_directory, @@ -1141,7 +1145,7 @@ async def launch_template_on_provider( # Get TFL_STORAGE_URI from storage context tfl_storage_uri = None try: - storage_root = storage.root_uri() + storage_root = await storage.root_uri() # Check if it's a remote URI (not a local path) if storage_root and any(storage_root.startswith(prefix) for prefix in ("s3://", "gs://", "gcs://", "abfs://")): tfl_storage_uri = storage_root @@ -1180,7 +1184,7 @@ async def launch_template_on_provider( for key, value in job_data.items(): if value is not None: - job_service.job_update_job_data_insert_key_value(job_id, key, value, request.experiment_id) + await job_service.job_update_job_data_insert_key_value(job_id, key, value, request.experiment_id) disk_size = None if request.disk_space: @@ -1227,7 +1231,7 @@ async def launch_template_on_provider( request_id = None if isinstance(launch_result, dict): - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value( job_id, "provider_launch_result", launch_result, @@ -1235,7 +1239,7 @@ async def launch_template_on_provider( ) request_id = launch_result.get("request_id") if request_id: - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value( job_id, "orchestrator_request_id", request_id, @@ -1264,7 +1268,7 @@ async def check_provider_job_status( team_id = user_and_team["team_id"] # Get the job - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") @@ -1350,7 +1354,7 @@ async def check_provider_job_status( try: # Set end_time when marking job as complete end_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()) - job_service.job_update_job_data_insert_key_value(job_id, "end_time", end_time_str, experiment_id) + await job_service.job_update_job_data_insert_key_value(job_id, "end_time", end_time_str, experiment_id) # Pass session to job_update_status so quota tracking uses the same session await job_service.job_update_status(job_id, "COMPLETE", experiment_id=experiment_id, session=session) # Commit the session to ensure quota tracking is persisted @@ -1417,7 +1421,7 @@ async def ensure_quota_recorded_for_completed_jobs( } # Get all REMOTE jobs for the experiment - jobs = job_service.jobs_get_all(type="REMOTE", experiment_id=experiment_id) + jobs = await job_service.jobs_get_all(type="REMOTE", experiment_id=experiment_id) jobs_processed = 0 jobs_recorded = 0 @@ -1451,7 +1455,7 @@ async def _update_sweep_job_status(job_id: str, experiment_id: str): Helper function to update a single sweep job's status by checking child jobs. Returns the updated job data. """ - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if not job: return None @@ -1472,7 +1476,7 @@ async def _update_sweep_job_status(job_id: str, experiment_id: str): queued_count = 0 for child_job_id in sweep_job_ids: - child_job = job_service.job_get(child_job_id) + child_job = await job_service.job_get(child_job_id) if not child_job: continue @@ -1487,26 +1491,26 @@ async def _update_sweep_job_status(job_id: str, experiment_id: str): queued_count += 1 # Update parent job with current counts - job_service.job_update_job_data_insert_key_value(job_id, "sweep_completed", completed_count, experiment_id) - job_service.job_update_job_data_insert_key_value(job_id, "sweep_running", running_count, experiment_id) - job_service.job_update_job_data_insert_key_value(job_id, "sweep_failed", failed_count, experiment_id) - job_service.job_update_job_data_insert_key_value(job_id, "sweep_queued", queued_count, experiment_id) + await job_service.job_update_job_data_insert_key_value(job_id, "sweep_completed", completed_count, experiment_id) + await job_service.job_update_job_data_insert_key_value(job_id, "sweep_running", running_count, experiment_id) + await job_service.job_update_job_data_insert_key_value(job_id, "sweep_failed", failed_count, experiment_id) + await job_service.job_update_job_data_insert_key_value(job_id, "sweep_queued", queued_count, experiment_id) # Calculate progress percentage progress = int((completed_count / sweep_total * 100)) if sweep_total > 0 else 0 - job_service.job_update_sweep_progress(job_id, progress, experiment_id) + await job_service.job_update_sweep_progress(job_id, progress, experiment_id) # Check if all jobs are done all_complete = completed_count + failed_count == sweep_total if all_complete and job.get("status") == "RUNNING": # Mark parent as complete if all children are done - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value( job_id, "end_time", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()), experiment_id ) await job_service.job_update_status(job_id, "COMPLETE", experiment_id=experiment_id) # Get the updated job data after status updates - return job_service.job_get(job_id) + return await job_service.job_get(job_id) @router.get("/jobs/sweep-status") @@ -1520,7 +1524,7 @@ async def check_sweep_status_all( Only updates status for running/launching jobs. """ # Get all SWEEP jobs for this experiment - all_sweep_jobs = job_service.jobs_get_all(experiment_id=experiment_id, type="SWEEP", status="") + all_sweep_jobs = await job_service.jobs_get_all(experiment_id=experiment_id, type="SWEEP", status="") # Update status for each running/launching sweep job updated_jobs = [] @@ -1558,7 +1562,7 @@ async def check_sweep_status( Check status of a specific sweep job by polling all child jobs and updating parent job status. Returns current sweep status with counts and the updated job data. """ - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") @@ -1605,7 +1609,7 @@ async def get_sweep_results( """ # Get the parent sweep job - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") @@ -1629,7 +1633,7 @@ async def get_sweep_results( best_job_id = None for child_job_id in sweep_job_ids: - child_job = job_service.job_get(child_job_id) + child_job = await job_service.job_get(child_job_id) if not child_job: continue @@ -1696,7 +1700,7 @@ async def get_sweep_results( } # Store results in parent job - job_service.job_update_job_data_insert_key_value(job_id, "sweep_results", aggregated_results, experiment_id) + await job_service.job_update_job_data_insert_key_value(job_id, "sweep_results", aggregated_results, experiment_id) return { "status": "success", @@ -1723,7 +1727,7 @@ async def resume_from_checkpoint( import time # Get the original job - original_job = job_service.job_get(job_id) + original_job = await job_service.job_get(job_id) if not original_job or str(original_job.get("experiment_id")) != str(experimentId): raise HTTPException(status_code=404, detail="Job not found") @@ -1749,9 +1753,9 @@ async def resume_from_checkpoint( ) # Verify checkpoint exists using workspace-aware path resolution - checkpoints_dir = get_job_checkpoints_dir(job_id) + checkpoints_dir = await get_job_checkpoints_dir(job_id) checkpoint_path = storage.join(checkpoints_dir, request.checkpoint) - if not storage.exists(checkpoint_path): + if not await storage.exists(checkpoint_path): raise HTTPException(status_code=404, detail=f"Checkpoint '{request.checkpoint}' not found") # Get provider @@ -1762,11 +1766,13 @@ async def resume_from_checkpoint( # Create new REMOTE job initial_status = "INTERACTIVE" if job_data.get("subtype") == "interactive" else "LAUNCHING" - new_job_id = job_service.job_create(type="REMOTE", status=initial_status, experiment_id=experimentId, job_data={}) + new_job_id = await job_service.job_create( + type="REMOTE", status=initial_status, experiment_id=experimentId, job_data={} + ) # Set parent_job_id and resumed_from_checkpoint in job_data - job_service.job_update_job_data_insert_key_value(new_job_id, "parent_job_id", job_id, experimentId) - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value(new_job_id, "parent_job_id", job_id, experimentId) + await job_service.job_update_job_data_insert_key_value( new_job_id, "resumed_from_checkpoint", request.checkpoint, experimentId ) @@ -1797,7 +1803,7 @@ async def resume_from_checkpoint( for field in config_fields: value = job_data.get(field) if value is not None: - job_service.job_update_job_data_insert_key_value(new_job_id, field, value, experimentId) + await job_service.job_update_job_data_insert_key_value(new_job_id, field, value, experimentId) # Relaunch via provider - replicate launch logic from compute_provider.py try: @@ -1846,7 +1852,7 @@ async def resume_from_checkpoint( # Add GitHub clone setup if enabled github_repo_url = job_data.get("github_repo_url") if github_repo_url: - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() github_pat = read_github_pat_from_workspace(workspace_dir) github_setup = generate_github_clone_setup( repo_url=github_repo_url, @@ -1888,7 +1894,7 @@ async def resume_from_checkpoint( for key, value in launch_job_data.items(): if value is not None: - job_service.job_update_job_data_insert_key_value(new_job_id, key, value, experimentId) + await job_service.job_update_job_data_insert_key_value(new_job_id, key, value, experimentId) # Build ClusterConfig disk_size = None diff --git a/api/transformerlab/routers/data.py b/api/transformerlab/routers/data.py index f3252a31e..c719d5beb 100644 --- a/api/transformerlab/routers/data.py +++ b/api/transformerlab/routers/data.py @@ -35,10 +35,10 @@ sandboxed_jinja2_environment = SandboxedEnvironment() -def log(msg): - global_log_path = get_global_log_path() - with storage.open(global_log_path, "a") as f: - f.write(msg + "\n") +async def log(msg): + global_log_path = await get_global_log_path() + async with await storage.open(global_log_path, "a") as f: + await f.write(msg + "\n") router = APIRouter(prefix="/data", tags=["datasets"]) @@ -68,10 +68,10 @@ class ErrorResponse(BaseModel): }, ) async def dataset_gallery() -> Any: - gallery = galleries.get_data_gallery() + gallery = await galleries.get_data_gallery() # list datasets from filesystem store try: - local_datasets = dataset_service.list_all() + local_datasets = await dataset_service.list_all() except Exception: local_datasets = [] @@ -85,7 +85,8 @@ async def dataset_gallery() -> Any: async def dataset_info(dataset_id: str): # Read from filesystem store try: - d = dataset_service.get(dataset_id).get_metadata() + d_obj = await dataset_service.get(dataset_id) + d = await d_obj.get_metadata() except FileNotFoundError: d = None if d is None: @@ -94,7 +95,8 @@ async def dataset_info(dataset_id: str): # This means it is a custom dataset the user uploaded if d.get("location") == "local": try: - dataset = dataset_service_module.load_local_dataset(dirs.dataset_dir_by_id(dataset_id)) + dataset_dir = await dirs.dataset_dir_by_id(dataset_id) + dataset = await dataset_service_module.load_local_dataset(dataset_dir) except EmptyDatasetError: return {"status": "error", "message": "The dataset is empty."} split = list(dataset.keys())[0] @@ -162,7 +164,8 @@ async def dataset_preview( ) -> Any: # Read from filesystem store try: - d = dataset_service.get(dataset_id).get_metadata() + d_obj = await dataset_service.get(dataset_id) + d = await d_obj.get_metadata() except FileNotFoundError: d = None dataset_len = 0 @@ -170,7 +173,8 @@ async def dataset_preview( try: if d.get("location") == "local": - dataset = dataset_service_module.load_local_dataset(dirs.dataset_dir_by_id(dataset_id), streaming=streaming) + dataset_dir = await dirs.dataset_dir_by_id(dataset_id) + dataset = await dataset_service_module.load_local_dataset(dataset_dir, streaming=streaming) else: dataset_config = (d.get("json_data") or {}).get("dataset_config", None) config_name = (d.get("json_data") or {}).get("config_name", None) @@ -313,7 +317,8 @@ def convert_audio_array_to_wav(array, sampling_rate): async def load_and_slice_dataset(dataset_id: str, offset: int, limit: int): try: - d = dataset_service.get(dataset_id).get_metadata() + d_obj = await dataset_service.get(dataset_id) + d = await d_obj.get_metadata() except FileNotFoundError: d = None dataset_len = 0 @@ -321,7 +326,8 @@ async def load_and_slice_dataset(dataset_id: str, offset: int, limit: int): # This means it is a custom dataset the user uploaded if d and d.get("location") == "local": try: - dataset = dataset_service_module.load_local_dataset(dirs.dataset_dir_by_id(dataset_id)) + dataset_dir = await dirs.dataset_dir_by_id(dataset_id) + dataset = await dataset_service_module.load_local_dataset(dataset_dir) except Exception as e: print(f"Error loading dataset: {type(e).__name__}: {e}") return {"status": "error", "message": "An internal error has occurred."} @@ -453,30 +459,33 @@ async def dataset_edit_with_template( offset: int = Query(0, ge=0, description="Starting index"), limit: int = Query(10, ge=1, le=1000, description="Max items to fetch"), ): - dataset_dir = dirs.dataset_dir_by_id(slugify(dataset_id)) - if not storage.exists(dataset_dir): + dataset_dir = await dirs.dataset_dir_by_id(slugify(dataset_id)) + if not await storage.exists(dataset_dir): return {"status": "error", "message": "Dataset directory not found"} rows = [] index = 0 - for root, _, files in storage.walk(dataset_dir): + async for root, _, files in storage.walk(dataset_dir): for file in files: if file.lower().endswith((".json", ".jsonl", ".csv")): # Convert root to string for storage.open metadata_path = storage.join(root, file) try: if file.endswith(".jsonl"): - with storage.open(metadata_path, "r", encoding="utf-8") as f: - data = [json.loads(line) for line in f] + async with await storage.open(metadata_path, "r", encoding="utf-8") as f: + content = await f.read() + data = [json.loads(line) for line in content.splitlines()] elif file.endswith(".json"): - with storage.open(metadata_path, "r", encoding="utf-8") as f: - data = json.load(f) + async with await storage.open(metadata_path, "r", encoding="utf-8") as f: + content = await f.read() + data = json.loads(content) if isinstance(data, dict): data = [data] elif file.endswith(".csv"): - with storage.open(metadata_path, "r", encoding="utf-8") as f: - reader = csv.DictReader(f) + async with await storage.open(metadata_path, "r", encoding="utf-8") as f: + content = await f.read() + reader = csv.DictReader(content.splitlines()) data = [row for row in reader] else: continue @@ -504,15 +513,15 @@ async def dataset_edit_with_template( continue # Use storage.exists for check, but PIL.open for actual image reading - if not storage.exists(image_path): - log(f"Image not found: {image_path}") + if not await storage.exists(image_path): + await log(f"Image not found: {image_path}") continue try: # For PIL, we need a local file, not a storage URI # Download to temp if needed or use PIL directly with storage - with storage.open(image_path, "rb") as f: - img_bytes = f.read() + async with await storage.open(image_path, "rb") as f: + img_bytes = await f.read() img = PILImage.open(BytesIO(img_bytes)) buffer = BytesIO() img.save(buffer, format="JPEG") @@ -563,17 +572,17 @@ async def dataset_edit_with_template( summary="Save edited metadata and create a new dataset with reorganized files and updated metadata.", ) async def save_metadata(dataset_id: str, new_dataset_id: str, file: UploadFile): - old_dataset_dir = dirs.dataset_dir_by_id(slugify(dataset_id)) - if not storage.exists(old_dataset_dir): + old_dataset_dir = await dirs.dataset_dir_by_id(slugify(dataset_id)) + if not await storage.exists(old_dataset_dir): return {"status": "error", "message": "Source dataset not found"} new_dataset_id = slugify(new_dataset_id) - new_dataset_dir = dirs.dataset_dir_by_id(new_dataset_id) + new_dataset_dir = await dirs.dataset_dir_by_id(new_dataset_id) - if storage.exists(new_dataset_dir): + if await storage.exists(new_dataset_dir): return {"status": "error", "message": "New dataset already exists"} - storage.makedirs(new_dataset_dir, exist_ok=True) + await storage.makedirs(new_dataset_dir, exist_ok=True) # Read updates updates_raw = await file.read() @@ -585,22 +594,25 @@ async def save_metadata(dataset_id: str, new_dataset_id: str, file: UploadFile): # Scan source metadata source_map = {} - for root, _, files in storage.walk(old_dataset_dir): + async for root, _, files in storage.walk(old_dataset_dir): for f in files: if f.lower().endswith((".json", ".jsonl", ".csv")): metadata_path = storage.join(root, f) try: if f.endswith(".jsonl"): - with storage.open(metadata_path, "r", encoding="utf-8") as meta_file: - data = [json.loads(line) for line in meta_file] + async with await storage.open(metadata_path, "r", encoding="utf-8") as meta_file: + content = await meta_file.read() + data = [json.loads(line) for line in content.splitlines()] elif f.endswith(".json"): - with storage.open(metadata_path, "r", encoding="utf-8") as meta_file: - data = json.load(meta_file) + async with await storage.open(metadata_path, "r", encoding="utf-8") as meta_file: + content = await meta_file.read() + data = json.loads(content) if isinstance(data, dict): data = [data] elif f.endswith(".csv"): - with storage.open(metadata_path, "r", encoding="utf-8") as meta_file: - reader = csv.DictReader(meta_file) + async with await storage.open(metadata_path, "r", encoding="utf-8") as meta_file: + content = await meta_file.read() + reader = csv.DictReader(content.splitlines()) data = [row for row in reader] else: continue @@ -639,25 +651,25 @@ async def save_metadata(dataset_id: str, new_dataset_id: str, file: UploadFile): source_info = source_map.get(file_name) if not source_info: - log(f"Warning: Source info not found for {file_name}, skipping") + await log(f"Warning: Source info not found for {file_name}, skipping") continue source_path = storage.join(source_info["metadata_root"], file_name) - if not storage.exists(source_path): - log(f"Warning: Source image file not found {source_path}, skipping") + if not await storage.exists(source_path): + await log(f"Warning: Source image file not found {source_path}, skipping") continue if final_label == "": dest_folder = storage.join(new_dataset_dir, final_split) else: dest_folder = storage.join(new_dataset_dir, final_split, final_label) - storage.makedirs(dest_folder, exist_ok=True) + await storage.makedirs(dest_folder, exist_ok=True) # Get just the filename for dest file_basename = Path(file_name).name dest_path = storage.join(dest_folder, file_basename) try: - storage.copy_file(source_path, dest_path) + await storage.copy_file(source_path, dest_path) except Exception as e: print(f"Failed to copy {source_path} to {dest_path}: {e}") return {"status": "error", "message": "Failed to copy from source to destination"} @@ -681,10 +693,10 @@ async def save_metadata(dataset_id: str, new_dataset_id: str, file: UploadFile): folder = storage.join(new_dataset_dir, split, label) metadata_file = storage.join(folder, "metadata.jsonl") try: - with storage.open(metadata_file, "w", encoding="utf-8") as f: + async with await storage.open(metadata_file, "w", encoding="utf-8") as f: for entry in entries: full_entry = {col: entry.get(col, "") for col in all_columns} - f.write(json.dumps(full_entry) + "\n") + await f.write(json.dumps(full_entry) + "\n") except Exception as e: print(f"Failed to write metadata file {metadata_file}: {e}") return {"status": "error", "message": "Failed to write metadata file!"} @@ -704,7 +716,7 @@ async def save_metadata(dataset_id: str, new_dataset_id: str, file: UploadFile): async def dataset_download(dataset_id: str, config_name: str = None): # Ensure we don't already have this dataset in filesystem store try: - _ = dataset_service.get(dataset_id) + _ = await dataset_service.get(dataset_id) return {"status": "error", "message": f"A dataset with the name {dataset_id} already exists"} except FileNotFoundError: pass @@ -712,7 +724,7 @@ async def dataset_download(dataset_id: str, config_name: str = None): # Try to get the dataset info from the gallery gallery = [] json_data = {} - gallery = galleries.get_data_gallery() + gallery = await galleries.get_data_gallery() for dataset in gallery: if dataset["huggingfacerepo"] == dataset_id: json_data = dataset @@ -726,24 +738,24 @@ async def dataset_download(dataset_id: str, config_name: str = None): ds_builder = load_dataset_builder(path=dataset_id, name=config_name, trust_remote_code=True) else: ds_builder = load_dataset_builder(dataset_id, trust_remote_code=True) - log(f"Dataset builder loaded for dataset_id: {dataset_id}") + await log(f"Dataset builder loaded for dataset_id: {dataset_id}") except ValueError as e: - log(f"ValueError occurred: {type(e).__name__}: {e}") + await log(f"ValueError occurred: {type(e).__name__}: {e}") if "Config name is missing" in str(e): return {"status": "error", "message": "Please enter the folder_name of the dataset from huggingface"} else: return {"status": "error", "message": "An internal error has occurred!"} except DatasetNotFoundError as e: - log(f"DatasetNotFoundError occurred: {e}") + await log(f"DatasetNotFoundError occurred: {e}") return { "status": "error", "message": f"Dataset '{dataset_id}' not found or is private. Please check the dataset ID.", } except Exception as e: - log(f"Exception occurred: {type(e).__name__}: {e}") + await log(f"Exception occurred: {type(e).__name__}: {e}") return {"status": "error", "message": "An internal error has occurred!"} dataset_size = ds_builder.info.download_size @@ -766,55 +778,56 @@ async def dataset_download(dataset_id: str, config_name: str = None): # Create filesystem metadata try: try: - sdk_ds = dataset_service.get(dataset_id) + sdk_ds = await dataset_service.get(dataset_id) except FileNotFoundError: - sdk_ds = dataset_service.create(dataset_id) - sdk_ds.set_metadata( + sdk_ds = await dataset_service.create(dataset_id) + await sdk_ds.set_metadata( location="huggingfacehub", description=ds_builder.info.description or "", size=dataset_size, json_data=json_data, ) - log(f"Dataset created in filesystem for dataset_id: {dataset_id}") + await log(f"Dataset created in filesystem for dataset_id: {dataset_id}") except Exception as e: print(f"Failed to write dataset metadata to SDK store: {type(e).__name__}: {e}") # Download the dataset # Later on we can move this to a job async def load_dataset_thread(dataset_id, config_name=None): - logFile = storage.open(get_global_log_path(), "a") - flushLogFile = FlushFile(logFile) - with contextlib.redirect_stdout(flushLogFile), contextlib.redirect_stderr(flushLogFile): - try: - if config_name is not None: - dataset = load_dataset(path=dataset_id, name=config_name, trust_remote_code=True) - else: - dataset = load_dataset(dataset_id, trust_remote_code=True) - print(f"Dataset downloaded for dataset_id: {dataset_id}") - return dataset - - except ValueError as e: - error_msg = f"{type(e).__name__}: {e}" - print(error_msg) - raise ValueError(e) - - except Exception as e: - error_msg = f"{type(e).__name__}: {e}" - print(error_msg) - raise + global_log_path = await get_global_log_path() + async with await storage.open(global_log_path, "a") as logFile: + flushLogFile = FlushFile(logFile) + with contextlib.redirect_stdout(flushLogFile), contextlib.redirect_stderr(flushLogFile): + try: + if config_name is not None: + dataset = load_dataset(path=dataset_id, name=config_name, trust_remote_code=True) + else: + dataset = load_dataset(dataset_id, trust_remote_code=True) + print(f"Dataset downloaded for dataset_id: {dataset_id}") + return dataset + + except ValueError as e: + error_msg = f"{type(e).__name__}: {e}" + print(error_msg) + raise ValueError(e) + + except Exception as e: + error_msg = f"{type(e).__name__}: {e}" + print(error_msg) + raise try: dataset = await load_dataset_thread(dataset_id, config_name) except ValueError as e: - log(f"Exception occurred while downloading dataset: {type(e).__name__}: {e}") + await log(f"Exception occurred while downloading dataset: {type(e).__name__}: {e}") if "Config name is missing" in str(e): return {"status": "error", "message": "Please enter the folder_name of the dataset from huggingface"} else: return {"status": "error", "message": "An internal error has occurred!"} except Exception as e: - log(f"Exception occurred while downloading dataset: {type(e).__name__}: {e}") + await log(f"Exception occurred while downloading dataset: {type(e).__name__}: {e}") return {"status": "error", "message": "An internal error has occurred!"} return {"status": "success"} @@ -824,7 +837,7 @@ async def load_dataset_thread(dataset_id, config_name=None): async def dataset_list(generated: bool = True): # Filesystem-only list try: - merged_list = dataset_service.list_all() + merged_list = await dataset_service.list_all() except Exception: merged_list = [] @@ -850,7 +863,7 @@ async def dataset_list(generated: bool = True): @router.get("/generated_datasets_list", summary="List available generated datasets.") async def generated_datasets_list(): try: - merged_list = dataset_service.list_all() + merged_list = await dataset_service.list_all() except Exception: merged_list = [] result = [] @@ -872,20 +885,20 @@ async def dataset_new(dataset_id: str, generated: bool = False): # Check to make sure we don't have a dataset with this name (filesystem) try: - _ = dataset_service.get(dataset_id) + _ = await dataset_service.get(dataset_id) return {"status": "error", "message": f"A dataset with the name {dataset_id} already exists"} except FileNotFoundError: pass # Now make a directory that maps to the above dataset_id # Check if the directory already exists - dataset_path = dirs.dataset_dir_by_id(dataset_id) - if not storage.exists(dataset_path): - storage.makedirs(dataset_path, exist_ok=True) + dataset_path = await dirs.dataset_dir_by_id(dataset_id) + if not await storage.exists(dataset_path): + await storage.makedirs(dataset_path, exist_ok=True) # Create filesystem metadata try: - ds = dataset_service.create(dataset_id) - ds.set_metadata( + ds = await dataset_service.create(dataset_id) + await ds.set_metadata( location="local", description="", size=-1, @@ -900,7 +913,8 @@ async def dataset_new(dataset_id: str, generated: bool = False): async def dataset_delete(dataset_id: str): dataset_id = secure_filename(dataset_id) # delete directory and contents. ignore_errors because we don't care if the directory doesn't exist - storage.rm_tree(dirs.dataset_dir_by_id(dataset_id)) + dataset_dir = await dirs.dataset_dir_by_id(dataset_id) + await storage.rm_tree(dataset_dir) return {"status": "success"} @@ -925,11 +939,12 @@ async def create_upload_file(dataset_id: str, files: list[UploadFile]): try: content = await file.read() - target_path = storage.join(dirs.dataset_dir_by_id(dataset_id), str(file.filename)) + dataset_dir = await dirs.dataset_dir_by_id(dataset_id) + target_path = storage.join(dataset_dir, str(file.filename)) # aiofiles doesn't support URIs, so we need to use storage.open instead - storage.makedirs(storage.join(dirs.dataset_dir_by_id(dataset_id)), exist_ok=True) - with storage.open(target_path, "wb") as out_file: - out_file.write(content) + await storage.makedirs(dataset_dir, exist_ok=True) + async with await storage.open(target_path, "wb") as out_file: + await out_file.write(content) uploaded_filenames.append(str(file.filename)) except Exception: raise HTTPException(status_code=403, detail="There was a problem uploading the file") @@ -937,8 +952,8 @@ async def create_upload_file(dataset_id: str, files: list[UploadFile]): # Update dataset metadata with uploaded files if uploaded_filenames: try: - ds = dataset_service.get(dataset_id) - current_data = ds.get_metadata() + ds = await dataset_service.get(dataset_id) + current_data = await ds.get_metadata() json_data = current_data.get("json_data", {}) # Add files list if not present or merge with existing existing_files = json_data.get("files", []) @@ -947,7 +962,7 @@ async def create_upload_file(dataset_id: str, files: list[UploadFile]): else: all_files = uploaded_filenames json_data["files"] = all_files - ds.set_metadata(json_data=json_data) + await ds.set_metadata(json_data=json_data) except Exception as e: print(f"Failed to update dataset metadata with files: {type(e).__name__}: {e}") diff --git a/api/transformerlab/routers/experiment/conversations.py b/api/transformerlab/routers/experiment/conversations.py index 499728680..4a77fea59 100644 --- a/api/transformerlab/routers/experiment/conversations.py +++ b/api/transformerlab/routers/experiment/conversations.py @@ -15,19 +15,20 @@ @router.get(path="/list") async def get_conversations(experimentId: str): - exp_obj = Experiment.get(experimentId) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(experimentId) + experiment_dir = await exp_obj.get_dir() conversation_dir = storage.join(experiment_dir, "conversations") # make directory if it does not exist: - if not storage.exists(conversation_dir): - storage.makedirs(conversation_dir, exist_ok=True) + if not await storage.exists(conversation_dir): + await storage.makedirs(conversation_dir, exist_ok=True) # now get a list of all the files in the conversations directory conversations_files = [] try: - for entry in storage.ls(conversation_dir, detail=False): + entries = await storage.ls(conversation_dir, detail=False) + for entry in entries: filename = entry.rstrip("/").split("/")[-1] if filename.endswith(".json"): conversations_files.append(filename) @@ -39,12 +40,13 @@ async def get_conversations(experimentId: str): # now read each conversation and create a list of all conversations # and their contents for i in range(len(conversations_files)): - with storage.open(storage.join(conversation_dir, conversations_files[i]), "r") as f: + async with await storage.open(storage.join(conversation_dir, conversations_files[i]), "r") as f: new_conversation = {} new_conversation["id"] = conversations_files[i] # remove .json from end of id new_conversation["id"] = new_conversation["id"][:-5] - new_conversation["contents"] = json.load(f) + content = await f.read() + new_conversation["contents"] = json.loads(content) # use file timestamp to get a date try: # fsspec detail listing could be used; fallback to 0 @@ -64,53 +66,54 @@ async def save_conversation( experimentId: str, conversation_id: Annotated[str, Body()], conversation: Annotated[str, Body()] ): conversation_id = secure_filename(conversation_id) - exp_obj = Experiment.get(experimentId) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(experimentId) + experiment_dir = await exp_obj.get_dir() conversation_dir = storage.join(experiment_dir, "conversations") final_path = storage.join(conversation_dir, conversation_id + ".json") # now save the conversation - with storage.open(final_path, "w") as f: - f.write(conversation) + async with await storage.open(final_path, "w") as f: + await f.write(conversation) return {"message": f"Conversation {conversation_id} saved"} @router.delete(path="/delete") async def delete_conversation(experimentId: str, conversation_id: str): - exp_obj = Experiment.get(experimentId) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(experimentId) + experiment_dir = await exp_obj.get_dir() conversation_id = secure_filename(conversation_id) conversation_dir = storage.join(experiment_dir, "conversations") final_path = storage.join(conversation_dir, conversation_id + ".json") # now delete the conversation - if storage.exists(final_path): - storage.rm(final_path) + if await storage.exists(final_path): + await storage.rm(final_path) return {"message": f"Conversation {conversation_id} deleted"} @router.get(path="/list_audio") async def list_audio(experimentId: str): - exp_obj = Experiment.get(experimentId) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(experimentId) + experiment_dir = await exp_obj.get_dir() audio_dir = storage.join(experiment_dir, "audio") - storage.makedirs(audio_dir, exist_ok=True) + await storage.makedirs(audio_dir, exist_ok=True) # now get a list of all the json files in the audio directory audio_files_metadata = [] try: - entries = storage.ls(audio_dir, detail=False) + entries = await storage.ls(audio_dir, detail=False) for entry in entries: filename = entry.rstrip("/").split("/")[-1] if filename.endswith(".json"): file_path = storage.join(audio_dir, filename) - with storage.open(file_path, "r") as f: + async with await storage.open(file_path, "r") as f: try: - data = json.load(f) + content = await f.read() + data = json.loads(content) # Add the file modification time for sorting data["id"] = filename[:-5] # Remove .json from the filename # fsspec doesn't always provide mtime, use 0 as fallback @@ -129,8 +132,8 @@ async def list_audio(experimentId: str): @router.get(path="/download_audio") async def download_audio(experimentId: str, filename: str, audioFolder: str = "audio"): - exp_obj = Experiment.get(experimentId) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(experimentId) + experiment_dir = await exp_obj.get_dir() # Use the provided audioFolder parameter, defaulting to "audio" audioFolder = secure_filename(audioFolder) @@ -140,7 +143,7 @@ async def download_audio(experimentId: str, filename: str, audioFolder: str = "a filename = secure_filename(filename) file_path = storage.join(audio_dir, filename) - if not storage.exists(file_path): + if not await storage.exists(file_path): return {"message": f"Audio file {filename} does not exist in experiment {experimentId}"} # FileResponse needs a local file path, so use the path string directly @@ -179,21 +182,21 @@ async def delete_audio(experimentId: str, id: str): {"message": "Experiment does not exist"} {"message": "Audio file does not exist in experiment "} """ - exp_obj = Experiment.get(experimentId) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(experimentId) + experiment_dir = await exp_obj.get_dir() audio_dir = storage.join(experiment_dir, "audio") # Delete the metadata file (.json) id = secure_filename(id) metadata_path = storage.join(audio_dir, id + ".json") - if not storage.exists(metadata_path): + if not await storage.exists(metadata_path): return {"message": f"Audio file {id} does not exist in experiment {experimentId}"} - storage.rm(metadata_path) + await storage.rm(metadata_path) # Delete the audio file (.wav) audio_path = storage.join(audio_dir, id + ".wav") - if storage.exists(audio_path): - storage.rm(audio_path) + if await storage.exists(audio_path): + await storage.rm(audio_path) return {"message": f"Audio file {id} deleted from experiment {experimentId}"} @@ -201,15 +204,15 @@ async def delete_audio(experimentId: str, id: str): @router.get("/list_transcription") async def list_transcription(experimentId: str): # Get experiment object and directory - exp_obj = Experiment.get(experimentId) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(experimentId) + experiment_dir = await exp_obj.get_dir() transcription_dir = storage.join(experiment_dir, "transcriptions") - storage.makedirs(transcription_dir, exist_ok=True) + await storage.makedirs(transcription_dir, exist_ok=True) # List all .json files in the transcription directory transcription_files_metadata = [] try: - entries = storage.ls(transcription_dir, detail=True) + entries = await storage.ls(transcription_dir, detail=True) for entry in entries: # Handle both dict (detail=True) and string (detail=False) formats if isinstance(entry, dict): @@ -220,9 +223,10 @@ async def list_transcription(experimentId: str): # Use the full path from entry if available, otherwise construct it if not file_path or file_path == filename: file_path = storage.join(transcription_dir, filename) - with storage.open(file_path, "r") as f: + async with await storage.open(file_path, "r") as f: try: - data = json.load(f) + content = await f.read() + data = json.loads(content) # Add the file modification time for sorting data["id"] = filename[:-5] # Remove .json from the filename # Extract mtime from file metadata, fallback to 0 if not available @@ -236,9 +240,10 @@ async def list_transcription(experimentId: str): filename = entry.rstrip("/").split("/")[-1] if "/" in entry else entry if filename.endswith(".json"): file_path = storage.join(transcription_dir, filename) - with storage.open(file_path, "r") as f: + async with await storage.open(file_path, "r") as f: try: - data = json.load(f) + content = await f.read() + data = json.loads(content) data["id"] = filename[:-5] # Remove .json from the filename # fsspec doesn't always provide mtime, use 0 as fallback data["file_date"] = 0 @@ -253,24 +258,24 @@ async def list_transcription(experimentId: str): @router.get("/download_transcription") async def download_transcription(experimentId: str, filename: str): - exp_obj = Experiment.get(experimentId) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(experimentId) + experiment_dir = await exp_obj.get_dir() text_dir = storage.join(experiment_dir, "transcriptions") filename = secure_filename(filename) file_path = storage.join(text_dir, filename) - if not storage.exists(file_path): + if not await storage.exists(file_path): return {"message": f"Text file {filename} does not exist in experiment {experimentId}"} return FileResponse(path=file_path, filename=filename, media_type="text/plain") @router.delete("/delete_transcription") async def delete_transcription(experimentId: str, id: str): - exp_obj = Experiment.get(experimentId) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(experimentId) + experiment_dir = await exp_obj.get_dir() text_dir = storage.join(experiment_dir, "transcriptions") id = secure_filename(id) text_path = storage.join(text_dir, id + ".json") - if not storage.exists(text_path): + if not await storage.exists(text_path): return {"message": f"Text file {id} does not exist in experiment {experimentId}"} - storage.rm(text_path) + await storage.rm(text_path) return {"message": f"Text file {id} deleted from experiment {experimentId}"} diff --git a/api/transformerlab/routers/experiment/diffusion.py b/api/transformerlab/routers/experiment/diffusion.py index 59d94c693..47910a63e 100644 --- a/api/transformerlab/routers/experiment/diffusion.py +++ b/api/transformerlab/routers/experiment/diffusion.py @@ -153,11 +153,23 @@ def _setup_diffusion_logger(): # File handler try: + import asyncio from lab.dirs import get_global_log_path - file_handler = logging.FileHandler(get_global_log_path(), encoding="utf-8") - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) + # Check if there's already an event loop running + # If so, we can't use asyncio.run() and must skip file handler setup + # to avoid the "coroutine was never awaited" warning + try: + asyncio.get_running_loop() + # There's a running loop, skip file handler setup + # (can't use asyncio.run() when loop is already running) + except RuntimeError: + # No running event loop, safe to use asyncio.run() + # Create and immediately await the coroutine + log_path = asyncio.run(get_global_log_path()) + file_handler = logging.FileHandler(log_path, encoding="utf-8") + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) except Exception: pass # Continue without file logging if there's an issue @@ -278,9 +290,9 @@ class CreateDatasetRequest(BaseModel): HISTORY_FILE = "history.json" -def get_diffusion_dir(experiment_name: str = None, workspace_dir: str | None = None): +async def get_diffusion_dir(experiment_name: str = None, workspace_dir: str | None = None): """Get the diffusion directory path""" - base = workspace_dir or get_workspace_dir() + base = workspace_dir or await get_workspace_dir() if experiment_name is not None: # New experiment-specific path @@ -290,32 +302,32 @@ def get_diffusion_dir(experiment_name: str = None, workspace_dir: str | None = N return storage.join(base, "diffusion") -def get_images_dir(experiment_name: str = None, workspace_dir: str | None = None): +async def get_images_dir(experiment_name: str = None, workspace_dir: str | None = None): """Get the images directory path""" - return storage.join(get_diffusion_dir(experiment_name, workspace_dir), "images") + return storage.join(await get_diffusion_dir(experiment_name, workspace_dir), "images") -def get_history_file_path(experiment_name: str = None, workspace_dir: str | None = None): +async def get_history_file_path(experiment_name: str = None, workspace_dir: str | None = None): """Get the history file path""" # Create a history file in the diffusion directory if it doesn't exist - return storage.join(get_diffusion_dir(experiment_name, workspace_dir), HISTORY_FILE) + return storage.join(await get_diffusion_dir(experiment_name, workspace_dir), HISTORY_FILE) -def ensure_directories(experiment_name: str = None, workspace_dir: str | None = None): +async def ensure_directories(experiment_name: str = None, workspace_dir: str | None = None): """Ensure diffusion and images directories exist""" - diffusion_dir = get_diffusion_dir(experiment_name, workspace_dir) - images_dir = get_images_dir(experiment_name, workspace_dir) - history_file_path = get_history_file_path(experiment_name, workspace_dir) - - storage.makedirs(diffusion_dir, exist_ok=True) - storage.makedirs(images_dir, exist_ok=True) - if not storage.exists(history_file_path): - with storage.open(history_file_path, "a"): + diffusion_dir = await get_diffusion_dir(experiment_name, workspace_dir) + images_dir = await get_images_dir(experiment_name, workspace_dir) + history_file_path = await get_history_file_path(experiment_name, workspace_dir) + + await storage.makedirs(diffusion_dir, exist_ok=True) + await storage.makedirs(images_dir, exist_ok=True) + if not await storage.exists(history_file_path): + async with await storage.open(history_file_path, "a"): # Create the history file if it doesn't exist pass -def load_history( +async def load_history( limit: int = 50, offset: int = 0, experiment_name: str = None, workspace_dir: str | None = None ) -> HistoryResponse: """Load image generation history from both new and old paths for backward compatibility""" @@ -324,20 +336,20 @@ def load_history( # Load from new experiment-specific path if experiment info is provided if experiment_name is not None: try: - new_history_file = get_history_file_path(experiment_name, workspace_dir) - if storage.exists(new_history_file): - with storage.open(new_history_file, "r") as f: - new_history_data = json.load(f) + new_history_file = await get_history_file_path(experiment_name, workspace_dir) + if await storage.exists(new_history_file): + async with await storage.open(new_history_file, "r") as f: + new_history_data = json.loads(await f.read()) all_images.extend(new_history_data) except (json.JSONDecodeError, FileNotFoundError): pass # Load from legacy global path for backward compatibility try: - legacy_history_file = get_history_file_path(None, workspace_dir) # No experiment → legacy path - if storage.exists(legacy_history_file): - with storage.open(legacy_history_file, "r") as f: - legacy_history_data = json.load(f) + legacy_history_file = await get_history_file_path(None, workspace_dir) # No experiment → legacy path + if await storage.exists(legacy_history_file): + async with await storage.open(legacy_history_file, "r") as f: + legacy_history_data = json.loads(await f.read()) all_images.extend(legacy_history_data) except (json.JSONDecodeError, FileNotFoundError): pass @@ -363,7 +375,7 @@ def load_history( return HistoryResponse(images=items, total=total_items) -def find_image_by_id( +async def find_image_by_id( image_id: str, experiment_name: str = None, workspace_dir: str | None = None ) -> ImageHistoryItem | None: """Find a specific image by ID without loading all history, searching both new and old paths""" @@ -371,10 +383,10 @@ def find_image_by_id( # Search in new experiment-specific path first if experiment info is provided if experiment_name is not None: try: - new_history_file = get_history_file_path(experiment_name, workspace_dir) - if storage.exists(new_history_file): - with storage.open(new_history_file, "r") as f: - history = json.load(f) + new_history_file = await get_history_file_path(experiment_name, workspace_dir) + if await storage.exists(new_history_file): + async with await storage.open(new_history_file, "r") as f: + history = json.loads(await f.read()) for item in history: if item.get("id") == image_id: return ImageHistoryItem(**item) @@ -383,10 +395,10 @@ def find_image_by_id( # Search in legacy global path for backward compatibility try: - legacy_history_file = get_history_file_path(None, workspace_dir) # No experiment → legacy path - if storage.exists(legacy_history_file): - with storage.open(legacy_history_file, "r") as f: - history = json.load(f) + legacy_history_file = await get_history_file_path(None, workspace_dir) # No experiment → legacy path + if await storage.exists(legacy_history_file): + async with await storage.open(legacy_history_file, "r") as f: + history = json.loads(await f.read()) for item in history: if item.get("id") == image_id: return ImageHistoryItem(**item) @@ -425,11 +437,14 @@ async def generate_image(experimentId: str, request: DiffusionRequest, http_requ "config": request_dict, } - job_id = job_create(type="DIFFUSION", status="QUEUED", job_data=job_config, experiment_id=experimentId) + job_id = await job_create( + type="DIFFUSION", status="QUEUED", job_data=job_config, experiment_id=experimentId + ) # Get experiment name for experiment-specific paths - images_folder = storage.join(get_images_dir(experimentId), generation_id) - if not images_folder.startswith(get_images_dir(experimentId)): # Validate containment + images_dir = await get_images_dir(experimentId) + images_folder = storage.join(images_dir, generation_id) + if not images_folder.startswith(images_dir): # Validate containment raise HTTPException(status_code=400, detail="Invalid generation_id: Path traversal detected.") tmp_json_path = storage.join(images_folder, "tmp_json.json") @@ -468,8 +483,8 @@ async def is_valid_diffusion(experimentId: str, request: DiffusionRequest): try: from lab.model import Model as ModelService - model_service = ModelService.get(model_id) - model_data = model_service.get_metadata() + model_service = await ModelService.get(model_id) + model_data = await model_service.get_metadata() if model_data and model_data.get("json_data"): json_data = model_data["json_data"] @@ -550,8 +565,8 @@ async def get_history(experimentId: str, limit: int = 50, offset: int = 0, http_ raise HTTPException(status_code=400, detail="Offset must be non-negative") # Get experiment name for experiment-specific paths - workspace_dir = get_workspace_dir() - return load_history(limit=limit, offset=offset, experiment_name=experimentId, workspace_dir=workspace_dir) + workspace_dir = await get_workspace_dir() + return await load_history(limit=limit, offset=offset, experiment_name=experimentId, workspace_dir=workspace_dir) @router.get("/history/{image_id}", summary="Get the actual image by ID") @@ -575,11 +590,11 @@ async def get_image_by_id( mask_image: Whether to return the mask image instead of generated image """ # Get experiment name for experiment-specific paths - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() if step: # If step is requested, we need to check if intermediate images were saved - images_dir = get_images_dir(experimentId, workspace_dir) + images_dir = await get_images_dir(experimentId, workspace_dir) image_dir_based_on_id = storage.join(images_dir, image_id) # Ensure the constructed path is within the intended base directory @@ -587,7 +602,7 @@ async def get_image_by_id( raise HTTPException(status_code=400, detail="Invalid image ID or path traversal attempt detected") # Check if the image path is a directory (new format) - if not storage.isdir(image_dir_based_on_id): + if not await storage.isdir(image_dir_based_on_id): raise HTTPException(status_code=404, detail=f"Image path is not a directory for image ID {image_id}") # Construct the path for the step image @@ -595,13 +610,13 @@ async def get_image_by_id( if not step_image_path.startswith(images_dir): raise HTTPException(status_code=400, detail="Invalid path traversal attempt detected") - if not storage.exists(step_image_path): + if not await storage.exists(step_image_path): raise HTTPException(status_code=404, detail=f"Step image file not found at {step_image_path}") return FileResponse(step_image_path) # Use the efficient function to find the specific image - image_item = find_image_by_id(image_id, experimentId, workspace_dir) + image_item = await find_image_by_id(image_id, experimentId, workspace_dir) if not image_item: raise HTTPException(status_code=404, detail=f"Image with ID {image_id} not found") @@ -612,7 +627,7 @@ async def get_image_by_id( if not image_item.mask_image_path or not image_item.mask_image_path.strip(): raise HTTPException(status_code=404, detail=f"No mask image found for image ID {image_id}") image_path = image_item.mask_image_path - if not storage.exists(image_path): + if not await storage.exists(image_path): raise HTTPException(status_code=404, detail=f"Mask image file not found at {image_path}") elif input_image: # Return the input image if requested and available @@ -620,18 +635,18 @@ async def get_image_by_id( raise HTTPException(status_code=404, detail=f"No input image found for image ID {image_id}") image_path = image_item.input_image_path - if not storage.exists(image_path): + if not await storage.exists(image_path): raise HTTPException(status_code=404, detail=f"Input image file not found at {image_path}") elif preprocessed: if not image_item.processed_image or not image_item.processed_image.strip(): raise HTTPException(status_code=404, detail=f"No preprocessed image found for image ID {image_id}") image_path = image_item.processed_image - if not storage.exists(image_path): + if not await storage.exists(image_path): raise HTTPException(status_code=404, detail=f"Preprocessed image file not found at {image_path}") else: # Return the generated output image (default behavior) # Check if image_path is a folder (new format) or a file (old format) - if storage.isdir(image_item.image_path): + if await storage.isdir(image_item.image_path): # New format: folder with numbered images if index < 0 or index >= (image_item.num_images if hasattr(image_item, "num_images") else 1): raise HTTPException( @@ -648,7 +663,7 @@ async def get_image_by_id( raise HTTPException(status_code=404, detail="Only index 0 available for this image set") image_path = image_item.image_path - if not storage.exists(image_path): + if not await storage.exists(image_path): raise HTTPException(status_code=404, detail=f"Image file not found at {image_path}") return FileResponse(image_path) @@ -666,21 +681,21 @@ async def get_image_info_by_id(image_id: str, experimentId: str, http_request: R Image metadata including number of images available """ # Get experiment name for experiment-specific paths - workspace_dir = get_workspace_dir() - image_item = find_image_by_id(image_id, experimentId, workspace_dir) + workspace_dir = await get_workspace_dir() + image_item = await find_image_by_id(image_id, experimentId, workspace_dir) if not image_item: raise HTTPException(status_code=404, detail=f"Image with ID {image_id} not found") # Check if image folder/file exists - if not storage.exists(image_item.image_path): + if not await storage.exists(image_item.image_path): raise HTTPException(status_code=404, detail=f"Image path not found at {image_item.image_path}") # Determine number of images available num_images = 1 # Default for old format - if storage.isdir(image_item.image_path): + if await storage.isdir(image_item.image_path): # Count PNG files in the directory - entries = storage.ls(image_item.image_path, detail=False) + entries = await storage.ls(image_item.image_path, detail=False) png_files = [f for f in entries if f.endswith(".png") and f.replace(".png", "").isdigit()] num_images = len(png_files) @@ -703,21 +718,21 @@ async def get_image_count(image_id: str, experimentId: str, http_request: Reques Number of images available """ # Get experiment name for experiment-specific paths - workspace_dir = get_workspace_dir() - image_item = find_image_by_id(image_id, experimentId, workspace_dir) + workspace_dir = await get_workspace_dir() + image_item = await find_image_by_id(image_id, experimentId, workspace_dir) if not image_item: raise HTTPException(status_code=404, detail=f"Image with ID {image_id} not found") # Check if image folder/file exists - if not storage.exists(image_item.image_path): + if not await storage.exists(image_item.image_path): raise HTTPException(status_code=404, detail=f"Image path not found at {image_item.image_path}") # Determine number of images available num_images = 1 # Default for old format - if storage.isdir(image_item.image_path): + if await storage.isdir(image_item.image_path): # Count PNG files in the directory - entries = storage.ls(image_item.image_path, detail=False) + entries = await storage.ls(image_item.image_path, detail=False) png_files = [f for f in entries if f.endswith(".png") and f.replace(".png", "").isdigit()] num_images = len(png_files) @@ -739,14 +754,14 @@ async def get_all_images(image_id: str, experimentId: str, http_request: Request import tempfile # Get experiment name for experiment-specific paths - workspace_dir = get_workspace_dir() - image_item = find_image_by_id(image_id, experimentId, workspace_dir) + workspace_dir = await get_workspace_dir() + image_item = await find_image_by_id(image_id, experimentId, workspace_dir) if not image_item: raise HTTPException(status_code=404, detail=f"Image with ID {image_id} not found") # Check if image folder/file exists - if not storage.exists(image_item.image_path): + if not await storage.exists(image_item.image_path): raise HTTPException(status_code=404, detail=f"Image path not found at {image_item.image_path}") # Create a temporary zip file @@ -755,16 +770,16 @@ async def get_all_images(image_id: str, experimentId: str, http_request: Request try: with zipfile.ZipFile(temp_zip.name, "w", zipfile.ZIP_DEFLATED) as zipf: - if storage.isdir(image_item.image_path): + if await storage.isdir(image_item.image_path): # New format: add all PNG files from the directory - entries = storage.ls(image_item.image_path, detail=False) + entries = await storage.ls(image_item.image_path, detail=False) for entry in entries: filename = entry.rstrip("/").split("/")[-1] if filename.endswith(".png") and filename.replace(".png", "").isdigit(): file_path = entry # Use the full path from storage.ls # Download to temp if remote, then add to zip - with storage.open(file_path, "rb") as remote_file: - content = remote_file.read() + async with await storage.open(file_path, "rb") as remote_file: + content = await remote_file.read() zipf.writestr(filename, content) else: # Old format: add the single file @@ -772,8 +787,8 @@ async def get_all_images(image_id: str, experimentId: str, http_request: Request # Extract just the filename for the zip filename = file_path.rstrip("/").split("/")[-1] # Download to temp if remote, then add to zip - with storage.open(file_path, "rb") as remote_file: - content = remote_file.read() + async with await storage.open(file_path, "rb") as remote_file: + content = await remote_file.read() zipf.writestr(filename, content) return FileResponse( @@ -799,16 +814,16 @@ async def delete_image_from_history(experimentId: str, image_id: str, http_reque image_id: The unique ID of the image set to delete """ # Get experiment name for experiment-specific paths - workspace_dir = get_workspace_dir() - history_file = get_history_file_path(experimentId, workspace_dir) + workspace_dir = await get_workspace_dir() + history_file = await get_history_file_path(experimentId, workspace_dir) - if not storage.exists(history_file): + if not await storage.exists(history_file): raise HTTPException(status_code=404, detail="No history found") try: # Load current history - with storage.open(history_file, "r") as f: - history = json.load(f) + async with await storage.open(history_file, "r") as f: + history = json.loads(await f.read()) # Find and remove the item item_to_remove = None @@ -824,24 +839,24 @@ async def delete_image_from_history(experimentId: str, image_id: str, http_reque # Remove image files/folder image_path = item_to_remove["image_path"] - if storage.exists(image_path): - if storage.isdir(image_path): + if await storage.exists(image_path): + if await storage.isdir(image_path): # New format: remove entire folder - storage.rm_tree(image_path) + await storage.rm_tree(image_path) else: # Old format: remove single file - storage.rm(image_path) + await storage.rm(image_path) # Remove input image if it exists - if item_to_remove.get("input_image_path") and storage.exists(item_to_remove["input_image_path"]): - storage.rm(item_to_remove["input_image_path"]) + if item_to_remove.get("input_image_path") and await storage.exists(item_to_remove["input_image_path"]): + await storage.rm(item_to_remove["input_image_path"]) # Remove processed image if it exists - if item_to_remove.get("processed_image") and storage.exists(item_to_remove["processed_image"]): - storage.rm(item_to_remove["processed_image"]) + if item_to_remove.get("processed_image") and await storage.exists(item_to_remove["processed_image"]): + await storage.rm(item_to_remove["processed_image"]) # Save updated history - with storage.open(history_file, "w") as f: - json.dump(updated_history, f, indent=2) + async with await storage.open(history_file, "w") as f: + await f.write(json.dumps(updated_history, indent=2)) return JSONResponse( content={"message": f"Image set {image_id} deleted successfully", "deleted_item": item_to_remove} @@ -857,51 +872,51 @@ async def clear_history(experimentId: str, http_request: Request = None): """ try: # Get experiment name for experiment-specific paths - workspace_dir = get_workspace_dir() - history_file = get_history_file_path(experimentId, workspace_dir) - images_dir = get_images_dir(experimentId, workspace_dir) + workspace_dir = await get_workspace_dir() + history_file = await get_history_file_path(experimentId, workspace_dir) + images_dir = await get_images_dir(experimentId, workspace_dir) # Load current history to get image paths deleted_count = 0 - if storage.exists(history_file): - with storage.open(history_file, "r") as f: - history = json.load(f) + if await storage.exists(history_file): + async with await storage.open(history_file, "r") as f: + history = json.loads(await f.read()) # Remove all image files/folders for item in history: image_path = item["image_path"] - if storage.exists(image_path): - if storage.isdir(image_path): + if await storage.exists(image_path): + if await storage.isdir(image_path): # New format: remove folder and count files inside - entries = storage.ls(image_path, detail=False) + entries = await storage.ls(image_path, detail=False) file_count = len([f for f in entries if f.endswith(".png")]) - storage.rm_tree(image_path) + await storage.rm_tree(image_path) deleted_count += file_count else: # Old format: remove single file - storage.rm(image_path) + await storage.rm(image_path) deleted_count += 1 # Remove input image if it exists - if item.get("input_image_path") and storage.exists(item["input_image_path"]): - storage.rm(item["input_image_path"]) + if item.get("input_image_path") and await storage.exists(item["input_image_path"]): + await storage.rm(item["input_image_path"]) # Remove processed image if it exists - if item.get("processed_image") and storage.exists(item["processed_image"]): - storage.rm(item["processed_image"]) + if item.get("processed_image") and await storage.exists(item["processed_image"]): + await storage.rm(item["processed_image"]) # Clear history file - with storage.open(history_file, "w") as f: - json.dump([], f) + async with await storage.open(history_file, "w") as f: + await f.write(json.dumps([])) # Remove any remaining files/folders in images directory - if storage.exists(images_dir): - for entry in storage.ls(images_dir, detail=False): + if await storage.exists(images_dir): + for entry in await storage.ls(images_dir, detail=False): item_name = entry.rstrip("/").split("/")[-1] item_path = storage.join(images_dir, item_name) - if storage.isdir(item_path): - storage.rm_tree(item_path) + if await storage.isdir(item_path): + await storage.rm_tree(item_path) elif item_name.endswith(".png"): - storage.rm(item_path) + await storage.rm(item_path) return JSONResponse( content={ @@ -935,7 +950,7 @@ async def create_dataset_from_history(request: CreateDatasetRequest, experimentI # Check if dataset already exists try: - Dataset.get(dataset_id) + await Dataset.get(dataset_id) # If we get here, the dataset exists raise HTTPException(status_code=400, detail=f"Dataset '{dataset_id}' already exists") except FileNotFoundError: @@ -945,7 +960,7 @@ async def create_dataset_from_history(request: CreateDatasetRequest, experimentI # Find selected images efficiently selected_images = [] for image_id in image_ids: - image_item = find_image_by_id(image_id, experimentId) + image_item = await find_image_by_id(image_id, experimentId) if image_item: selected_images.append(image_item) @@ -955,9 +970,9 @@ async def create_dataset_from_history(request: CreateDatasetRequest, experimentI # Calculate total image count (accounting for multi-image generations) total_image_count = 0 for image_item in selected_images: - if storage.isdir(image_item.image_path): + if await storage.isdir(image_item.image_path): # Count images in folder - entries = storage.ls(image_item.image_path, detail=False) + entries = await storage.ls(image_item.image_path, detail=False) image_files = [f for f in entries if f.endswith(".png") and f.replace(".png", "").isdigit()] total_image_count += len(image_files) else: @@ -973,14 +988,14 @@ async def create_dataset_from_history(request: CreateDatasetRequest, experimentI "image_count": total_image_count, "created_from_image_ids": image_ids, } - new_dataset = create_local_dataset(dataset_id, json_data=json_data) + new_dataset = await create_local_dataset(dataset_id, json_data=json_data) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create dataset: {str(e)}") # Create dataset directory - dataset_dir = new_dataset.get_dir() + dataset_dir = await new_dataset.get_dir() images_dir = storage.join(dataset_dir, "train") - storage.makedirs(images_dir, exist_ok=True) + await storage.makedirs(images_dir, exist_ok=True) # Prepare dataset metadata and copy images dataset_records = [] @@ -989,10 +1004,10 @@ async def create_dataset_from_history(request: CreateDatasetRequest, experimentI for image_item in selected_images: try: # Check if this is a multi-image generation (folder) or single image - if storage.isdir(image_item.image_path): + if await storage.isdir(image_item.image_path): # Multi-image generation - process each image in the folder image_files = [] - for entry in storage.ls(image_item.image_path, detail=False): + for entry in await storage.ls(image_item.image_path, detail=False): filename = entry.rstrip("/").split("/")[-1] if filename.endswith(".png") and filename.replace(".png", "").isdigit(): image_files.append(filename) @@ -1008,8 +1023,8 @@ async def create_dataset_from_history(request: CreateDatasetRequest, experimentI dest_image_path = storage.join(images_dir, dataset_filename) # Copy image file - if storage.exists(src_image_path): - storage.copy_file(src_image_path, dest_image_path) + if await storage.exists(src_image_path): + await storage.copy_file(src_image_path, dest_image_path) else: log_print(f"Warning: Image file not found at {src_image_path}") continue @@ -1055,8 +1070,8 @@ async def create_dataset_from_history(request: CreateDatasetRequest, experimentI dest_image_path = storage.join(images_dir, dataset_filename) # Copy image file - if storage.exists(image_item.image_path): - storage.copy_file(image_item.image_path, dest_image_path) + if await storage.exists(image_item.image_path): + await storage.copy_file(image_item.image_path, dest_image_path) else: log_print(f"Warning: Image file not found at {image_item.image_path}") continue @@ -1099,20 +1114,20 @@ async def create_dataset_from_history(request: CreateDatasetRequest, experimentI if not dataset_records: # Clean up if no images were successfully processed - new_dataset.delete() + await new_dataset.delete() raise HTTPException(status_code=500, detail="Failed to process any images") # Save dataset as JSONL file try: # Make train directory if it doesn't exist - storage.makedirs(images_dir, exist_ok=True) + await storage.makedirs(images_dir, exist_ok=True) dataset_file = storage.join(dataset_dir, "train", "metadata.jsonl") - with storage.open(dataset_file, "w") as f: + async with await storage.open(dataset_file, "w") as f: for record in dataset_records: - f.write(json.dumps(record) + "\n") + await f.write(json.dumps(record) + "\n") except Exception as e: # Clean up on failure - new_dataset.delete() + await new_dataset.delete() raise HTTPException(status_code=500, detail=f"Failed to save dataset: {str(e)}") return JSONResponse( @@ -1160,10 +1175,10 @@ async def get_new_generation_id(experimentId: str, http_request: Request = None) Returns a new unique generation ID and creates the images folder for it. """ generation_id = str(uuid.uuid4()) - workspace_dir = get_workspace_dir() - ensure_directories(experimentId, workspace_dir) - images_folder = storage.join(get_images_dir(experimentId, workspace_dir), generation_id) - storage.makedirs(images_folder, exist_ok=True) + workspace_dir = await get_workspace_dir() + await ensure_directories(experimentId, workspace_dir) + images_folder = storage.join(await get_images_dir(experimentId, workspace_dir), generation_id) + await storage.makedirs(images_folder, exist_ok=True) return {"generation_id": generation_id, "images_folder": images_folder} @@ -1175,18 +1190,18 @@ async def get_file(experimentId: str, generation_id: str): uuid.UUID(sanitized_id) # Validate UUID format except ValueError: raise HTTPException(status_code=400, detail="Invalid generation ID format") - root_dir = get_images_dir(experimentId) + root_dir = await get_images_dir(experimentId) file_path = storage.join(root_dir, sanitized_id, "tmp_json.json") - if not storage.exists(file_path): + if not await storage.exists(file_path): raise HTTPException(status_code=404, detail=f"Output JSON file not found at {file_path}") try: if not file_path.startswith(root_dir): raise HTTPException(status_code=400, detail="Invalid file path") - if not storage.isfile(file_path): + if not await storage.isfile(file_path): raise HTTPException(status_code=404, detail="File not found") - with storage.open(file_path, "r") as f: - data = json.load(f) + async with await storage.open(file_path, "r") as f: + data = json.loads(await f.read()) return JSONResponse(content=data) diff --git a/api/transformerlab/routers/experiment/documents.py b/api/transformerlab/routers/experiment/documents.py index 2d7363778..f1c2f75a0 100644 --- a/api/transformerlab/routers/experiment/documents.py +++ b/api/transformerlab/routers/experiment/documents.py @@ -47,7 +47,7 @@ def is_valid_url(url: str) -> bool: async def document_view(experimentId: str, document_name: str, folder: str = None): try: exp_obj = Experiment(experimentId) - experiment_dir = exp_obj.get_dir() + experiment_dir = await exp_obj.get_dir() document_name = secure_filename(document_name) folder = secure_filename(folder) @@ -69,17 +69,17 @@ async def document_list(experimentId: str, folder: str = None): documents = [] # List the files that are in the experiment//documents directory: exp_obj = Experiment(experimentId) - experiment_dir = exp_obj.get_dir() + experiment_dir = await exp_obj.get_dir() documents_dir = storage.join(experiment_dir, "documents") folder = secure_filename(folder) if folder and folder != "": - if storage.exists(storage.join(documents_dir, folder)): + if await storage.exists(storage.join(documents_dir, folder)): documents_dir = storage.join(documents_dir, folder) else: return {"status": "error", "message": f'Folder "{folder}" not found'} - if storage.exists(documents_dir): + if await storage.exists(documents_dir): try: - entries = storage.ls(documents_dir, detail=True) + entries = await storage.ls(documents_dir, detail=True) except Exception: entries = [] for entry in entries: @@ -93,7 +93,7 @@ async def document_list(experimentId: str, folder: str = None): else: full_path = entry name = os.path.basename(full_path.rstrip("/")) - is_dir = storage.isdir(full_path) + is_dir = await storage.isdir(full_path) size = 0 if is_dir else 0 mtime = None if is_dir: @@ -117,20 +117,20 @@ async def document_new(experimentId: str, dataset_id: str): @router.get("/delete", summary="Delete a document.") async def delete_document(experimentId: str, document_name: str, folder: str = None): exp_obj = Experiment(experimentId) - experiment_dir = exp_obj.get_dir() + experiment_dir = await exp_obj.get_dir() document_name = secure_filename(document_name) path = storage.join(experiment_dir, "documents", document_name) - if folder and folder != "" and not storage.isdir(path): + if folder and folder != "" and not await storage.isdir(path): folder = secure_filename(folder) path = storage.join(experiment_dir, "documents", folder, document_name) else: path = storage.join(experiment_dir, "documents", document_name) # first check if it is a directory: - if storage.isdir(path): - storage.rm_tree(path) - elif storage.exists(path): - storage.rm(path) + if await storage.isdir(path): + await storage.rm_tree(path) + elif await storage.exists(path): + await storage.rm(path) return {"status": "success"} @@ -167,27 +167,28 @@ async def document_upload(experimentId: str, folder: str, files: list[UploadFile # raise HTTPException(status_code=403, detail="The file must be a text file, a JSONL file, or a PDF") exp_obj = Experiment(experimentId) - experiment_dir = exp_obj.get_dir() + experiment_dir = await exp_obj.get_dir() documents_dir = storage.join(experiment_dir, "documents") if folder and folder != "": - if storage.exists(storage.join(documents_dir, folder)): - documents_dir = storage.join(documents_dir, folder) + folder_path = storage.join(documents_dir, folder) + if await storage.exists(folder_path): + documents_dir = folder_path else: - print(f"Creating directory as it doesn't exist: {storage.join(documents_dir, folder)}") - storage.makedirs(storage.join(documents_dir, folder), exist_ok=True) - documents_dir = storage.join(documents_dir, folder) + print(f"Creating directory as it doesn't exist: {folder_path}") + await storage.makedirs(folder_path, exist_ok=True) + documents_dir = folder_path markitdown_dir = storage.join(documents_dir, ".tlab_markitdown") - if not storage.exists(markitdown_dir): - storage.makedirs(markitdown_dir, exist_ok=True) + if not await storage.exists(markitdown_dir): + await storage.makedirs(markitdown_dir, exist_ok=True) if not restricted_file_type: # Save the file to the dataset directory try: content = await file.read() - if not storage.exists(documents_dir): + if not await storage.exists(documents_dir): print("Creating directory") - storage.makedirs(documents_dir, exist_ok=True) + await storage.makedirs(documents_dir, exist_ok=True) newfilename = storage.join(documents_dir, str(file_name)) async with aiofiles.open(newfilename, "wb") as out_file: @@ -252,11 +253,11 @@ async def create_folder(experimentId: str, name: str): # Secure folder name name = secure_filename(name) exp_obj = Experiment(experimentId) - experiment_dir = exp_obj.get_dir() + experiment_dir = await exp_obj.get_dir() path = storage.join(experiment_dir, "documents", name) print(f"Creating folder {path}") - if not storage.exists(path): - storage.makedirs(path, exist_ok=True) + if not await storage.exists(path): + await storage.makedirs(path, exist_ok=True) return {"status": "success"} @@ -265,18 +266,19 @@ async def document_upload_links(experimentId: str, folder: str = None, data: dic urls = data.get("urls") folder = secure_filename(folder) exp_obj = Experiment(experimentId) - experiment_dir = exp_obj.get_dir() + experiment_dir = await exp_obj.get_dir() documents_dir = storage.join(experiment_dir, "documents") if folder and folder != "": - if storage.exists(storage.join(documents_dir, folder)): - documents_dir = storage.join(documents_dir, folder) + folder_path = storage.join(documents_dir, folder) + if await storage.exists(folder_path): + documents_dir = folder_path else: return {"status": "error", "message": f'Folder "{folder}" not found'} markitdown_dir = storage.join(documents_dir, ".tlab_markitdown") - if not storage.exists(markitdown_dir): - storage.makedirs(markitdown_dir, exist_ok=True) + if not await storage.exists(markitdown_dir): + await storage.makedirs(markitdown_dir, exist_ok=True) # Find the next available number for link_X.md files existing_numbers = set() @@ -331,7 +333,7 @@ async def document_download_zip(experimentId: str, data: dict = Body(...)): raise HTTPException(status_code=400, detail="Invalid or unauthorized URL") exp_obj = Experiment(experimentId) - experiment_dir = exp_obj.get_dir() + experiment_dir = await exp_obj.get_dir() documents_dir = storage.join(experiment_dir, "documents") try: diff --git a/api/transformerlab/routers/experiment/evals.py b/api/transformerlab/routers/experiment/evals.py index 9e19b8834..1a7ee60bf 100644 --- a/api/transformerlab/routers/experiment/evals.py +++ b/api/transformerlab/routers/experiment/evals.py @@ -26,7 +26,7 @@ async def experiment_add_evaluation(experimentId: str, plugin: Any = Body()): directory, we can modify the plugin code for the specific experiment without affecting other experiments that use the same plugin.""" - experiment = experiment_get(experimentId) + experiment = await experiment_get(experimentId) if experiment is None: return {"message": f"Experiment {experimentId} does not exist"} @@ -53,7 +53,7 @@ async def experiment_add_evaluation(experimentId: str, plugin: Any = Body()): evaluations.append(evaluation) - experiment_update_config(experimentId, "evaluations", evaluations) + await experiment_update_config(experimentId, "evaluations", evaluations) return {"message": f"Experiment {experimentId} updated with plugin {plugin_name}"} @@ -62,7 +62,7 @@ async def experiment_add_evaluation(experimentId: str, plugin: Any = Body()): async def experiment_delete_eval(experimentId: str, eval_name: str): """Delete an evaluation from an experiment. This will delete the directory in the experiment and remove the global plugin from the specific experiment.""" - experiment = experiment_get(experimentId) + experiment = await experiment_get(experimentId) if experiment is None: return {"message": f"Experiment {experimentId} does not exist"} @@ -77,7 +77,7 @@ async def experiment_delete_eval(experimentId: str, eval_name: str): # remove the evaluation from the list: evaluations = [e for e in evaluations if e["name"] != eval_name] - experiment_update_config(experimentId, "evaluations", evaluations) + await experiment_update_config(experimentId, "evaluations", evaluations) return {"message": f"Evaluation {eval_name} deleted from experiment {experimentId}"} @@ -89,7 +89,7 @@ async def experiment_delete_eval(experimentId: str, eval_name: str): async def edit_evaluation_task(experimentId: str, plugin: Any = Body()): """Get the contents of the evaluation""" try: - experiment = experiment_get(experimentId) + experiment = await experiment_get(experimentId) # if the experiment does not exist, return an error: if experiment is None: @@ -133,7 +133,7 @@ async def edit_evaluation_task(experimentId: str, plugin: Any = Body()): @router.get("/get_evaluation_plugin_file_contents") async def get_evaluation_plugin_file_contents(experimentId: str, plugin_name: str): # first get the experiment name: - data = experiment_get(experimentId) + data = await experiment_get(experimentId) # if the experiment does not exist, return an error: if data is None: @@ -144,7 +144,7 @@ async def get_evaluation_plugin_file_contents(experimentId: str, plugin_name: st # print(f"{EXPERIMENTS_DIR}/{experiment_name}/evals/{eval_name}/main.py") file_name = "main.py" - plugin_path = lab_dirs.plugin_dir_by_name(plugin_name) + plugin_path = await lab_dirs.plugin_dir_by_name(plugin_name) # now get the file contents try: @@ -165,7 +165,7 @@ async def run_evaluation_script( org_id: Optional[str] = None, user_id: Optional[str] = None, ): - job_config_raw = (job_get(job_id))["job_data"] + job_config_raw = (await job_get(job_id))["job_data"] # Ensure job_config is a dict if isinstance(job_config_raw, str): try: @@ -176,7 +176,7 @@ async def run_evaluation_script( job_config = job_config_raw eval_config = job_config.get("config", {}) print(eval_config) - experiment_details = experiment_get(id=experimentId) + experiment_details = await experiment_get(id=experimentId) if experiment_details is None: return {"message": f"Experiment {experimentId} does not exist"} @@ -208,7 +208,8 @@ async def run_evaluation_script( from lab.dirs import get_temp_dir plugin_json_file = "plugin_input_" + secure_filename(str(plugin_name)) + ".json" - input_file = storage.join(get_temp_dir(), plugin_json_file) + temp_dir = await get_temp_dir() + input_file = storage.join(temp_dir, plugin_json_file) # The following two ifs convert nested JSON strings to JSON objects -- this is a hack # and should be done in the API itself @@ -231,14 +232,14 @@ async def run_evaluation_script( job_output_file = await shared.get_job_output_file_name(job_id, plugin_name, experimentId) input_contents = {"experiment": experiment_details, "config": template_config} - with storage.open(input_file, "w") as outfile: - json.dump(input_contents, outfile, indent=4) + async with await storage.open(input_file, "w") as outfile: + await outfile.write(json.dumps(input_contents, indent=4)) # For now, even though we have the file above, we are also going to pass all params # as command line arguments to the script. # Create a list of all the parameters: - script_directory = lab_dirs.plugin_dir_by_name(plugin_name) + script_directory = await lab_dirs.plugin_dir_by_name(plugin_name) extra_args = ["--plugin_dir", script_directory] for key in template_config: extra_args.append("--" + key) @@ -298,7 +299,7 @@ async def run_evaluation_script( # Try to get org_id from workspace path from lab.dirs import get_workspace_dir - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() if "/orgs/" in workspace_dir: team_id = workspace_dir.split("/orgs/")[-1].split("/")[0] @@ -317,30 +318,30 @@ async def run_evaluation_script( process_env = os.environ.copy() process_env["_TFL_USER_ID"] = resolved_user_id - with storage.open(job_output_file, "w") as f: + async with await storage.open(job_output_file, "w") as f: process = await asyncio.create_subprocess_exec( *subprocess_command, stdout=f, stderr=subprocess.PIPE, env=process_env ) await process.communicate() - with storage.open(output_file, "w") as f: + async with await storage.open(output_file, "w") as f: # Copy all contents from job_output_file to output_file - with storage.open(job_output_file, "r") as job_output: - for line in job_output: - f.write(line) + async with await storage.open(job_output_file, "r") as job_output: + async for line in job_output: + await f.write(line) @router.get("/get_output") async def get_output(experimentId: str, eval_name: str): """Get the output of an evaluation""" eval_name = secure_filename(eval_name) # sanitize the input - data = experiment_get(experimentId) + data = await experiment_get(experimentId) # if the experiment does not exist, return an error: if data is None: return {"message": f"Experiment {experimentId} does not exist"} eval_output_file = await lab_dirs.eval_output_file(experimentId, eval_name) - if not storage.exists(eval_output_file): + if not await storage.exists(eval_output_file): return {"message": "Output file does not exist"} print(f"Returning output file: {eval_output_file}.") diff --git a/api/transformerlab/routers/experiment/experiment.py b/api/transformerlab/routers/experiment/experiment.py index 62eb22cc4..d0c4d3c9e 100644 --- a/api/transformerlab/routers/experiment/experiment.py +++ b/api/transformerlab/routers/experiment/experiment.py @@ -42,23 +42,23 @@ @router.get("/", summary="Get all Experiments", tags=["experiment"]) -def experiments_get_all(): +async def experiments_get_all(): """Get a list of all experiments""" - return experiment_service.experiment_get_all() + return await experiment_service.experiment_get_all() @router.get("/create", summary="Create Experiment", tags=["experiment"]) -def experiments_create(name: str): +async def experiments_create(name: str): # Apply secure filename validation to the experiment name secure_name = secure_filename(name) - newid = experiment_service.experiment_create(secure_name, {}) + newid = await experiment_service.experiment_create(secure_name, {}) return newid @router.get("/{id}", summary="Get Experiment by ID", tags=["experiment"]) -def experiment_get(id: str): - data = experiment_service.experiment_get(id) +async def experiment_get(id: str): + data = await experiment_service.experiment_get(id) if data is None: return {"status": "error", "message": f"Experiment {id} does not exist"} @@ -68,37 +68,37 @@ def experiment_get(id: str): @router.get("/{id}/delete", tags=["experiment"]) -def experiments_delete(id: str): - experiment_service.experiment_delete(id) +async def experiments_delete(id: str): + await experiment_service.experiment_delete(id) return {"message": f"Experiment {id} deleted"} @router.get("/{id}/update", tags=["experiment"]) -def experiments_update(id: str, name: str): - experiment_service.experiment_update(id, name) +async def experiments_update(id: str, name: str): + await experiment_service.experiment_update(id, name) return {"message": f"Experiment {id} updated to {name}"} @router.get("/{id}/update_config", tags=["experiment"]) -def experiments_update_config(id: str, key: str, value: str): - experiment_service.experiment_update_config(id, key, value) +async def experiments_update_config(id: str, key: str, value: str): + await experiment_service.experiment_update_config(id, key, value) return {"message": f"Experiment {id} updated"} @router.post("/{id}/update_configs", tags=["experiment"]) -def experiments_update_configs(id: str, updates: Annotated[dict, Body()]): - experiment_service.experiment_update_configs(id, updates) +async def experiments_update_configs(id: str, updates: Annotated[dict, Body()]): + await experiment_service.experiment_update_configs(id, updates) return {"message": f"Experiment {id} configs updated"} @router.post("/{id}/prompt", tags=["experiment"]) -def experiments_save_prompt_template(id: str, template: Annotated[str, Body()]): - experiment_service.experiment_save_prompt_template(id, template) +async def experiments_save_prompt_template(id: str, template: Annotated[str, Body()]): + await experiment_service.experiment_save_prompt_template(id, template) return {"message": f"Experiment {id} prompt template saved"} @router.post("/{id}/save_file_contents", tags=["experiment"]) -def experiment_save_file_contents(id: str, filename: str, file_contents: Annotated[str, Body()]): +async def experiment_save_file_contents(id: str, filename: str, file_contents: Annotated[str, Body()]): filename = secure_filename(filename) # remove file extension from file: @@ -110,8 +110,8 @@ def experiment_save_file_contents(id: str, filename: str, file_contents: Annotat # clean the file name: filename = shared.slugify(filename) - exp_obj = Experiment.get(id) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(id) + experiment_dir = await exp_obj.get_dir() # For remote paths, use storage.join which handles remote URIs properly file_path = storage.join(experiment_dir, f"{filename}{file_ext}") @@ -120,18 +120,18 @@ def experiment_save_file_contents(id: str, filename: str, file_contents: Annotat return {"message": "Invalid file path - path traversal detected"} # Save the file contents securely - with storage.open(file_path, "w", encoding="utf-8") as f: - f.write(file_contents) + async with await storage.open(file_path, "w", encoding="utf-8") as f: + await f.write(file_contents) return {"message": f"{file_path} file contents saved"} @router.get("/{id}/file_contents", tags=["experiment"]) -def experiment_get_file_contents(id: str, filename: str): +async def experiment_get_file_contents(id: str, filename: str): filename = secure_filename(filename) - exp_obj = Experiment.get(id) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(id) + experiment_dir = await exp_obj.get_dir() # remove file extension from file: [filename, file_ext] = os.path.splitext(filename) @@ -154,8 +154,8 @@ def experiment_get_file_contents(id: str, filename: str): # now get the file contents try: - with storage.open(final_path, "r") as f: - file_contents = f.read() + async with await storage.open(final_path, "r") as f: + file_contents = await f.read() except FileNotFoundError: return "" @@ -163,11 +163,11 @@ def experiment_get_file_contents(id: str, filename: str): @router.get("/{id}/export_to_recipe", summary="Export experiment to recipe format", tags=["experiment"]) -def export_experiment_to_recipe(id: str, request: Request): +async def export_experiment_to_recipe(id: str, request: Request): """Export an experiment to JSON format that matches the recipe gallery structure.""" # Get experiment data - data = experiment_service.experiment_get(id) + data = await experiment_service.experiment_get(id) if data is None: return {"status": "error", "message": f"Experiment {id} does not exist"} @@ -185,12 +185,12 @@ def export_experiment_to_recipe(id: str, request: Request): } # Get the notes content from readme.md if it exists - exp_obj = Experiment.get(id) - experiment_dir = exp_obj.get_dir() + exp_obj = await Experiment.get(id) + experiment_dir = await exp_obj.get_dir() notes_path = storage.join(experiment_dir, "readme.md") try: - with storage.open(notes_path, "r") as f: - export_data["notes"] = f.read() + async with await storage.open(notes_path, "r") as f: + export_data["notes"] = await f.read() except FileNotFoundError: # If no notes file exists, leave it as empty string pass @@ -198,16 +198,17 @@ def export_experiment_to_recipe(id: str, request: Request): # Track unique dependencies to avoid duplicates added_dependencies = set() - def add_dependency(dep_type: str, dep_name: str): + async def add_dependency(dep_type: str, dep_name: str): """Helper function to add a dependency if it's not already added""" dep_key = f"{dep_type}:{dep_name}" if dep_key not in added_dependencies and dep_name: # For datasets, check if it's generated and skip if it is if dep_type == "dataset": try: - dataset_info = Dataset.get(dep_name) + dataset_info = await Dataset.get(dep_name) if dataset_info: - json_data = dataset_info.get("json_data", "{}") + json_data = await dataset_info.get_metadata() + json_data = json_data.get("json_data", {}) if not isinstance(json_data, dict): json_data = json.loads(json_data) if json_data.get("generated", False): @@ -226,7 +227,7 @@ def add_dependency(dep_type: str, dep_name: str): for task_type in task_types: from transformerlab.services.tasks_service import tasks_service - tasks = tasks_service.tasks_get_by_type_in_experiment(task_type, id) + tasks = await tasks_service.tasks_get_by_type_in_experiment(task_type, id) for task in tasks: if not isinstance(task["config"], dict): task_config = json.loads(task["config"]) @@ -240,17 +241,17 @@ def add_dependency(dep_type: str, dep_name: str): else: model_name = task_config.get("model_name") if model_name: - add_dependency("model", model_name) + await add_dependency("model", model_name) # Add dataset dependency from task dataset_name = task_config.get("dataset_name") if dataset_name: - add_dependency("dataset", dataset_name) + await add_dependency("dataset", dataset_name) # Add plugin dependency plugin_name = task_config.get("plugin_name") if plugin_name: - add_dependency("plugin", plugin_name) + await add_dependency("plugin", plugin_name) # Add task to tasks list with its configuration export_data["tasks"].append( @@ -273,9 +274,9 @@ def add_dependency(dep_type: str, dep_name: str): export_data["workflows"] = [] # Write to file in the workspace directory (org-aware via request context) - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() output_file = storage.join(workspace_dir, f"{data['name']}_export.json") - with storage.open(output_file, "w") as f: - json.dump(export_data, f, indent=2) + async with await storage.open(output_file, "w") as f: + await f.write(json.dumps(export_data, indent=2)) return FileResponse(output_file, filename=output_file) diff --git a/api/transformerlab/routers/experiment/export.py b/api/transformerlab/routers/experiment/export.py index 63e30bc42..6b3992c05 100644 --- a/api/transformerlab/routers/experiment/export.py +++ b/api/transformerlab/routers/experiment/export.py @@ -39,7 +39,7 @@ async def run_exporter_script( """ # Load experiment details into config - experiment_details = experiment_get(id=id) + experiment_details = await experiment_get(id=id) if experiment_details is None: return {"message": f"Experiment {id} does not exist"} @@ -95,13 +95,14 @@ async def run_exporter_script( output_filename = "." # Figure out plugin and model output directories - script_directory = lab_dirs.plugin_dir_by_name(plugin_name) + script_directory = await lab_dirs.plugin_dir_by_name(plugin_name) output_model_id = secure_filename(output_model_id) from lab.dirs import get_models_dir - output_path = storage.join(get_models_dir(), output_model_id) + models_dir = await get_models_dir() + output_path = storage.join(models_dir, output_model_id) # Create a job in the DB with the details of this export (only if job_id not provided) if job_id is None: @@ -117,7 +118,9 @@ async def run_exporter_script( params=params, ) job_data_json = json.dumps(job_data) - job_id = job_create(type="EXPORT", status="Started", experiment_id=experiment_name, job_data=job_data_json) + job_id = await job_create( + type="EXPORT", status="Started", experiment_id=experiment_name, job_data=job_data_json + ) return job_id # Setup arguments to pass to plugin @@ -155,7 +158,7 @@ async def run_exporter_script( # Try to get org_id from workspace path from lab.dirs import get_workspace_dir - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() if "/orgs/" in workspace_dir: team_id = workspace_dir.split("/orgs/")[-1].split("/")[0] @@ -186,7 +189,7 @@ async def run_exporter_script( job_output_file = await shared.get_job_output_file_name(job_id, experiment_name=experiment_name) # Create the output file and run the process with output redirection - with storage.open(job_output_file, "w") as f: + async with await storage.open(job_output_file, "w") as f: process = await asyncio.create_subprocess_exec( *subprocess_command, stdout=f, stderr=subprocess.PIPE, cwd=script_directory, env=process_env ) @@ -199,10 +202,10 @@ async def run_exporter_script( if stderr_str.strip(): print(f"Error: {stderr_str}") - f.write(f"\nError:\n{stderr_str}") + await f.write(f"\nError:\n{stderr_str}") if process.returncode != 0: - job = job_get(job_id) + job = await job_get(job_id) experiment_id = job["experiment_id"] await job_update_status(job_id=job_id, status="FAILED", experiment_id=experiment_id) return { @@ -212,7 +215,7 @@ async def run_exporter_script( except Exception as e: print(f"Failed to export model. Exception: {e}") - job = job_get(job_id) + job = await job_get(job_id) experiment_id = job["experiment_id"] await job_update_status(job_id=job_id, status="FAILED", experiment_id=experiment_id) return {"message": "Failed to export model due to an internal error."} @@ -237,7 +240,7 @@ async def run_exporter_script( }, } model_description_file_path = storage.join(output_path, "index.json") - with storage.open(model_description_file_path, "w") as model_description_file: - json.dump(model_description, model_description_file) + async with await storage.open(model_description_file_path, "w") as model_description_file: + await model_description_file.write(json.dumps(model_description, indent=2)) return {"status": "success", "job_id": job_id} diff --git a/api/transformerlab/routers/experiment/generations.py b/api/transformerlab/routers/experiment/generations.py index 8afd3dcdc..264b306d4 100644 --- a/api/transformerlab/routers/experiment/generations.py +++ b/api/transformerlab/routers/experiment/generations.py @@ -26,7 +26,7 @@ async def experiment_add_generation(experimentId: str, plugin: Any = Body()): directory, we can modify the plugin code for the specific experiment without affecting other experiments that use the same plugin.""" - experiment = experiment_get(experimentId) + experiment = await experiment_get(experimentId) if experiment is None: return {"message": f"Experiment {experimentId} does not exist"} @@ -53,7 +53,7 @@ async def experiment_add_generation(experimentId: str, plugin: Any = Body()): generations.append(generation) - experiment_update_config(experimentId, "generations", generations) + await experiment_update_config(experimentId, "generations", generations) return {"message": f"Experiment {experimentId} updated with plugin {plugin_name}"} @@ -64,7 +64,7 @@ async def experiment_delete_generation(experimentId: str, generation_name: str): and remove the global plugin from the specific experiment.""" try: print("Deleting generation", experimentId, generation_name) - experiment = experiment_get(experimentId) + experiment = await experiment_get(experimentId) if experiment is None: return {"message": f"Experiment {experimentId} does not exist"} @@ -79,7 +79,7 @@ async def experiment_delete_generation(experimentId: str, generation_name: str): # remove the generation from the list: generations = [e for e in generations if e["name"] != generation_name] - experiment_update_config(experimentId, "generations", generations) + await experiment_update_config(experimentId, "generations", generations) return {"message": f"Generation {generations} deleted from experiment {experimentId}"} except Exception as e: @@ -94,7 +94,7 @@ async def experiment_delete_generation(experimentId: str, generation_name: str): async def edit_evaluation_generation(experimentId: str, plugin: Any = Body()): """Get the contents of the generation""" try: - experiment = experiment_get(experimentId) + experiment = await experiment_get(experimentId) # if the experiment does not exist, return an error: if experiment is None: @@ -127,7 +127,7 @@ async def edit_evaluation_generation(experimentId: str, plugin: Any = Body()): generation["script_parameters"] = updated_json generation["name"] = template_name - experiment_update_config(experimentId, "generations", generations) + await experiment_update_config(experimentId, "generations", generations) return {"message": "OK"} except Exception as e: @@ -138,7 +138,7 @@ async def edit_evaluation_generation(experimentId: str, plugin: Any = Body()): @router.get("/get_generation_plugin_file_contents") async def get_generation_plugin_file_contents(experimentId: str, plugin_name: str): # first get the experiment name: - data = experiment_get(experimentId) + data = await experiment_get(experimentId) # if the experiment does not exist, return an error: if data is None: @@ -149,7 +149,7 @@ async def get_generation_plugin_file_contents(experimentId: str, plugin_name: st # print(f"{EXPERIMENTS_DIR}/{experiment_name}/generation/{generation_name}/main.py") file_name = "main.py" - plugin_path = lab_dirs.plugin_dir_by_name(plugin_name) + plugin_path = await lab_dirs.plugin_dir_by_name(plugin_name) # now get the file contents try: @@ -170,7 +170,7 @@ async def run_generation_script( org_id: Optional[str] = None, user_id: Optional[str] = None, ): - job_config_raw = (job_get(job_id))["job_data"] + job_config_raw = (await job_get(job_id))["job_data"] # Ensure job_config is a dict if isinstance(job_config_raw, str): try: @@ -179,12 +179,13 @@ async def run_generation_script( job_config = {} else: job_config = job_config_raw + generation_config = job_config.get("config", {}) print(generation_config) plugin_name = secure_filename(plugin_name) generation_name = secure_filename(generation_name) - experiment_details = experiment_get(id=experimentId) + experiment_details = await experiment_get(id=experimentId) if experiment_details is None: return {"message": f"Experiment {experimentId} does not exist"} @@ -213,7 +214,8 @@ async def run_generation_script( # Create the input file for the script: from lab.dirs import get_temp_dir - input_file = storage.join(get_temp_dir(), "plugin_input_" + str(plugin_name) + ".json") + temp_dir = await get_temp_dir() + input_file = storage.join(temp_dir, "plugin_input_" + str(plugin_name) + ".json") # The following two ifs convert nested JSON strings to JSON objects -- this is a hack # and should be done in the API itself @@ -236,14 +238,14 @@ async def run_generation_script( job_output_file = await shared.get_job_output_file_name(job_id, plugin_name, experimentId) input_contents = {"experiment": experiment_details, "config": template_config} - with storage.open(input_file, "w") as outfile: - json.dump(input_contents, outfile, indent=4) + async with await storage.open(input_file, "w") as outfile: + await outfile.write(json.dumps(input_contents, indent=4)) # For now, even though we have the file above, we are also going to pass all params # as command line arguments to the script. # Create a list of all the parameters: - script_directory = lab_dirs.plugin_dir_by_name(plugin_name) + script_directory = await lab_dirs.plugin_dir_by_name(plugin_name) extra_args = ["--plugin_dir", script_directory] for key in template_config: extra_args.append("--" + key) @@ -297,7 +299,7 @@ async def run_generation_script( # Try to get org_id from workspace path from lab.dirs import get_workspace_dir - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() if "/orgs/" in workspace_dir: team_id = workspace_dir.split("/orgs/")[-1].split("/")[0] @@ -316,17 +318,17 @@ async def run_generation_script( process_env = os.environ.copy() process_env["_TFL_USER_ID"] = resolved_user_id - with storage.open(job_output_file, "w") as f: + async with await storage.open(job_output_file, "w") as f: process = await asyncio.create_subprocess_exec( *subprocess_command, stdout=f, stderr=subprocess.PIPE, env=process_env ) await process.communicate() - with storage.open(output_file, "w") as f: + async with await storage.open(output_file, "w") as f: # Copy all contents from job_output_file to output_file - with storage.open(job_output_file, "r") as job_output: - for line in job_output: - f.write(line) + async with await storage.open(job_output_file, "r") as job_output: + async for line in job_output: + await f.write(line) @router.get("/get_output") @@ -337,7 +339,7 @@ async def get_output(experimentId: str, generation_name: str): generation_name = urllib.parse.unquote(generation_name) generation_output_file = await lab_dirs.generation_output_file(experimentId, generation_name) - if not storage.exists(generation_output_file): + if not await storage.exists(generation_output_file): return {"message": "Output file does not exist"} print(f"Returning output file: {generation_output_file}.") diff --git a/api/transformerlab/routers/experiment/jobs.py b/api/transformerlab/routers/experiment/jobs.py index 3f52fc29c..fa0f7090e 100644 --- a/api/transformerlab/routers/experiment/jobs.py +++ b/api/transformerlab/routers/experiment/jobs.py @@ -36,7 +36,7 @@ @router.get("/list") async def jobs_get_all(experimentId: str, type: str = "", status: str = "", subtype: str = ""): - jobs = job_service.jobs_get_all(type=type, status=status, experiment_id=experimentId) + jobs = await job_service.jobs_get_all(type=type, status=status, experiment_id=experimentId) # Optional filter by job_data.subtype if subtype: @@ -57,7 +57,7 @@ async def jobs_get_all(experimentId: str, type: str = "", status: str = "", subt @router.get("/delete/{job_id}") async def job_delete(job_id: str, experimentId: str): - job_service.job_delete(job_id, experiment_id=experimentId) + await job_service.job_delete(job_id, experiment_id=experimentId) return {"message": "OK"} @@ -68,12 +68,14 @@ async def job_create( status: str = "CREATED", data: str = "{}", ): - jobid = job_service.job_create(type=type, status=status, job_data=data, experiment_id=experimentId) + jobid = await job_service.job_create(type=type, status=status, job_data=data, experiment_id=experimentId) return jobid async def job_create_task(script: str, job_data: str = "{}", experimentId: str = None): - jobid = job_service.job_create(type="UNDEFINED", status="CREATED", job_data=job_data, experiment_id=experimentId) + jobid = await job_service.job_create( + type="UNDEFINED", status="CREATED", job_data=job_data, experiment_id=experimentId + ) return jobid @@ -85,12 +87,12 @@ async def job_update(job_id: str, status: str, experimentId: str): async def start_next_job(): # Count running jobs across all organizations - num_running_jobs = job_service.job_count_running_across_all_orgs() + num_running_jobs = await job_service.job_count_running_across_all_orgs() if num_running_jobs > 0: return {"message": "A job is already running"} # Get next queued job across all organizations - nextjob, org_id = job_service.jobs_get_next_queued_job_across_all_orgs() + nextjob, org_id = await job_service.jobs_get_next_queued_job_across_all_orgs() if nextjob: print(f"Starting Next Job in Queue: {nextjob}") @@ -133,19 +135,19 @@ async def start_next_job(): async def stop_job(job_id: str, experimentId: str): # The way a job is stopped is simply by adding "stop: true" to the job_data # This will be checked by the plugin as it runs - job_service.job_stop(job_id, experiment_id=experimentId) + await job_service.job_stop(job_id, experiment_id=experimentId) return {"message": "OK"} @router.get("/delete_all") async def job_delete_all(experimentId: str): - job_service.job_delete_all(experiment_id=experimentId) + await job_service.job_delete_all(experiment_id=experimentId) return {"message": "OK"} @router.get("/{job_id}") async def get_training_job(job_id: str): - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if job is None: return Response("Job not found", status_code=404) return job @@ -158,7 +160,7 @@ async def get_tasks_job_output(job_id: str, sweeps: bool = False): Uses the same logic as stream_job_output but returns content directly. """ try: - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if job is None: return "Job not found" @@ -175,7 +177,7 @@ async def get_tasks_job_output(job_id: str, sweeps: bool = False): # Handle sweeps case first if sweeps: output_file = job_data.get("sweep_output_file", None) - if output_file is not None and storage.exists(output_file): + if output_file is not None and await storage.exists(output_file): output_file_name = output_file else: # Fall back to regular output file logic @@ -185,10 +187,10 @@ async def get_tasks_job_output(job_id: str, sweeps: bool = False): output_file_name = await shared.get_job_output_file_name(job_id) # Read and return the file content as JSON array of lines - if storage.exists(output_file_name): + if await storage.exists(output_file_name): lines = [] - with storage.open(output_file_name, "r") as f: - for line in f: + async with await storage.open(output_file_name, "r") as f: + async for line in f: lines.append(line.rstrip("\n")) # Remove trailing newline return lines else: @@ -202,10 +204,10 @@ async def get_tasks_job_output(job_id: str, sweeps: bool = False): await asyncio.sleep(4) try: output_file_name = await shared.get_job_output_file_name(job_id) - if storage.exists(output_file_name): + if await storage.exists(output_file_name): lines = [] - with storage.open(output_file_name, "r") as f: - for line in f: + async with await storage.open(output_file_name, "r") as f: + async for line in f: lines.append(line.rstrip("\n")) # Remove trailing newline return lines else: @@ -215,12 +217,12 @@ async def get_tasks_job_output(job_id: str, sweeps: bool = False): print(f"Still no output file found for job {job_id} after retry, creating empty file: {retry_e}") # Use the Job class to get the proper directory and create the file job_obj = Job(job_id) - output_file_name = job_obj.get_log_path() + output_file_name = await job_obj.get_log_path() # Get directory by removing filename from path using storage.join output_dir = storage.join(*output_file_name.split("/")[:-1]) if "/" in output_file_name else "." - storage.makedirs(output_dir, exist_ok=True) - with storage.open(output_file_name, "w") as f: - f.write("") + await storage.makedirs(output_dir, exist_ok=True) + async with await storage.open(output_file_name, "w") as f: + await f.write("") return [] else: print(f"ValueError in get_tasks_job_output: {e}") @@ -243,7 +245,7 @@ async def get_provider_job_logs( Fetch the raw job logs directly from the underlying compute provider for a REMOTE job. """ - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if not job or str(job.get("experiment_id")) != str(experimentId): raise HTTPException(status_code=404, detail="Job not found") @@ -475,7 +477,7 @@ async def stream_job_output(job_id: str, sweeps: bool = False): Enhanced version combining the best of both train and jobs routers. """ try: - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) job_data = job.get("job_data", {}) @@ -490,7 +492,7 @@ async def stream_job_output(job_id: str, sweeps: bool = False): # Handle sweeps case first if sweeps: output_file = job_data.get("sweep_output_file", None) - if output_file is not None and storage.exists(output_file): + if output_file is not None and await storage.exists(output_file): output_file_name = output_file else: # Fall back to regular output file logic @@ -512,12 +514,12 @@ async def stream_job_output(job_id: str, sweeps: bool = False): print(f"Still no output file found for job {job_id} after retry, creating empty file: {retry_e}") # Use the Job class to get the proper directory and create the file job_obj = Job(job_id) - output_file_name = job_obj.get_log_path() + output_file_name = await job_obj.get_log_path() # Get directory by removing filename from path using storage.join output_dir = storage.join(*output_file_name.split("/")[:-1]) if "/" in output_file_name else "." - storage.makedirs(output_dir, exist_ok=True) - with storage.open(output_file_name, "w") as f: - f.write("") + await storage.makedirs(output_dir, exist_ok=True) + async with await storage.open(output_file_name, "w") as f: + await f.write("") else: print(f"ValueError in stream_job_output: {e}") return StreamingResponse( @@ -544,7 +546,7 @@ async def stream_job_output(job_id: str, sweeps: bool = False): @router.get("/{job_id}/stream_detailed_json_report") async def stream_detailed_json_report(job_id: str, file_name: str): - if not storage.exists(file_name): + if not await storage.exists(file_name): print(f"File not found: {file_name}") return "File not found", 404 @@ -558,7 +560,7 @@ async def stream_detailed_json_report(job_id: str, file_name: str): @router.get("/{job_id}/get_additional_details") async def stream_job_additional_details(job_id: str, task: str = "view"): - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if job is None: return Response("Job not found", status_code=404) job_data = job["job_data"] @@ -574,13 +576,14 @@ async def stream_job_additional_details(job_id: str, task: str = "view"): if task == "download": return FileResponse(file_path, filename=filename, media_type=file_format) - if not storage.exists(file_path): + if not await storage.exists(file_path): return Response("No additional details found for this evaluation", media_type="text/csv") # convert csv to JSON, but do not assume that \n marks the end of a row as cells can # contain fields that start and end with " and contain \n. Use a CSV parser instead. - with storage.open(file_path, "r") as csvfile: - contents = csv.reader(csvfile, delimiter=",", quotechar='"') + async with await storage.open(file_path, "r") as csvfile: + contents_text = await csvfile.read() + contents = csv.reader(contents_text.splitlines(), delimiter=",", quotechar='"') # convert the csv to a JSON object csv_content = {"header": [], "body": []} for i, row in enumerate(contents): @@ -593,23 +596,24 @@ async def stream_job_additional_details(job_id: str, task: str = "view"): @router.get("/{job_id}/get_figure_json") async def get_figure_path(job_id: str): - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if job is None: return Response("Job not found", status_code=404) job_data = job["job_data"] file_path = job_data.get("plot_data_path", None) - if file_path is None or not storage.exists(file_path): + if file_path is None or not await storage.exists(file_path): return Response("No plot data found for this evaluation", media_type="text/csv") - with storage.open(file_path, "r") as f: - content = json.loads(f.read()) + async with await storage.open(file_path, "r") as f: + content_str = await f.read() + content = json.loads(content_str) return content @router.get("/{job_id}/get_generated_dataset") async def get_generated_dataset(job_id: str): - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if job is None: return Response("Job not found", status_code=404) # Get experiment name @@ -621,11 +625,12 @@ async def get_generated_dataset(job_id: str): else: return Response("No dataset found for this evaluation", media_type="text/csv") - if not storage.exists(json_file_path): + if not await storage.exists(json_file_path): return Response("No dataset found for this evaluation", media_type="text/csv") else: - with storage.open(json_file_path, "r") as f: - json_content = json.loads(f.read()) + async with await storage.open(json_file_path, "r") as f: + json_content_str = await f.read() + json_content = json.loads(json_content_str) df = pd.DataFrame(json_content) @@ -637,7 +642,7 @@ async def get_generated_dataset(job_id: str): @router.get("/{job_id}/get_eval_results") async def get_eval_results(job_id: str, task: str = "view", file_index: int = 0): """Get evaluation results for a job""" - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if job is None: return Response("Job not found", status_code=404) job_data = job["job_data"] @@ -655,7 +660,7 @@ async def get_eval_results(job_id: str, task: str = "view", file_index: int = 0) file_index = 0 file_path = eval_results_list[file_index] - if not storage.exists(file_path): + if not await storage.exists(file_path): return Response("Evaluation results file not found", media_type="text/csv") # Determine file format @@ -671,10 +676,10 @@ async def get_eval_results(job_id: str, task: str = "view", file_index: int = 0) if task == "download": # Use StreamingResponse to support both local and remote files - def generate(): - with storage.open(file_path, "rb") as f: + async def generate(): + async with await storage.open(file_path, "rb") as f: while True: - chunk = f.read(8192) # Read in 8KB chunks + chunk = await f.read(8192) # Read in 8KB chunks if not chunk: break yield chunk @@ -690,8 +695,9 @@ def generate(): # For view, convert CSV to JSON format if file_path.endswith(".csv"): - with storage.open(file_path, "r") as csvfile: - contents = csv.reader(csvfile, delimiter=",", quotechar='"') + async with await storage.open(file_path, "r") as csvfile: + content_str = await csvfile.read() + contents = csv.reader(content_str.splitlines(), delimiter=",", quotechar='"') csv_content = {"header": [], "body": []} for i, row in enumerate(contents): if i == 0: @@ -700,8 +706,9 @@ def generate(): csv_content["body"].append(row) return csv_content elif file_path.endswith(".json"): - with storage.open(file_path, "r") as jsonfile: - content = json.load(jsonfile) + async with await storage.open(file_path, "r") as jsonfile: + content_str = await jsonfile.read() + content = json.loads(content_str) # If it's a list of records, convert to header/body format if isinstance(content, list) and len(content) > 0: if isinstance(content[0], dict): @@ -711,14 +718,14 @@ def generate(): return content else: # For other file types, just return as text - with storage.open(file_path, "r") as f: - return f.read() + async with await storage.open(file_path, "r") as f: + return await f.read() @router.get("/{job_id}/get_eval_images") async def get_eval_images(job_id: str): """Get list of evaluation images for a job""" - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if job is None: return Response("Job not found", status_code=404) job_data = job["job_data"] @@ -729,7 +736,7 @@ async def get_eval_images(job_id: str): images_dir = job_data["eval_images_dir"] - if not storage.exists(images_dir): + if not await storage.exists(images_dir): return {"images": []} # Supported image extensions @@ -737,7 +744,7 @@ async def get_eval_images(job_id: str): images = [] try: # Use storage.ls to list directory contents - items = storage.ls(images_dir, detail=True) + items = await storage.ls(images_dir, detail=True) for item in items: # Handle both dict (detail=True) and string (detail=False) formats if isinstance(item, dict): @@ -758,13 +765,13 @@ async def get_eval_images(job_id: str): else: # Fallback for string format - check if it's a file file_path = item if isinstance(item, str) else str(item) - if storage.isfile(file_path): + if await storage.isfile(file_path): filename = file_path.split("/")[-1] if "/" in file_path else file_path _, ext = os.path.splitext(filename.lower()) if ext in image_extensions: # Try to get file info - for remote storage, stats might not be available try: - items_detail = storage.ls(file_path, detail=True) + items_detail = await storage.ls(file_path, detail=True) if items_detail and isinstance(items_detail[0], dict): file_info = items_detail[0] images.append( @@ -805,7 +812,7 @@ async def get_eval_images(job_id: str): @router.get("/{job_id}/image/{filename}") async def get_eval_image(job_id: str, filename: str): """Serve individual evaluation image files""" - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if job is None: return Response("Job not found", status_code=404) job_data = job["job_data"] @@ -816,7 +823,7 @@ async def get_eval_image(job_id: str, filename: str): images_dir = job_data["eval_images_dir"] - if not storage.exists(images_dir): + if not await storage.exists(images_dir): return Response("Images directory not found", status_code=404) # Secure the filename to prevent directory traversal @@ -824,7 +831,7 @@ async def get_eval_image(job_id: str, filename: str): file_path = storage.join(images_dir, filename) # Ensure the file exists - if not storage.exists(file_path): + if not await storage.exists(file_path): return Response("Image not found", status_code=404) # For security, verify the file path is within the images directory @@ -864,7 +871,7 @@ async def get_checkpoints(job_id: str, request: Request): return {"checkpoints": []} """Get list of checkpoints for a job""" - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if job is None: return {"checkpoints": []} @@ -881,7 +888,7 @@ async def get_checkpoints(job_id: str, request: Request): checkpoints = [] for checkpoint_path in checkpoint_paths: try: - if storage.isdir(checkpoint_path): + if await storage.isdir(checkpoint_path): # Don't set formatted_time and filesize for directories formatted_time = None filesize = None @@ -889,7 +896,7 @@ async def get_checkpoints(job_id: str, request: Request): # Try to get file info from storage try: # Use storage.ls to get file details if available - file_info_list = storage.ls(checkpoint_path, detail=True) + file_info_list = await storage.ls(checkpoint_path, detail=True) if file_info_list and isinstance(file_info_list, dict): file_info = file_info_list.get(checkpoint_path, {}) filesize = file_info.get("size", 0) @@ -958,7 +965,7 @@ async def get_checkpoints(job_id: str, request: Request): config = {} model_name = config.get("model_name", "") adaptor_name = config.get("adaptor_name", "adaptor") - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() default_adaptor_dir = storage.join(workspace_dir, "adaptors", secure_filename(model_name), adaptor_name) # Get job directory @@ -966,13 +973,13 @@ async def get_checkpoints(job_id: str, request: Request): if not checkpoints_dir: from lab.dirs import get_job_checkpoints_dir - checkpoints_dir = get_job_checkpoints_dir(job_id) - if not checkpoints_dir or not storage.exists(checkpoints_dir): + checkpoints_dir = await get_job_checkpoints_dir(job_id) + if not checkpoints_dir or not await storage.exists(checkpoints_dir): return {"checkpoints": []} - elif storage.isdir(checkpoints_dir): + elif await storage.isdir(checkpoints_dir): checkpoints = [] try: - items = storage.ls(checkpoints_dir, detail=False) + items = await storage.ls(checkpoints_dir, detail=False) for item in items: file_path = item if isinstance(item, str) else str(item) filename = file_path.split("/")[-1] if "/" in file_path else file_path @@ -980,7 +987,7 @@ async def get_checkpoints(job_id: str, request: Request): if fnmatch(filename, "*_adapters.safetensors"): # Try to get file info try: - file_info_list = storage.ls(file_path, detail=True) + file_info_list = await storage.ls(file_path, detail=True) if file_info_list and isinstance(file_info_list, dict): file_info = file_info_list.get(file_path, {}) filesize = file_info.get("size", 0) @@ -1005,7 +1012,7 @@ async def get_checkpoints(job_id: str, request: Request): print(f"Error getting file info for {file_path}: {e}") checkpoints.append({"filename": filename, "date": None, "size": None}) # allow directories too - elif storage.isdir(file_path): + elif await storage.isdir(file_path): checkpoints.append({"filename": filename, "date": None, "size": None}) if checkpoints: return {"checkpoints": checkpoints} @@ -1020,7 +1027,7 @@ async def get_checkpoints(job_id: str, request: Request): checkpoints = [] try: - items = storage.ls(checkpoints_dir, detail=False) + items = await storage.ls(checkpoints_dir, detail=False) for item in items: file_path = item if isinstance(item, str) else str(item) filename = file_path.split("/")[-1] if "/" in file_path else file_path @@ -1028,7 +1035,7 @@ async def get_checkpoints(job_id: str, request: Request): if fnmatch(filename, checkpoints_file_filter): try: # Try to get file info from storage - file_info_list = storage.ls(file_path, detail=True) + file_info_list = await storage.ls(file_path, detail=True) if file_info_list and isinstance(file_info_list, dict): file_info = file_info_list.get(file_path, {}) filesize = file_info.get("size", 0) @@ -1081,17 +1088,17 @@ async def get_artifacts(job_id: str, request: Request): if not job_id or job_id in ("", "-1"): return {"artifacts": []} - # Get job data - job = job_service.job_get(job_id) - if not job: + """Get list of artifacts for a job""" + job = await job_service.job_get(job_id) + if job is None: return {"artifacts": []} # Use get_job_artifacts_dir to get the artifacts directory directly try: from lab.dirs import get_job_artifacts_dir - artifacts_dir = get_job_artifacts_dir(job_id) - artifacts = get_artifacts_from_directory(artifacts_dir, storage) + artifacts_dir = await get_job_artifacts_dir(job_id) + artifacts = await get_artifacts_from_directory(artifacts_dir, storage) except Exception as e: print(f"Error getting artifacts for job {job_id}: {e}") artifacts = [] @@ -1108,14 +1115,14 @@ async def download_all_artifacts(job_id: str): Download a zip file containing all artifacts for a job. """ # 1. Gather all artifact file paths using service - all_file_paths = job_service.get_all_artifact_paths(job_id, storage) + all_file_paths = await job_service.get_all_artifact_paths(job_id, storage) if not all_file_paths: return Response("No artifacts found for this job", status_code=404) # 2. Create Zip File in memory try: - zip_buffer = zip_utils.create_zip_from_storage(all_file_paths, storage) + zip_buffer = await zip_utils.create_zip_from_storage(all_file_paths, storage) filename = f"artifacts_{job_id}.zip" headers = { @@ -1174,7 +1181,7 @@ async def get_artifact(job_id: str, filename: str, task: str = "view"): artifacts_dir = job_data["artifacts_dir"] - if not storage.exists(artifacts_dir): + if not await storage.exists(artifacts_dir): return Response("Artifacts directory not found", status_code=404) # Secure the filename to prevent directory traversal @@ -1182,7 +1189,7 @@ async def get_artifact(job_id: str, filename: str, task: str = "view"): artifact_file_path = storage.join(artifacts_dir, filename_secure) # Ensure the file exists - if not storage.exists(artifact_file_path): + if not await storage.exists(artifact_file_path): return Response("Artifact not found", status_code=404) # Determine media type based on file extension @@ -1223,8 +1230,9 @@ async def get_artifact(job_id: str, filename: str, task: str = "view"): # For JSON files in view mode, return the parsed content if task == "view" and ext == ".json": try: - with storage.open(artifact_file_path, "r") as f: - content = json.load(f) + async with await storage.open(artifact_file_path, "r") as f: + content_str = await f.read() + content = json.loads(content_str) return content except Exception as e: print(f"Error reading JSON file: {e}") @@ -1232,10 +1240,10 @@ async def get_artifact(job_id: str, filename: str, task: str = "view"): # For download or other file types, stream the file # Use StreamingResponse to support both local and remote files (e.g., s3://) - def generate(): - with storage.open(artifact_file_path, "rb") as f: + async def generate(): + async with await storage.open(artifact_file_path, "rb") as f: while True: - chunk = f.read(8192) # Read in 8KB chunks + chunk = await f.read(8192) # Read in 8KB chunks if not chunk: break yield chunk @@ -1259,13 +1267,13 @@ def generate(): @router.get("/{job_id}") async def get_training_job_by_path(job_id: str): - return job_service.job_get(job_id) + return await job_service.job_get(job_id) @router.get("/{job_id}/output") async def get_training_job_output_jobpath(job_id: str, sweeps: bool = False): try: - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if job is None: return "Job not found" @@ -1281,9 +1289,9 @@ async def get_training_job_output_jobpath(job_id: str, sweeps: bool = False): if sweeps: output_file = job_data.get("sweep_output_file", None) - if output_file is not None and storage.exists(output_file): - with storage.open(output_file, "r") as f: - output = f.read() + if output_file is not None and await storage.exists(output_file): + async with await storage.open(output_file, "r") as f: + output = await f.read() return output else: # Fall back to regular output file logic @@ -1294,9 +1302,9 @@ async def get_training_job_output_jobpath(job_id: str, sweeps: bool = False): experiment_id = job["experiment_id"] output_file_name = await shared.get_job_output_file_name(job_id, experiment_name=experiment_id) - if storage.exists(output_file_name): - with storage.open(output_file_name, "r") as f: - output = f.read() + if await storage.exists(output_file_name): + async with await storage.open(output_file_name, "r") as f: + output = await f.read() return output else: return "Output file not found" @@ -1313,7 +1321,7 @@ async def get_training_job_output_jobpath(job_id: str, sweeps: bool = False): @router.get("/{job_id}/sweep_results") async def sweep_results(job_id: str): try: - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if job is None: return {"status": "error", "message": "Job not found."} @@ -1328,10 +1336,11 @@ async def sweep_results(job_id: str): job_data = {} output_file = job_data.get("sweep_results_file", None) - if output_file and storage.exists(output_file): + if output_file and await storage.exists(output_file): try: - with storage.open(output_file, "r") as f: - output = json.load(f) + async with await storage.open(output_file, "r") as f: + content_str = await f.read() + output = json.loads(content_str) return {"status": "success", "data": output} except json.JSONDecodeError as e: print(f"JSON decode error for job {job_id}: {e}") diff --git a/api/transformerlab/routers/experiment/plugins.py b/api/transformerlab/routers/experiment/plugins.py index b729af217..6ba351aa0 100644 --- a/api/transformerlab/routers/experiment/plugins.py +++ b/api/transformerlab/routers/experiment/plugins.py @@ -24,7 +24,7 @@ async def experiment_list_scripts(id: str, type: str = None, filter: str = None): """List all the scripts in the experiment""" # first get the experiment name: - data = experiment_get(id) + data = await experiment_get(id) # if the experiment does not exist, return an error: if data is None: @@ -45,7 +45,7 @@ async def experiment_list_scripts(id: str, type: str = None, filter: str = None) from lab.dirs import get_plugin_dir - scripts_dir = get_plugin_dir() + scripts_dir = await get_plugin_dir() # now get a list of all the directories in the scripts directory: scripts_full_json = [] @@ -142,7 +142,7 @@ async def plugin_download(id: int, plugin_slug: str): response = await client.get(url + file) file_contents = response.text # Save each file to workspace/plugins// - p = lab_dirs.plugin_dir_by_name(plugin_slug) + p = await lab_dirs.plugin_dir_by_name(plugin_slug) os.makedirs(p, mode=0o755, exist_ok=True) with open(f"{p}/{file}", "w") as f: f.write(file_contents) @@ -167,7 +167,7 @@ async def plugin_save_file_contents(id: str, pluginId: str, filename: str, file_ filename = secure_filename(filename) - data = experiment_get(id) + data = await experiment_get(id) # if the experiment does not exist, return an error: if data is None: return {"message": f"Experiment {id} does not exist"} @@ -184,7 +184,7 @@ async def plugin_save_file_contents(id: str, pluginId: str, filename: str, file_ filename = shared.slugify(filename) pluginId = shared.slugify(pluginId) - script_path = lab_dirs.plugin_dir_by_name(pluginId) + script_path = await lab_dirs.plugin_dir_by_name(pluginId) # make directory if it does not exist: if not os.path.exists(f"{script_path}"): @@ -204,7 +204,7 @@ async def plugin_get_file_contents(id: str, pluginId: str, filename: str): filename = secure_filename(filename) - data = experiment_get(id) + data = await experiment_get(id) # if the experiment does not exist, return an error: if data is None: return {"message": f"Experiment {id} does not exist"} @@ -218,7 +218,7 @@ async def plugin_get_file_contents(id: str, pluginId: str, filename: str): return {"message": f"File extension {file_ext} for {filename} not supported"} # The following prevents path traversal attacks: - plugin_dir = lab_dirs.plugin_dir_by_name((pluginId)) + plugin_dir = await lab_dirs.plugin_dir_by_name((pluginId)) final_path = Path(plugin_dir).joinpath(filename + file_ext).resolve().relative_to(plugin_dir) final_path = plugin_dir + "/" + str(final_path) @@ -237,13 +237,13 @@ async def plugin_get_file_contents(id: str, pluginId: str, filename: str): async def plugin_list_files(id: str, pluginId: str): global allowed_extensions - data = experiment_get(id) + data = await experiment_get(id) # if the experiment does not exist, return an error: if data is None: return {"message": f"Experiment {id} does not exist"} # experiment_name = data["name"] - scripts_dir = lab_dirs.plugin_dir_by_name(pluginId) + scripts_dir = await lab_dirs.plugin_dir_by_name(pluginId) # check if directory exists: if not os.path.exists(scripts_dir): @@ -265,7 +265,7 @@ async def plugin_create_new_file(id: str, pluginId: str, filename: str): filename = secure_filename(filename) - data = experiment_get(id) + data = await experiment_get(id) # if the experiment does not exist, return an error: if data is None: return {"message": f"Experiment {id} does not exist"} @@ -285,7 +285,7 @@ async def plugin_create_new_file(id: str, pluginId: str, filename: str): filename = shared.slugify(filename) pluginId = shared.slugify(pluginId) - script_path = lab_dirs.plugin_dir_by_name(pluginId) + script_path = await lab_dirs.plugin_dir_by_name(pluginId) # make directory if it does not exist: if not os.path.exists(f"{script_path}"): @@ -305,7 +305,7 @@ async def plugin_delete_file(id: str, pluginId: str, filename: str): filename = secure_filename(filename) - data = experiment_get(id) + data = await experiment_get(id) # if the experiment does not exist, return an error: if data is None: return {"message": f"Experiment {id} does not exist"} @@ -325,7 +325,7 @@ async def plugin_delete_file(id: str, pluginId: str, filename: str): filename = shared.slugify(filename) pluginId = shared.slugify(pluginId) - script_path = lab_dirs.plugin_dir_by_name(pluginId) + script_path = await lab_dirs.plugin_dir_by_name(pluginId) # make directory if it does not exist: if not os.path.exists(f"{script_path}"): @@ -341,7 +341,7 @@ async def plugin_delete_file(id: str, pluginId: str, filename: str): async def plugin_new_plugin_directory(id: str, pluginId: str): global allowed_extensions - data = experiment_get(id) + data = await experiment_get(id) # if the experiment does not exist, return an error: if data is None: return {"message": f"Experiment {id} does not exist"} @@ -351,7 +351,7 @@ async def plugin_new_plugin_directory(id: str, pluginId: str): # clean the file name: pluginId = shared.slugify(value=pluginId) - script_path = lab_dirs.plugin_dir_by_name(pluginId) + script_path = await lab_dirs.plugin_dir_by_name(pluginId) # make directory if it does not exist: if not os.path.exists(f"{script_path}"): diff --git a/api/transformerlab/routers/experiment/rag.py b/api/transformerlab/routers/experiment/rag.py index 66dd30b12..8eaee2881 100644 --- a/api/transformerlab/routers/experiment/rag.py +++ b/api/transformerlab/routers/experiment/rag.py @@ -37,15 +37,15 @@ async def query( """Query the RAG engine""" exp_obj = Experiment(experimentId) - experiment_dir = exp_obj.get_dir() + experiment_dir = await exp_obj.get_dir() documents_dir = storage.join(experiment_dir, "documents") documents_dir = storage.join(documents_dir, rag_folder) # Basic traversal protection for posix paths if not documents_dir.startswith(experiment_dir.rstrip("/") + "/") and documents_dir != experiment_dir: return "Error: Invalid RAG folder path" - if not storage.exists(documents_dir): + if not await storage.exists(documents_dir): return "Error: The RAG folder does not exist in the documents directory" - experiment_details = experiment_get(id=experimentId) + experiment_details = await experiment_get(id=experimentId) experiment_config = ( experiment_details["config"] if isinstance(experiment_details["config"], dict) @@ -80,7 +80,8 @@ async def query( # Check if it exists in workspace/plugins: from lab.dirs import get_plugin_dir - plugin_path = os.path.join(get_plugin_dir(), plugin) + plugin_dir = await get_plugin_dir() + plugin_path = os.path.join(plugin_dir, plugin) if not os.path.exists(plugin_path): return f"Plugin {plugin} does not exist on the filesystem -- you must install or reinstall this plugin." @@ -160,13 +161,13 @@ async def reindex( """Reindex the RAG engine""" exp_obj = Experiment(experimentId) - experiment_dir = exp_obj.get_dir() + experiment_dir = await exp_obj.get_dir() documents_dir = storage.join(experiment_dir, "documents") documents_dir = storage.join(documents_dir, secure_filename(rag_folder)) - if not storage.exists(documents_dir): + if not await storage.exists(documents_dir): return "Error: The RAG folder does not exist in the documents directory." - experiment_details = experiment_get(id=experimentId) + experiment_details = await experiment_get(id=experimentId) experiment_config = ( experiment_details["config"] if isinstance(experiment_details["config"], dict) @@ -192,7 +193,8 @@ async def reindex( # Check if it exists in workspace/plugins: from lab.dirs import get_plugin_dir - plugin_path = os.path.join(get_plugin_dir(), plugin) + plugin_dir = await get_plugin_dir() + plugin_path = os.path.join(plugin_dir, plugin) if not os.path.exists(plugin_path): return f"Plugin {plugin} does not exist on the filesystem -- you must install or reinstall this plugin." @@ -264,7 +266,7 @@ async def embed_text(request: EmbedRequest): """Embed text using the embedding model using sentence transformers""" from sentence_transformers import SentenceTransformer - experiment_details = experiment_get(id=request.experiment_id) + experiment_details = await experiment_get(id=request.experiment_id) experiment_config = ( experiment_details["config"] if isinstance(experiment_details["config"], dict) diff --git a/api/transformerlab/routers/experiment/task.py b/api/transformerlab/routers/experiment/task.py index b4c23bb59..f59177571 100644 --- a/api/transformerlab/routers/experiment/task.py +++ b/api/transformerlab/routers/experiment/task.py @@ -81,13 +81,13 @@ def process_env_parameters_to_env_vars(config: dict) -> dict: @router.get("/list", summary="Returns all the tasks") async def task_get_all(): - tasks = task_service.task_get_all() + tasks = await task_service.task_get_all() return tasks @router.get("/{task_id}/get", summary="Gets all the data for a single task") async def task_get_by_id(task_id: str): - task = task_service.task_get_by_id(task_id) + task = await task_service.task_get_by_id(task_id) if task is None: return {"message": "NOT FOUND"} return task @@ -95,7 +95,7 @@ async def task_get_by_id(task_id: str): @router.get("/list_by_type", summary="Returns all the tasks of a certain type, e.g TRAIN") async def task_get_by_type(type: str): - tasks = task_service.task_get_by_type(type) + tasks = await task_service.task_get_by_type(type) return tasks @@ -104,7 +104,7 @@ async def task_get_by_type(type: str): summary="Returns all the tasks of a certain type in a certain experiment, e.g TRAIN", ) async def task_get_by_type_in_experiment(experimentId: str, type: str): - tasks = task_service.task_get_by_type_in_experiment(type, experimentId) + tasks = await task_service.task_get_by_type_in_experiment(type, experimentId) return tasks @@ -117,7 +117,7 @@ async def task_get_by_subtype_in_experiment( subtype: str, type: Optional[str] = Query(None, description="Optional task type filter (e.g., REMOTE)"), ): - tasks = task_service.task_get_by_subtype_in_experiment(experimentId, subtype, type) + tasks = await task_service.task_get_by_subtype_in_experiment(experimentId, subtype, type) return tasks @@ -126,7 +126,7 @@ async def update_task(task_id: str, new_task: dict = Body()): # Perform secure_filename before updating the task if "name" in new_task: new_task["name"] = secure_filename(new_task["name"]) - success = task_service.update_task(task_id, new_task) + success = await task_service.update_task(task_id, new_task) if success: return {"message": "OK"} else: @@ -135,7 +135,7 @@ async def update_task(task_id: str, new_task: dict = Body()): @router.get("/{task_id}/delete", summary="Deletes a task") async def delete_task(task_id: str): - success = task_service.delete_task(task_id) + success = await task_service.delete_task(task_id) if success: return {"message": "OK"} else: @@ -294,16 +294,16 @@ async def _store_zip_file(zip_file: UploadFile, task_id: str) -> str: Store a zip file locally for a task. Returns the stored path that should be mapped to ~/src in file_mounts. """ - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() if not workspace_dir: raise HTTPException(status_code=500, detail="Workspace directory is not configured") # Create uploads/task/{task_id} directory uploads_root = storage.join(workspace_dir, "uploads", "task") - storage.makedirs(uploads_root, exist_ok=True) + await storage.makedirs(uploads_root, exist_ok=True) task_dir = storage.join(uploads_root, str(task_id)) - storage.makedirs(task_dir, exist_ok=True) + await storage.makedirs(task_dir, exist_ok=True) # Generate a safe filename for the zip file import uuid @@ -330,8 +330,8 @@ async def _store_zip_file(zip_file: UploadFile, task_id: str) -> str: zip_ref.testzip() # Store the zip file - with storage.open(stored_path, "wb") as f: - f.write(zip_content) + async with await storage.open(stored_path, "wb") as f: + await f.write(zip_content) # Clean up temp file os.remove(temp_zip_path) @@ -441,14 +441,14 @@ async def add_task( new_task["name"] = secure_filename(new_task["name"]) # All fields are stored directly in the JSON (not nested in inputs/outputs/config) - task_id = task_service.add_task(new_task) + task_id = await task_service.add_task(new_task) # Handle zip file if provided (for JSON requests, zip_file would come from multipart) if zip_file and zip_file.filename: try: zip_path = await _store_zip_file(zip_file, task_id) # Update task with file_mounts - map ~/src to the stored zip file - task_service.update_task(task_id, {"file_mounts": {"~/src": zip_path}}) + await task_service.update_task(task_id, {"file_mounts": {"~/src": zip_path}}) except Exception as e: # Log error but don't fail task creation print(f"Warning: Failed to process zip file: {e}") @@ -470,14 +470,14 @@ async def add_task( # Handle provider matching await _resolve_provider(task_data, user_and_team, session) - task_id = task_service.add_task(task_data) + task_id = await task_service.add_task(task_data) # Handle zip file if provided if zip_file and zip_file.filename: try: zip_path = await _store_zip_file(zip_file, task_id) # Update task with file_mounts - map ~/src to the stored zip file - task_service.update_task(task_id, {"file_mounts": {"~/src": zip_path}}) + await task_service.update_task(task_id, {"file_mounts": {"~/src": zip_path}}) except Exception as e: # Log error but don't fail task creation print(f"Warning: Failed to process zip file: {e}") @@ -503,14 +503,14 @@ async def add_task( if "name" in task_data: task_data["name"] = secure_filename(task_data["name"]) - task_id = task_service.add_task(task_data) + task_id = await task_service.add_task(task_data) # Handle zip file if provided if zip_file and zip_file.filename: try: zip_path = await _store_zip_file(zip_file, task_id) # Update task with file_mounts - map ~/src to the stored zip file - task_service.update_task(task_id, {"file_mounts": {"~/src": zip_path}}) + await task_service.update_task(task_id, {"file_mounts": {"~/src": zip_path}}) except Exception as e: # Log error but don't fail task creation print(f"Warning: Failed to process zip file: {e}") @@ -525,21 +525,21 @@ async def add_task( @router.get("/delete_all", summary="Wipe all tasks") async def task_delete_all(): - task_service.task_delete_all() + await task_service.task_delete_all() return {"message": "OK"} @router.get("/gallery", summary="List all tasks from the tasks gallery") async def task_gallery(): """Get the tasks gallery from the JSON file (same as tasks gallery)""" - gallery = galleries.get_tasks_gallery() + gallery = await galleries.get_tasks_gallery() return {"status": "success", "data": gallery} @router.get("/gallery/interactive", summary="List all interactive task templates") async def interactive_gallery(): """Get the interactive tasks gallery (vscode, jupyter, vllm, ssh templates)""" - gallery = galleries.get_interactive_gallery() + gallery = await galleries.get_interactive_gallery() return {"status": "success", "data": gallery} @@ -554,7 +554,7 @@ async def import_task_from_gallery( Creates a new task using the gallery entry's config and GitHub info. Uses the team's GitHub PAT if available. """ - gallery = galleries.get_tasks_gallery() + gallery = await galleries.get_tasks_gallery() # Find the gallery entry by index or ID try: @@ -633,7 +633,7 @@ async def import_task_from_gallery( # Perform secure_filename before adding the task new_task["name"] = secure_filename(new_task["name"]) - task_service.add_task(new_task) + await task_service.add_task(new_task) return {"status": "success", "message": f"Task '{task_name}' imported successfully"} @@ -641,7 +641,7 @@ async def import_task_from_gallery( @router.get("/gallery/team", summary="List team-specific tasks from the team gallery") async def team_task_gallery(): """Get the team-specific tasks gallery stored in workspace_dir (same as tasks gallery)""" - gallery = galleries.get_team_tasks_gallery() + gallery = await galleries.get_team_tasks_gallery() return {"status": "success", "data": gallery} @@ -654,7 +654,7 @@ async def import_task_from_team_gallery( """ Import a task from the team-specific tasks gallery (workspace_dir/team_specific_tasks.json). """ - gallery = galleries.get_team_tasks_gallery() + gallery = await galleries.get_team_tasks_gallery() # Find the gallery entry by index or ID try: @@ -730,7 +730,7 @@ async def import_task_from_team_gallery( # Perform secure_filename before adding the task new_task["name"] = secure_filename(new_task["name"]) - task_service.add_task(new_task) + await task_service.add_task(new_task) return {"status": "success", "message": f"Task '{task_name}' imported successfully"} @@ -744,7 +744,7 @@ async def export_task_to_team_gallery( Export a task into the team-specific gallery stored in workspace_dir. Tasks store all fields directly (not nested in config). """ - task = task_service.task_get_by_id(request.task_id) + task = await task_service.task_get_by_id(request.task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") @@ -788,7 +788,7 @@ async def export_task_to_team_gallery( "github_repo_dir": task.get("github_directory"), } - galleries.add_team_task_to_gallery(gallery_entry) + await galleries.add_team_task_to_gallery(gallery_entry) return { "status": "success", @@ -805,7 +805,7 @@ async def delete_team_task_from_gallery( """ Delete a task from the team-specific gallery stored in workspace_dir. """ - success = galleries.delete_team_task_from_gallery(request.task_id) + success = await galleries.delete_team_task_from_gallery(request.task_id) if success: return { "status": "success", diff --git a/api/transformerlab/routers/experiment/workflows.py b/api/transformerlab/routers/experiment/workflows.py index 71efb040b..250e5a527 100644 --- a/api/transformerlab/routers/experiment/workflows.py +++ b/api/transformerlab/routers/experiment/workflows.py @@ -432,7 +432,7 @@ async def workflow_runs_get_by_id(workflow_run_id: str, experimentId: str): job_ids = workflow_run.get("job_ids", []) for job_id in job_ids: - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if not job: continue @@ -490,7 +490,7 @@ async def cancel_workflow_run(workflow_run_id: str, experimentId: str): cancelled_jobs = [] for job_id in current_job_ids: - job_service.job_stop(job_id, experimentId) # This sets stop=True in job_data + await job_service.job_stop(job_id, experimentId) # This sets stop=True in job_data cancelled_jobs.append(job_id) # The workflow execution engine will automatically detect the stopped jobs @@ -558,7 +558,7 @@ async def check_current_jobs_status(workflow_run_id, current_job_ids): return None for job_id in current_job_ids: - current_job = job_service.job_get(job_id) + current_job = await job_service.job_get(job_id) if not current_job: await workflow_run_update_status(workflow_run_id, "FAILED") return f"Could not find job with ID {job_id}" @@ -686,12 +686,12 @@ async def find_previous_node_and_job(current_node, workflow_run, workflow_config ran_jobs = json.loads(ran_jobs_str) if previous_node["id"] in ran_nodes: previous_job_ID = ran_jobs[ran_nodes.index(previous_node["id"])] - previous_job = job_service.job_get(previous_job_ID) + previous_job = await job_service.job_get(previous_job_ID) return previous_job -def extract_previous_job_outputs(previous_job): +async def extract_previous_job_outputs(previous_job): """Extracts relevant output information from a completed job.""" outputs = {} if previous_job is None or "job_data" not in previous_job or not previous_job["job_data"]: @@ -704,7 +704,7 @@ def extract_previous_job_outputs(previous_job): try: from lab.dirs import get_models_dir - fuse_pretext = get_models_dir() + "/" + fuse_pretext = (await get_models_dir()) + "/" except Exception: fuse_pretext = "" @@ -829,7 +829,7 @@ async def queue_job_for_node(node: dict, workflow_run: dict, workflow_config: di previous_job = await find_previous_node_and_job(node, workflow_run, workflow_config) # Extract outputs from the previous job - previous_outputs = extract_previous_job_outputs(previous_job) + previous_outputs = await extract_previous_job_outputs(previous_job) # Prepare inputs and outputs for the new job inputs_json, outputs_json = prepare_next_task_io(task_def, previous_outputs) diff --git a/api/transformerlab/routers/model.py b/api/transformerlab/routers/model.py index b26868dd1..b136fe022 100644 --- a/api/transformerlab/routers/model.py +++ b/api/transformerlab/routers/model.py @@ -51,7 +51,7 @@ def _parse_gallery_added_date(value) -> date: return date(2024, 2, 1) -def get_model_dir(model_id: str): +async def get_model_dir(model_id: str): """ Helper function gets the directory for a model ID model_id may be in Hugging Face format @@ -59,10 +59,11 @@ def get_model_dir(model_id: str): model_id_without_author = model_id.split("/")[-1] from lab.dirs import get_models_dir - return storage.join(get_models_dir(), model_id_without_author) + models_dir = await get_models_dir() + return storage.join(models_dir, model_id_without_author) -def get_current_org_id() -> str | None: +async def get_current_org_id() -> str | None: """ Resolve the current organization id from workspace path when multitenant is enabled. Returns None if multitenancy is disabled or org id cannot be determined. @@ -70,7 +71,7 @@ def get_current_org_id() -> str | None: try: from lab.dirs import get_workspace_dir - ws = get_workspace_dir() + ws = await get_workspace_dir() if "/orgs/" in ws: return ws.split("/orgs/")[-1].split("/")[0] except Exception: @@ -78,12 +79,12 @@ def get_current_org_id() -> str | None: return None -def get_model_details_from_gallery(model_id: str): +async def get_model_details_from_gallery(model_id: str): """ Given a model ID this returns the associated data from the model gallery file. Returns None if no such value found. """ - gallery = galleries.get_models_gallery() + gallery = await galleries.get_models_gallery() result = None @@ -97,7 +98,7 @@ def get_model_details_from_gallery(model_id: str): @router.get("/model/gallery") async def model_gallery_list_all(): - gallery = galleries.get_models_gallery() + gallery = await galleries.get_models_gallery() # Get a list of local models to determine what has been downloaded already local_models = await model_helper.list_installed_models() @@ -129,7 +130,7 @@ async def model_gallery_list_all(): @router.get("/model/model_groups_list", summary="Returns the grouped model gallery from model-group-gallery.json.") async def model_groups_list_all(): - gallery = galleries.get_model_groups_gallery() + gallery = await galleries.get_model_groups_gallery() # Get list of locally installed models local_models = await model_helper.list_installed_models() @@ -212,7 +213,7 @@ async def model_gallery(model_id: str): # convert "~~~"" in string to "/": model_id = model_id.replace("~~~", "/") - return get_model_details_from_gallery(model_id) + return await get_model_details_from_gallery(model_id) # Should this be a POST request? @@ -225,7 +226,7 @@ async def upload_model_to_huggingface( """ Given a model ID, upload it to Hugging Face. """ - model_directory = get_model_dir(model_id) + model_directory = await get_model_dir(model_id) api = HfApi() try: # Using HF API to check user info and use it for the model creation @@ -297,14 +298,14 @@ async def model_details_from_filesystem(model_id: str): # TODO: Refactor this code with models/list function # see if the model exists locally - model_path = get_model_dir(model_id) - if storage.isdir(model_path): + model_path = await get_model_dir(model_id) + if await storage.isdir(model_path): # Look for model information using SDK methods try: from lab.model import Model as ModelService model_service = ModelService(model_id) - filedata = model_service.get_metadata() + filedata = await model_service.get_metadata() # Some models are a single file (possibly of many in a directory, e.g. GGUF) # For models that have model_filename set we should link directly to that specific file @@ -523,7 +524,7 @@ async def download_huggingface_model( - message: error message if status is "error" """ if job_id is None: - job_id = job_service.job_create( + job_id = await job_service.job_create( type="DOWNLOAD_MODEL", status="STARTED", experiment_id=experiment_id, job_data="{}" ) else: @@ -559,7 +560,7 @@ async def download_huggingface_model( workspace_dir = storage.join(HOME_DIR, "orgs", organization_id, "workspace") else: # Use default workspace path - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() args += ["--workspace_dir", workspace_dir] except Exception as e: @@ -590,7 +591,7 @@ async def download_huggingface_model( if exitcode == 77: # This means we got a GatedRepoError # The user needs to agree to terms on HuggingFace to download - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) error_msg = None if job and job.get("job_data"): error_msg = job["job_data"].get("error_msg") @@ -598,7 +599,7 @@ async def download_huggingface_model( return {"status": "unauthorized", "message": error_msg} elif exitcode != 0: - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) error_msg = None if job and job.get("job_data"): error_msg = job["job_data"].get("error_msg") @@ -624,12 +625,12 @@ async def download_huggingface_model( if hugging_face_filename is None: # only save to local filesystem if we are downloading the whole repo try: - model_service = ModelService.create(hugging_face_id) - model_service.set_metadata(model_id=hugging_face_id, name=name, json_data=model_details) + model_service = await ModelService.create(hugging_face_id) + await model_service.set_metadata(model_id=hugging_face_id, name=name, json_data=model_details) except FileExistsError: # Model already exists, update it - model_service = ModelService.get(hugging_face_id) - model_service.set_metadata(model_id=hugging_face_id, name=name, json_data=model_details) + model_service = await ModelService.get(hugging_face_id) + await model_service.set_metadata(model_id=hugging_face_id, name=name, json_data=model_details) return {"status": "success", "message": "success", "model": model_details, "job_id": job_id} @@ -708,7 +709,7 @@ async def download_model_by_huggingface_id(model: str, job_id: int | None = None if is_sd: model_details["allow_patterns"] = sd_patterns - org_id = get_current_org_id() + org_id = await get_current_org_id() print("🔵 CURRENT ORG ID: ", org_id) return await download_huggingface_model(model, model_details, job_id, experiment_id, org_id, user_id=None) @@ -754,7 +755,7 @@ async def download_gguf_file_from_repo(model: str, filename: str, job_id: int | except Exception: pass # Use existing size if we can't get specific file size - org_id = get_current_org_id() + org_id = await get_current_org_id() return await download_huggingface_model(model, model_details, job_id, experiment_id, org_id, user_id=None) @@ -768,7 +769,7 @@ async def download_model_from_gallery(gallery_id: str, job_id: int | None = None # Get model details from the gallery # If None then return an error - gallery_entry = get_model_details_from_gallery(gallery_id) + gallery_entry = await get_model_details_from_gallery(gallery_id) if gallery_entry is None: return {"status": "error", "message": "Model not found in gallery"} @@ -779,8 +780,8 @@ async def download_model_from_gallery(gallery_id: str, job_id: int | None = None if "pipeline_tag" not in gallery_entry: # First try to get from filesystem try: - model_service = ModelService.get(huggingface_id) - model_data = model_service.get_metadata() + model_service = await ModelService.get(huggingface_id) + model_data = await model_service.get_metadata() if model_data and model_data.get("json_data") and "pipeline_tag" in model_data["json_data"]: gallery_entry["pipeline_tag"] = model_data["json_data"]["pipeline_tag"] else: @@ -804,8 +805,7 @@ async def download_model_from_gallery(gallery_id: str, job_id: int | None = None print(f"Error fetching pipeline tag for {huggingface_id}: {type(e).__name__}: {e}") gallery_entry["pipeline_tag"] = "text-generation" - org_id = get_current_org_id() - print("🔵 CURRENT ORG ID: ", org_id) + org_id = await get_current_org_id() return await download_huggingface_model(huggingface_id, gallery_entry, job_id, experiment_id, org_id) @@ -836,7 +836,7 @@ async def model_provenance(model_id: str): async def model_count_downloaded(): # Currently used to determine if user has any downloaded models # Use filesystem instead of database - models = ModelService.list_all() + models = await ModelService.list_all() count = len(models) return {"status": "success", "data": count} @@ -845,8 +845,8 @@ async def model_count_downloaded(): async def model_local_create(id: str, name: str, json_data={}): # Use filesystem instead of database try: - model_service = ModelService.create(id) - model_service.set_metadata(model_id=id, name=name, json_data=json_data) + model_service = await ModelService.create(id) + await model_service.set_metadata(model_id=id, name=name, json_data=json_data) return {"message": "model created"} except FileExistsError: return {"status": "error", "message": f"Model {id} already exists"} @@ -859,11 +859,11 @@ async def model_local_create(id: str, name: str, json_data={}): async def model_local_delete(model_id: str, delete_from_cache: bool = False): # Try to delete from filesystem first using SDK try: - model_service = ModelService.get(model_id) + model_service = await ModelService.get(model_id) # Delete the entire directory - model_dir = model_service.get_dir() - if storage.exists(model_dir): - storage.rm_tree(model_dir) + model_dir = await model_service.get_dir() + if await storage.exists(model_dir): + await storage.rm_tree(model_dir) print(f"Deleted filesystem model: {model_id}") except FileNotFoundError: # Model not found in filesystem, continue with other deletion methods @@ -874,7 +874,7 @@ async def model_local_delete(model_id: str, delete_from_cache: bool = False): # Also try the legacy method for backward compatibility from lab.dirs import get_models_dir - root_models_dir = get_models_dir() + root_models_dir = await get_models_dir() # Sanitize and validate model_dir unsafe_model_dir = model_id.rsplit("/", 1)[-1] @@ -883,16 +883,16 @@ async def model_local_delete(model_id: str, delete_from_cache: bool = False): candidate_index_file = storage.join(root_models_dir, model_dir, "index.json") # For fsspec, validate paths are within root_models_dir by checking they start with it - if not storage.exists(candidate_index_file): + if not await storage.exists(candidate_index_file): pass # File doesn't exist, skip legacy deletion elif not candidate_index_file.startswith(root_models_dir): print("ERROR: Invalid index file path") - elif storage.isfile(candidate_index_file): + elif await storage.isfile(candidate_index_file): model_path = storage.join(root_models_dir, model_dir) if not model_path.startswith(root_models_dir): print("ERROR: Invalid directory structure") print(f"Deleteing {model_path}") - storage.rm_tree(model_path) + await storage.rm_tree(model_path) else: if delete_from_cache: @@ -917,21 +917,21 @@ async def model_local_delete(model_id: str, delete_from_cache: bool = False): async def model_gets_pefts( model_id: Annotated[str, Body()], ): - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() model_id = secure_filename(model_id) adaptors_dir = storage.join(workspace_dir, "adaptors", model_id) - if not storage.exists(adaptors_dir): + if not await storage.exists(adaptors_dir): return [] # Use storage.ls to list directory contents try: - all_items = storage.ls(adaptors_dir, detail=False) + all_items = await storage.ls(adaptors_dir, detail=False) adaptors = [] for item_path in all_items: # Extract just the name from full path (works for both local and remote) name = item_path.split("/")[-1].split("\\")[-1] # Handle both / and \ separators - if not name.startswith(".") and storage.isdir(item_path): + if not name.startswith(".") and await storage.isdir(item_path): adaptors.append(name) except Exception: # Fallback to empty list if listing fails @@ -941,17 +941,17 @@ async def model_gets_pefts( @router.get("/model/delete_peft") async def model_delete_peft(model_id: str, peft: str): - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() secure_model_id = secure_filename(model_id) adaptors_dir = storage.join(workspace_dir, "adaptors", secure_model_id) # Check if the peft exists - if storage.exists(adaptors_dir): + if await storage.exists(adaptors_dir): peft_path = storage.join(adaptors_dir, peft) else: # Assume the adapter is stored in the older naming convention format peft_path = storage.join(workspace_dir, "adaptors", model_id, peft) - storage.rm_tree(peft_path) + await storage.rm_tree(peft_path) return {"message": "success"} @@ -1044,7 +1044,7 @@ def compare_field(a_cfg, b_cfg, key, fallback_keys=None): print(f"Model Details: {model_details}") # Create or update job if job_id is None: - job_id = job_service.job_create( + job_id = await job_service.job_create( type="DOWNLOAD_MODEL", status="STARTED", experiment_id=experiment_id, job_data="{}" ) else: @@ -1074,7 +1074,7 @@ def compare_field(a_cfg, b_cfg, key, fallback_keys=None): try: from lab.dirs import get_workspace_dir - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() args += ["--workspace_dir", workspace_dir] except Exception: pass @@ -1110,8 +1110,8 @@ async def get_local_hfconfig(model_id: str): async def get_model_from_db(model_id: str): # Get model from filesystem - model_service = ModelService.get(model_id) - return model_service.get_metadata() + model_service = await ModelService.get(model_id) + return await model_service.get_metadata() @router.get("/model/list_local_uninstalled") @@ -1197,7 +1197,7 @@ async def model_import_local_path(model_path: str): """ # Restrict to workspace directory only - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() # Normalize both workspace and input paths abs_workspace_dir = os.path.abspath(os.path.normpath(workspace_dir)) abs_model_path = os.path.abspath(os.path.normpath(model_path)) @@ -1279,8 +1279,8 @@ async def get_pipeline_tag(model_name: str): """ # First try to get from filesystem try: - model_service = ModelService.get(model_name) - model_data = model_service.get_metadata() + model_service = await ModelService.get(model_name) + model_data = await model_service.get_metadata() if model_data and model_data.get("json_data") and "pipeline_tag" in model_data["json_data"]: pipeline_tag = model_data["json_data"]["pipeline_tag"] return {"status": "success", "data": pipeline_tag, "model_id": model_name} diff --git a/api/transformerlab/routers/plugins.py b/api/transformerlab/routers/plugins.py index 4581f4b45..67d5e6321 100644 --- a/api/transformerlab/routers/plugins.py +++ b/api/transformerlab/routers/plugins.py @@ -61,7 +61,7 @@ async def plugin_gallery(): # Now get a list of the plugins that are already installed: from lab.dirs import get_plugin_dir - local_workspace_gallery_directory = get_plugin_dir() + local_workspace_gallery_directory = await get_plugin_dir() installed_plugins = [] if os.path.exists(local_workspace_gallery_directory): for lp in os.listdir(local_workspace_gallery_directory): @@ -165,14 +165,14 @@ async def copy_plugin_files_to_workspace(plugin_id: str): # create the directory if it doesn't exist from lab.dirs import get_plugin_dir - new_directory = os.path.join(get_plugin_dir(), plugin_id) + new_directory = os.path.join(await get_plugin_dir(), plugin_id) if not os.path.exists(plugin_path): print(f"Plugin {plugin_path} not found in gallery.") return if not os.path.exists(new_directory): os.makedirs(new_directory) # Now copy it to the workspace: - copy_tree(plugin_path, lab_dirs.plugin_dir_by_name(plugin_id)) + copy_tree(plugin_path, await lab_dirs.plugin_dir_by_name(plugin_id)) async def delete_plugin_files_from_workspace(plugin_id: str): @@ -180,7 +180,7 @@ async def delete_plugin_files_from_workspace(plugin_id: str): from lab.dirs import get_plugin_dir - plugin_path = os.path.join(get_plugin_dir(), plugin_id) + plugin_path = os.path.join(await get_plugin_dir(), plugin_id) # return if the directory doesn't exist if not os.path.exists(plugin_path): print(f"Plugin {plugin_path} not found in workspace.") @@ -193,7 +193,7 @@ async def run_installer_for_plugin(plugin_id: str, log_file): plugin_id = secure_filename(plugin_id) from lab.dirs import get_plugin_dir - new_directory = os.path.join(get_plugin_dir(), plugin_id) + new_directory = os.path.join(await get_plugin_dir(), plugin_id) venv_path = os.path.join(new_directory, "venv") plugin_path = os.path.join(dirs.PLUGIN_PRELOADED_GALLERY, plugin_id) @@ -278,7 +278,7 @@ async def run_installer_for_plugin(plugin_id: str, log_file): @router.get(path="/delete_plugin") async def delete_plugin(plugin_name: str): - final_path = lab_dirs.plugin_dir_by_name(plugin_name) + final_path = await lab_dirs.plugin_dir_by_name(plugin_name) remove_tree(final_path) return {"message": f"Plugin {plugin_name} deleted successfully."} @@ -300,12 +300,12 @@ async def install_plugin(plugin_id: str): await copy_plugin_files_to_workspace(plugin_id) - new_directory = os.path.join(lab_dirs.get_plugin_dir(), plugin_id) + new_directory = os.path.join(await lab_dirs.get_plugin_dir(), plugin_id) venv_path = os.path.join(new_directory, "venv") from lab.dirs import get_global_log_path - global_log_file_name = get_global_log_path() + global_log_file_name = await get_global_log_path() async with aiofiles.open(global_log_file_name, "a") as log_file: # Create virtual environment using uv print("Creating virtual environment for plugin...") @@ -419,7 +419,7 @@ async def install_plugin(plugin_id: str): async def run_installer_script(plugin_id: str): from lab.dirs import get_global_log_path - global_log_file_name = get_global_log_path() + global_log_file_name = await get_global_log_path() async with aiofiles.open(global_log_file_name, "a") as log_file: return await run_installer_for_plugin(plugin_id, log_file) return {"status": "error", "message": f"Failed to open log file: {global_log_file_name}"} @@ -431,7 +431,7 @@ async def list_plugins() -> list[object]: from lab.dirs import get_plugin_dir - local_workspace_gallery_directory = get_plugin_dir() + local_workspace_gallery_directory = await get_plugin_dir() # now get the local workspace gallery workspace_gallery = [] diff --git a/api/transformerlab/routers/prompts.py b/api/transformerlab/routers/prompts.py index 132df57fb..cf3fdef93 100644 --- a/api/transformerlab/routers/prompts.py +++ b/api/transformerlab/routers/prompts.py @@ -20,7 +20,7 @@ async def list_prompts(): prompt_gallery = json.load(f) prompt_templates = [] - prompts_dir = lab_dirs.get_prompt_templates_dir() + prompts_dir = await lab_dirs.get_prompt_templates_dir() for file in os.listdir(prompts_dir): if file.endswith(".json"): with open(os.path.join(prompts_dir, file), "r") as f: @@ -42,7 +42,7 @@ async def new_prompt(title: Annotated[str, Body()], text: Annotated[str, Body()] return {"status": "error", "message": "The text must include the placeholder {text}"} slug = slugify(title) - prompts_dir = lab_dirs.get_prompt_templates_dir() + prompts_dir = await lab_dirs.get_prompt_templates_dir() prompt_file = os.path.join(prompts_dir, f"{slug}.json") @@ -60,7 +60,7 @@ async def new_prompt(title: Annotated[str, Body()], text: Annotated[str, Body()] async def delete_prompt(prompt_id: str): """Delete a prompt template""" - prompts_dir = lab_dirs.get_prompt_templates_dir() + prompts_dir = await lab_dirs.get_prompt_templates_dir() prompt_file = os.path.join(prompts_dir, f"{prompt_id}.json") if os.path.exists(prompt_file): diff --git a/api/transformerlab/routers/recipes.py b/api/transformerlab/routers/recipes.py index cc6e72150..36537ee60 100644 --- a/api/transformerlab/routers/recipes.py +++ b/api/transformerlab/routers/recipes.py @@ -15,14 +15,14 @@ @router.get("/list") async def list_recipes(): """List all recipes for a given experiment name.""" - recipes_gallery = galleries.get_exp_recipe_gallery() + recipes_gallery = await galleries.get_exp_recipe_gallery() return recipes_gallery @router.get("/{id}") async def get_recipe_by_id(id: str): """Fetch a recipe by its ID from the experiment recipe gallery.""" - recipes_gallery = galleries.get_exp_recipe_gallery() + recipes_gallery = await galleries.get_exp_recipe_gallery() for recipe in recipes_gallery: if recipe.get("id") == id: return recipe @@ -33,7 +33,7 @@ async def get_recipe_by_id(id: str): async def check_recipe_dependencies(id: str): """Check if the dependencies for a recipe are installed for a given environment.""" # Get the recipe - recipes_gallery = galleries.get_exp_recipe_gallery() + recipes_gallery = await galleries.get_exp_recipe_gallery() recipe = next((r for r in recipes_gallery if r.get("id") == id), None) if not recipe: return {"error": f"Recipe with id {id} not found."} @@ -44,7 +44,7 @@ async def check_recipe_dependencies(id: str): # Get local models and datasets local_models = await model_helper.list_installed_models() local_model_names = set(model["model_id"] for model in local_models) - local_datasets = Dataset.list_all() + local_datasets = await Dataset.list_all() local_dataset_ids = set(ds["dataset_id"] for ds in local_datasets) # Get installed plugins using the same logic as /plugins/gallery @@ -77,10 +77,10 @@ async def _install_recipe_dependencies_job(job_id, id): from transformerlab.routers import plugins as plugins_router try: - job = job_service.job_get(job_id) - experiment_id = job["experiment_id"] + job = await job_service.job_get(job_id) + experiment_id = job.get("experiment_id") if job else None await job_update_status(job_id, "RUNNING", experiment_id=experiment_id) - recipes_gallery = galleries.get_exp_recipe_gallery() + recipes_gallery = await galleries.get_exp_recipe_gallery() recipe = next((r for r in recipes_gallery if r.get("id") == id), None) if not recipe: await job_update_status( @@ -95,7 +95,7 @@ async def _install_recipe_dependencies_job(job_id, id): await job_update_status(job_id, "COMPLETE", experiment_id=experiment_id) return - local_datasets = Dataset.list_all() + local_datasets = await Dataset.list_all() local_dataset_ids = set(ds["dataset_id"] for ds in local_datasets) total = len(non_model_deps) progress = 0 @@ -126,8 +126,8 @@ async def _install_recipe_dependencies_job(job_id, id): result["status"] = str(e) results.append(result) progress += 1 - job_service.job_update_progress(job_id, int(progress * 100 / total), experiment_id) - job_service.job_update_job_data_insert_key_value(job_id, "results", results, experiment_id) + await job_service.job_update_progress(job_id, int(progress * 100 / total), experiment_id) + await job_service.job_update_job_data_insert_key_value(job_id, "results", results, experiment_id) await job_update_status(job_id, "COMPLETE", experiment_id=experiment_id) except Exception as e: await job_update_status(job_id, "FAILED", experiment_id=experiment_id, error_msg=str(e)) @@ -140,7 +140,7 @@ async def install_recipe_model_dependencies(id: str): import asyncio # Get the recipe - recipes_gallery = galleries.get_exp_recipe_gallery() + recipes_gallery = await galleries.get_exp_recipe_gallery() recipe = next((r for r in recipes_gallery if r.get("id") == id), None) if not recipe: return {"error": f"Recipe with id {id} not found."} @@ -155,7 +155,7 @@ async def install_recipe_model_dependencies(id: str): dep_name = dep.get("name") if dep_name not in local_model_names: # Create a DOWNLOAD_MODEL job for this model - job_id = job_service.job_create( + job_id = await job_service.job_create( type="DOWNLOAD_MODEL", status="QUEUED", job_data=json.dumps({"model_id": dep_name}), @@ -194,7 +194,7 @@ async def install_recipe_dependencies(id: str, background_tasks: BackgroundTasks return model_result # Install other dependencies as a background job - job_id = job_service.job_create( + job_id = await job_service.job_create( type="INSTALL_RECIPE_DEPS", status="QUEUED", job_data=json.dumps({"recipe_id": id, "results": [], "progress": 0}), @@ -232,7 +232,7 @@ async def install_recipe_dependencies(id: str, background_tasks: BackgroundTasks @router.get("/jobs/{job_id}/status") async def get_install_job_status(job_id: int): """Get the status and progress of a dependency installation job.""" - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) if not job: return {"error": f"Job {job_id} not found."} return { @@ -250,14 +250,14 @@ async def create_experiment_for_recipe(id: str, experiment_name: str): from transformerlab.routers.experiment import experiment as experiment_router # Check if experiment already exists - existing = experiment_service.experiment_get(experiment_name) + existing = await experiment_service.experiment_get(experiment_name) if existing: return {"status": "error", "message": f"Experiment '{experiment_name}' already exists.", "data": {}} # Create experiment with blank config - experiment_id = experiment_service.experiment_create(name=experiment_name, config={}) + experiment_id = await experiment_service.experiment_create(name=experiment_name, config={}) # Get the recipe - recipes_gallery = galleries.get_exp_recipe_gallery() + recipes_gallery = await galleries.get_exp_recipe_gallery() recipe = next((r for r in recipes_gallery if r.get("id") == id), None) if not recipe: return {"status": "error", "message": f"Recipe with id {id} not found.", "data": {}} diff --git a/api/transformerlab/routers/serverinfo.py b/api/transformerlab/routers/serverinfo.py index 231592f88..9d8a8f1c4 100644 --- a/api/transformerlab/routers/serverinfo.py +++ b/api/transformerlab/routers/serverinfo.py @@ -321,15 +321,15 @@ async def watch_s3_file( print(f"👀 Watching S3 file: {filename}") # create the file if it doesn't already exist: - if not storage.exists(filename): - with storage.open(filename, "w") as f: - f.write("") + if not await storage.exists(filename): + async with await storage.open(filename, "w") as f: + await f.write("") last_content = "" if start_from_beginning: try: - with storage.open(filename, "r") as f: - last_content = f.read() + async with await storage.open(filename, "r") as f: + last_content = await f.read() if last_content: lines = last_content.splitlines(keepends=True) yield (f"data: {json.dumps(lines)}\n\n") @@ -339,8 +339,8 @@ async def watch_s3_file( else: # Start from current end of file try: - with storage.open(filename, "r") as f: - last_content = f.read() + async with await storage.open(filename, "r") as f: + last_content = await f.read() except Exception as e: print(f"Error reading S3 file: {e}") last_content = "" @@ -349,8 +349,8 @@ async def watch_s3_file( while True: await asyncio.sleep(poll_interval_ms / 1000.0) try: - with storage.open(filename, "r") as f: - current_content = f.read() + async with await storage.open(filename, "r") as f: + current_content = await f.read() # Check if file has grown if len(current_content) > len(last_content): @@ -375,46 +375,50 @@ async def watch_file(filename: str, start_from_beginning=False, force_polling=Tr print(f"👀 Watching file: {filename}") # create the file if it doesn't already exist: - if not storage.exists(filename): - with storage.open(filename, "w") as f: - f.write("") + if not await storage.exists(filename): + async with await storage.open(filename, "w") as f: + await f.write("") last_position = 0 if start_from_beginning: last_position = 0 - with storage.open(filename, "r") as f: - f.seek(last_position) - new_lines = f.readlines() + async with await storage.open(filename, "r") as f: + await f.seek(last_position) + new_lines = await f.readlines() yield (f"data: {json.dumps(new_lines)}\n\n") - last_position = f.tell() + last_position = await f.tell() else: try: - with storage.open(filename, "r") as f: - f.seek(0, os.SEEK_END) - last_position = f.tell() + async with await storage.open(filename, "r") as f: + await f.seek(0, os.SEEK_END) + last_position = await f.tell() except Exception as e: print(f"Error seeking to end of file: {e}") async for changes in awatch(filename, force_polling=force_polling, poll_delay_ms=100): - with storage.open(filename, "r") as f: - f.seek(last_position) - new_lines = f.readlines() + async with await storage.open(filename, "r") as f: + await f.seek(last_position) + new_lines = await f.readlines() yield (f"data: {json.dumps(new_lines)}\n\n") - last_position = f.tell() + last_position = await f.tell() @router.get("/stream_log") async def watch_log(): - global_log_path = get_global_log_path() + global_log_path = await get_global_log_path() - if not storage.exists(global_log_path): - # Create the file - with storage.open(global_log_path, "w") as f: - f.write("") - try: - # Check if the path is an S3 or other remote filesystem path - is_remote_path = global_log_path.startswith(("s3://", "gs://", "abfs://", "gcs://")) + # Check if the path is an S3 or other remote filesystem path + is_remote_path = global_log_path.startswith(("s3://", "gs://", "abfs://", "gcs://")) + if not await storage.exists(global_log_path): + # Create the file using appropriate method + if is_remote_path: + async with await storage.open(global_log_path, "w") as f: + await f.write("") + else: + async with await storage.open(global_log_path, "w") as f: + await f.write("") + try: if is_remote_path: # Use S3 polling watcher for remote filesystems return StreamingResponse( @@ -453,13 +457,13 @@ async def download_logs(): local_server_log_path = storage.join(HOME_DIR, "local_server.log") # Path to transformerlab.log in workspace_dir - transformerlab_log_path = get_global_log_path() + transformerlab_log_path = await get_global_log_path() # Check which files exist and add them to the list - if storage.exists(local_server_log_path): + if await storage.exists(local_server_log_path): log_files.append(("local_server.log", local_server_log_path)) - if storage.exists(transformerlab_log_path): + if await storage.exists(transformerlab_log_path): log_files.append(("transformerlab.log", transformerlab_log_path)) # If no files exist, return an error @@ -475,8 +479,8 @@ async def download_logs(): for filename, file_path in log_files: try: # Read file content (works with both local and remote storage) - with storage.open(file_path, "rb") as log_file: - content = log_file.read() + async with await storage.open(file_path, "rb") as log_file: + content = await log_file.read() zipf.writestr(filename, content) except Exception as e: # If we can't read a file, log it but continue with others diff --git a/api/transformerlab/routers/tasks.py b/api/transformerlab/routers/tasks.py index ced3d51db..608fd9f69 100644 --- a/api/transformerlab/routers/tasks.py +++ b/api/transformerlab/routers/tasks.py @@ -22,13 +22,13 @@ class DeleteTeamTaskFromGalleryRequest(BaseModel): @router.get("/list", summary="Returns all the tasks") async def tasks_get_all(): - tasks = tasks_service.tasks_get_all() + tasks = await tasks_service.tasks_get_all() return tasks @router.get("/{task_id}/get", summary="Gets all the data for a single task") async def tasks_get_by_id(task_id: str): - task = tasks_service.tasks_get_by_id(task_id) + task = await tasks_service.tasks_get_by_id(task_id) if task is None: return {"message": "NOT FOUND"} return task @@ -36,7 +36,7 @@ async def tasks_get_by_id(task_id: str): @router.get("/list_by_type", summary="Returns all the tasks of a certain type, e.g TRAIN") async def tasks_get_by_type(type: str): - tasks = tasks_service.tasks_get_by_type(type) + tasks = await tasks_service.tasks_get_by_type(type) return tasks @@ -44,7 +44,7 @@ async def tasks_get_by_type(type: str): "/list_by_type_in_experiment", summary="Returns all the tasks of a certain type in a certain experiment, e.g TRAIN" ) async def tasks_get_by_type_in_experiment(type: str, experiment_id: str): - tasks = tasks_service.tasks_get_by_type_in_experiment(type, experiment_id) + tasks = await tasks_service.tasks_get_by_type_in_experiment(type, experiment_id) return tasks @@ -57,7 +57,7 @@ async def tasks_get_by_subtype_in_experiment( subtype: str, remote_task: Optional[bool] = None, ): - tasks = tasks_service.tasks_get_by_experiment(experiment_id) + tasks = await tasks_service.tasks_get_by_experiment(experiment_id) filtered = [] for t in tasks: cfg = t.get("config", {}) @@ -79,7 +79,7 @@ async def update_task(task_id: str, new_task: dict = Body()): # Perform secure_filename before updating the task if "name" in new_task: new_task["name"] = secure_filename(new_task["name"]) - success = tasks_service.update_task(task_id, new_task) + success = await tasks_service.update_task(task_id, new_task) if success: return {"message": "OK"} else: @@ -88,7 +88,7 @@ async def update_task(task_id: str, new_task: dict = Body()): @router.get("/{task_id}/delete", summary="Deletes a task") async def delete_task(task_id: str): - success = tasks_service.delete_task(task_id) + success = await tasks_service.delete_task(task_id) if success: return {"message": "OK"} else: @@ -106,7 +106,7 @@ async def add_task(new_task: dict = Body()): except Exception: remote_task_flag = False - tasks_service.add_task( + await tasks_service.add_task( new_task["name"], new_task["type"], new_task["inputs"], @@ -137,7 +137,7 @@ async def add_task(new_task: dict = Body()): # Repeat for dataset dataset_downloaded = False - local_datasets = Dataset.list_all() + local_datasets = await Dataset.list_all() for dataset in local_datasets: if dataset["dataset_id"] == datasets: dataset_downloaded = True @@ -164,7 +164,7 @@ async def add_task(new_task: dict = Body()): @router.get("/delete_all", summary="Wipe the task table") async def tasks_delete_all(): - tasks_service.tasks_delete_all() + await tasks_service.tasks_delete_all() return {"message": "OK"} @@ -175,7 +175,7 @@ async def queue_task( output_override: str = "{}", user_and_team=Depends(get_user_and_team), ): - task_to_queue = tasks_service.tasks_get_by_id(task_id) + task_to_queue = await tasks_service.tasks_get_by_id(task_id) if task_to_queue is None: return {"message": "TASK NOT FOUND"} @@ -275,7 +275,7 @@ async def queue_task( if user: job_data["user_id"] = str(user.id) - job_id = job_create( + job_id = await job_create( type=job_type, status=job_status, experiment_id=task_to_queue["experiment_id"], @@ -287,7 +287,7 @@ async def queue_task( @router.get("/gallery", summary="List all tasks from the tasks gallery") async def tasks_gallery(): """Get the tasks gallery from the JSON file""" - gallery = galleries.get_tasks_gallery() + gallery = await galleries.get_tasks_gallery() return {"status": "success", "data": gallery} @@ -327,7 +327,7 @@ async def import_task_from_gallery( Creates a new task using the gallery entry's config and GitHub info. Uses the team's GitHub PAT if available. """ - gallery = galleries.get_tasks_gallery() + gallery = await galleries.get_tasks_gallery() # Find the gallery entry by index or ID try: @@ -409,7 +409,7 @@ async def import_task_from_gallery( # Perform secure_filename before adding the task new_task["name"] = secure_filename(new_task["name"]) - tasks_service.add_task( + await tasks_service.add_task( new_task["name"], new_task["type"], new_task["inputs"], @@ -426,7 +426,7 @@ async def import_task_from_gallery( @router.get("/gallery/team", summary="List team-specific tasks from the team gallery") async def team_tasks_gallery(): """Get the team-specific tasks gallery stored in workspace_dir""" - gallery = galleries.get_team_tasks_gallery() + gallery = await galleries.get_team_tasks_gallery() return {"status": "success", "data": gallery} @@ -438,7 +438,7 @@ async def import_task_from_team_gallery( """ Import a task from the team-specific tasks gallery (workspace_dir/team_specific_tasks.json). """ - gallery = galleries.get_team_tasks_gallery() + gallery = await galleries.get_team_tasks_gallery() # Find the gallery entry by index or ID try: @@ -517,7 +517,7 @@ async def import_task_from_team_gallery( # Perform secure_filename before adding the task new_task["name"] = secure_filename(new_task["name"]) - tasks_service.add_task( + await tasks_service.add_task( new_task["name"], new_task["type"], new_task["inputs"], @@ -539,7 +539,7 @@ async def export_task_to_team_gallery( """ Export a task template into the team-specific gallery stored in workspace_dir. """ - task = tasks_service.tasks_get_by_id(request.task_id) + task = await tasks_service.tasks_get_by_id(request.task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") @@ -559,7 +559,7 @@ async def export_task_to_team_gallery( "github_repo_dir": config.get("github_directory") or config.get("github_repo_dir"), } - galleries.add_team_task_to_gallery(gallery_entry) + await galleries.add_team_task_to_gallery(gallery_entry) return { "status": "success", @@ -606,7 +606,7 @@ async def add_team_task_to_gallery( "github_repo_dir": request.github_repo_dir, } - galleries.add_team_task_to_gallery(gallery_entry) + await galleries.add_team_task_to_gallery(gallery_entry) return { "status": "success", @@ -623,7 +623,7 @@ async def delete_team_task_from_gallery( """ Delete a task from the team-specific gallery stored in workspace_dir. """ - success = galleries.delete_team_task_from_gallery(request.task_id) + success = await galleries.delete_team_task_from_gallery(request.task_id) if success: return { "status": "success", diff --git a/api/transformerlab/routers/teams.py b/api/transformerlab/routers/teams.py index 1f4ac943d..b202dd825 100644 --- a/api/transformerlab/routers/teams.py +++ b/api/transformerlab/routers/teams.py @@ -109,12 +109,12 @@ async def create_team( set_organization_id(team.id) # Create the default experiment - Experiment("alpha", create_new=True) + await Experiment.create_or_get("alpha", create_new=True) # Save logo if provided if logo: try: - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() logo_path = storage.join(workspace_dir, "logo.png") # Validate content type @@ -174,7 +174,7 @@ async def create_team( image = image.convert("RGB") # Save as PNG - with storage.open(logo_path, "wb") as f: + async with await storage.open(logo_path, "wb") as f: image.save(f, format="PNG") except HTTPException: # Re-raise HTTPExceptions (validation errors) @@ -829,21 +829,22 @@ async def get_github_pat( if team_id != user_and_team["team_id"]: raise HTTPException(status_code=400, detail="Team ID mismatch") - pat_path = storage.join(get_workspace_dir(), "github_pat.txt") + workspace_dir = await get_workspace_dir() + pat_path = storage.join(workspace_dir, "github_pat.txt") - try: - with storage.open(pat_path, "r") as f: - raw_pat = f.read().rstrip("\n") - except FileNotFoundError: - return {"status": "success", "pat_exists": False} - except Exception: - return {"status": "error", "message": "Failed to read GitHub PAT"} + if await storage.exists(pat_path): + try: + async with await storage.open(pat_path, "r") as f: + pat = (await f.read()).strip() + if pat: + # Return masked version for security (only show last 4 chars) + masked_pat = mask_key(pat) + return {"status": "success", "pat_exists": True, "masked_pat": masked_pat} + except Exception as e: + print(f"Error reading GitHub PAT: {e}") + return {"status": "error", "message": "Failed to read GitHub PAT"} - return { - "status": "success", - "pat_exists": True, - "masked_pat": mask_key(raw_pat), - } + return {"status": "error", "message": "GitHub PAT not found"} @router.put("/teams/{team_id}/github_pat") @@ -861,20 +862,20 @@ async def set_github_pat( if team_id != owner_info["team_id"]: raise HTTPException(status_code=400, detail="Team ID mismatch") - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() pat_path = storage.join(workspace_dir, "github_pat.txt") try: pat = pat_data.pat if pat and pat.strip(): # Store the PAT - with storage.open(pat_path, "w") as f: - f.write(pat.strip()) + async with await storage.open(pat_path, "w") as f: + await f.write(pat.strip()) return {"status": "success", "message": "GitHub PAT saved successfully"} else: # Remove the PAT if empty string is provided - if storage.exists(pat_path): - storage.rm(pat_path) + if await storage.exists(pat_path): + await storage.rm(pat_path) return {"status": "success", "message": "GitHub PAT removed successfully"} except Exception as e: print(f"Error saving GitHub PAT: {e}") @@ -894,10 +895,10 @@ async def get_team_logo( if team_id != user_and_team["team_id"]: raise HTTPException(status_code=400, detail="Team ID mismatch") - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() logo_path = storage.join(workspace_dir, "logo.png") - if not storage.exists(logo_path): + if not await storage.exists(logo_path): raise HTTPException(status_code=404, detail="Team logo not found") try: @@ -906,8 +907,8 @@ async def get_team_logo( return FileResponse(logo_path, media_type="image/png") else: # For remote storage, read and return as bytes - with storage.open(logo_path, "rb") as f: - return Response(content=f.read(), media_type="image/png") + async with await storage.open(logo_path, "rb") as f: + return Response(content=await f.read(), media_type="image/png") except Exception as e: print(f"Error reading team logo: {e}") raise HTTPException(status_code=500, detail="Failed to read team logo") @@ -927,7 +928,7 @@ async def set_team_logo( if team_id != owner_info["team_id"]: raise HTTPException(status_code=400, detail="Team ID mismatch") - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() logo_path = storage.join(workspace_dir, "logo.png") try: @@ -987,7 +988,7 @@ async def set_team_logo( image = image.convert("RGB") # Save as PNG - with storage.open(logo_path, "wb") as f: + async with await storage.open(logo_path, "wb") as f: image.save(f, format="PNG") return {"status": "success", "message": "Team logo saved successfully"} @@ -1008,12 +1009,12 @@ async def delete_team_logo( if team_id != owner_info["team_id"]: raise HTTPException(status_code=400, detail="Team ID mismatch") - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() logo_path = storage.join(workspace_dir, "logo.png") try: - if storage.exists(logo_path): - storage.rm(logo_path) + if await storage.exists(logo_path): + await storage.rm(logo_path) return {"status": "success", "message": "Team logo deleted successfully"} except Exception as e: print(f"Error deleting team logo: {e}") diff --git a/api/transformerlab/routers/train.py b/api/transformerlab/routers/train.py index 835dcff10..3ef62b811 100644 --- a/api/transformerlab/routers/train.py +++ b/api/transformerlab/routers/train.py @@ -38,11 +38,11 @@ async def spawn_tensorboard(job_id: str): print("Starting tensorboard") - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) # First get the experiment name from the job experiment_id = job["experiment_id"] exp_obj = Experiment(experiment_id) - experiment_dir = exp_obj.get_dir() + experiment_dir = await exp_obj.get_dir() job_data = job["job_data"] if "template_name" not in job_data.keys(): @@ -52,6 +52,6 @@ async def spawn_tensorboard(job_id: str): template_name = secure_filename(template_name) logdir = storage.join(experiment_dir, "tensorboards", template_name) - storage.makedirs(logdir, exist_ok=True) + await storage.makedirs(logdir, exist_ok=True) tensorboard_process = subprocess.Popen(["tensorboard", "--logdir", logdir, "--host", "0.0.0.0"]) diff --git a/api/transformerlab/services/dataset_service.py b/api/transformerlab/services/dataset_service.py index 40bb7ff16..885199104 100644 --- a/api/transformerlab/services/dataset_service.py +++ b/api/transformerlab/services/dataset_service.py @@ -5,19 +5,19 @@ from datasets import load_dataset, load_from_disk, Dataset as HFDataset, DatasetDict -def create_local_dataset(dataset_id, json_data=None): +async def create_local_dataset(dataset_id, json_data=None): # Create a new dataset - new_dataset = Dataset.create(dataset_id) + new_dataset = await Dataset.create(dataset_id) # Take description from json_data if it exists description = json_data.get("description", "") if isinstance(json_data, dict) else "" - new_dataset.set_metadata( + await new_dataset.set_metadata( location="local", description=description, size=-1, json_data=json_data if json_data is not None else {} ) return new_dataset -def load_local_dataset(dataset_dir, data_files=None, streaming=False): +async def load_local_dataset(dataset_dir, data_files=None, streaming=False): """ Load a local dataset, excluding index.json from the data files. This prevents accidental ingestion of repository index/metadata files @@ -46,7 +46,7 @@ def load_local_dataset(dataset_dir, data_files=None, streaming=False): if data_files is None: try: # Use storage.ls() which works with both local and remote (S3) paths - entries = storage.ls(dataset_dir, detail=False) + entries = await storage.ls(dataset_dir, detail=False) except Exception: entries = [] @@ -66,7 +66,7 @@ def load_local_dataset(dataset_dir, data_files=None, streaming=False): continue # Check if it's a file (entry_path is already the full path from storage.ls) - if storage.isfile(entry_path): + if await storage.isfile(entry_path): filtered_files.append(name) data_files = filtered_files @@ -81,16 +81,17 @@ def load_local_dataset(dataset_dir, data_files=None, streaming=False): # Read JSON files from remote storage and create dataset all_data = [] for json_file_path in data_file_paths: - with storage.open(json_file_path, "r", encoding="utf-8") as f: + async with await storage.open(json_file_path, "r", encoding="utf-8") as f: + content = await f.read() if json_file_path.endswith(".jsonl"): # JSONL format (one JSON object per line) - for line in f: + for line in content.splitlines(): line = line.strip() if line: all_data.append(json.loads(line)) else: # Regular JSON format - file_data = json.load(f) + file_data = json.loads(content) if isinstance(file_data, list): all_data.extend(file_data) elif isinstance(file_data, dict): diff --git a/api/transformerlab/services/experiment_init.py b/api/transformerlab/services/experiment_init.py index b15160fb4..232613f48 100644 --- a/api/transformerlab/services/experiment_init.py +++ b/api/transformerlab/services/experiment_init.py @@ -227,8 +227,8 @@ async def migrate_workspace_to_org(team_id: str): pass # Recreate workspace directory (default sdk behaviour is to create this directory again when auth isnt done -- which will happen at startup) - if not storage.exists(old_workspace): - storage.makedirs(old_workspace, exist_ok=True) + if not await storage.exists(old_workspace): + await storage.makedirs(old_workspace, exist_ok=True) # Add a text file in the old workspace saying where the migration happened with open(os.path.join(old_workspace, "migration.txt"), "w") as f: @@ -338,11 +338,11 @@ def update_diffusion_history_paths(old_workspace: str, new_workspace: str): traceback.print_exc() -def seed_default_experiments(): +async def seed_default_experiments(): """Create a few default experiments if they do not exist (filesystem-backed).""" # Only seed default experiments if there are no experiments at all try: - existing_experiments = Experiment.get_all() + existing_experiments = await Experiment.get_all() if len(existing_experiments) > 0: return except Exception as e: @@ -351,7 +351,7 @@ def seed_default_experiments(): for name in ["alpha", "beta", "gamma"]: try: - exp = Experiment(name, create_new=True) + exp = await Experiment.create_or_get(name, create_new=True) # Sanity check to make sure nothing went wrong or no Exception was silently passed if exp.id != name: raise Exception(f"Error creating experiment {name}: {exp.id} != {name}") @@ -361,7 +361,7 @@ def seed_default_experiments(): pass -def cancel_in_progress_jobs(): +async def cancel_in_progress_jobs(): """On startup, mark any RUNNING jobs as CANCELLED in the filesystem job store across all organizations. REMOTE jobs are excluded from this cancellation as they run on external compute providers.""" # Get HOME_DIR @@ -372,34 +372,34 @@ def cancel_in_progress_jobs(): # Check all org directories orgs_dir = storage.join(home_dir, "orgs") - if storage.exists(orgs_dir) and storage.isdir(orgs_dir): + if await storage.exists(orgs_dir) and await storage.isdir(orgs_dir): try: - org_entries = storage.ls(orgs_dir, detail=False) + org_entries = await storage.ls(orgs_dir, detail=False) for org_path in org_entries: - if storage.isdir(org_path): + if await storage.isdir(org_path): org_id = org_path.rstrip("/").split("/")[-1] # Set org context to check jobs for this org lab_dirs.set_organization_id(org_id) try: - jobs_dir = get_jobs_dir() - if storage.exists(jobs_dir): - entries = storage.ls(jobs_dir, detail=False) + jobs_dir = await get_jobs_dir() + if await storage.exists(jobs_dir): + entries = await storage.ls(jobs_dir, detail=False) for entry_path in entries: - if storage.isdir(entry_path): + if await storage.isdir(entry_path): try: # Extract the job ID from the path job_id = entry_path.rstrip("/").split("/")[-1] - job = Job.get(job_id) - if job.get_status() == "RUNNING": + job = await Job.get(job_id) + if await job.get_status() == "RUNNING": # Skip REMOTE jobs - they should not be cancelled on startup - job_data = job.get_json_data(uncached=True) + job_data = await job.get_json_data(uncached=True) job_type = job_data.get("type", "") if job_type == "REMOTE": print(f"Skipping REMOTE job: {job_id} (org: {org_id})") else: - job.update_status("CANCELLED") + await job.update_status("CANCELLED") print(f"Cancelled running job: {job_id} (org: {org_id})") except Exception: # If we can't access the job, continue to the next one diff --git a/api/transformerlab/services/experiment_service.py b/api/transformerlab/services/experiment_service.py index efdab0043..9ddde9e99 100644 --- a/api/transformerlab/services/experiment_service.py +++ b/api/transformerlab/services/experiment_service.py @@ -5,22 +5,22 @@ from lab import storage -def experiment_get_all(): +async def experiment_get_all(): experiments = [] - experiments_dir = lab_dirs.get_experiments_dir() - if storage.exists(experiments_dir): + experiments_dir = await lab_dirs.get_experiments_dir() + if await storage.exists(experiments_dir): try: - exp_dirs = storage.ls(experiments_dir, detail=False) + exp_dirs = await storage.ls(experiments_dir, detail=False) # Sort the directories exp_dirs = sorted(exp_dirs) for exp_path in exp_dirs: # Skip if this is the experiments directory itself (shouldn't happen but safety check) if exp_path.rstrip("/") == experiments_dir.rstrip("/"): continue - if storage.isdir(exp_path): + if await storage.isdir(exp_path): # Check if this directory is actually a valid experiment by checking for index.json index_file = storage.join(exp_path, "index.json") - if not storage.exists(index_file): + if not await storage.exists(index_file): # Skip directories that don't have index.json (not valid experiments) continue # Extract the directory name from the path @@ -28,7 +28,7 @@ def experiment_get_all(): # Skip if the extracted name is the experiments directory itself (shouldn't happen but safety check) if exp_dir == "experiments": continue - exp_dict = experiment_get(exp_dir) + exp_dict = await experiment_get(exp_dir) if exp_dict: experiments.append(exp_dict) except Exception: @@ -36,15 +36,15 @@ def experiment_get_all(): return experiments -def experiment_create(name: str, config: dict) -> str: - Experiment.create_with_config(name, config) +async def experiment_create(name: str, config: dict) -> str: + await Experiment.create_with_config(name, config) return name -def experiment_get(id): +async def experiment_get(id): try: - exp = Experiment.get(id) - data = exp.get_json_data() + exp = await Experiment.get(id) + data = await exp.get_json_data() # Parse config field from JSON string to dict if needed config = data.get("config", {}) if isinstance(config, str): @@ -61,50 +61,50 @@ def experiment_get(id): return None -def experiment_delete(id): +async def experiment_delete(id): try: - exp = Experiment.get(id) - exp.delete() + exp = await Experiment.get(id) + await exp.delete() except FileNotFoundError: print(f"Experiment with id '{id}' not found") except Exception as e: print(f"Error deleting experiment {id}: {e}") -def experiment_update(id, config): +async def experiment_update(id, config): try: - exp = Experiment.get(id) - exp.update_config(config) + exp = await Experiment.get(id) + await exp.update_config(config) except FileNotFoundError: print(f"Experiment with id '{id}' not found") except Exception as e: print(f"Error updating experiment {id}: {e}") -def experiment_update_config(id, key, value): +async def experiment_update_config(id, key, value): try: - exp = Experiment.get(id) - exp.update_config_field(key, value) + exp = await Experiment.get(id) + await exp.update_config_field(key, value) except FileNotFoundError: print(f"Experiment with id '{id}' not found") except Exception as e: print(f"Error updating experiment config key {key}: {e}") -def experiment_save_prompt_template(id, template): +async def experiment_save_prompt_template(id, template): try: - exp_obj = Experiment.get(id) - exp_obj.update_config_field("prompt_template", template) + exp_obj = await Experiment.get(id) + await exp_obj.update_config_field("prompt_template", template) except FileNotFoundError: print(f"Experiment with id '{id}' not found") except Exception as e: print(f"Error saving prompt template: {e}") -def experiment_update_configs(id, updates: dict): +async def experiment_update_configs(id, updates: dict): try: - exp_obj = Experiment.get(id) - exp_obj.update_config(updates) + exp_obj = await Experiment.get(id) + await exp_obj.update_config(updates) except FileNotFoundError: print(f"Experiment with id '{id}' not found") except Exception as e: diff --git a/api/transformerlab/services/job_service.py b/api/transformerlab/services/job_service.py index dbda49472..f27337055 100644 --- a/api/transformerlab/services/job_service.py +++ b/api/transformerlab/services/job_service.py @@ -1,3 +1,4 @@ +import asyncio import datetime import json import os @@ -27,7 +28,7 @@ SUPPORTED_WORKFLOW_TRIGGERS = ["TRAIN", "LOAD_MODEL", "EXPORT", "EVAL", "GENERATE", "DOWNLOAD_MODEL"] -def job_create(type, status, experiment_id, job_data="{}"): +async def job_create(type, status, experiment_id, job_data="{}"): # check if type is allowed if type not in ALLOWED_JOB_TYPES: raise ValueError(f"Job type {type} is not allowed") @@ -43,42 +44,43 @@ def job_create(type, status, experiment_id, job_data="{}"): exp = Experiment(experiment_id) # Create job through experiment - job = exp.create_job() - job.set_type(type) - job.update_status(status) - job.set_job_data(job_data) + job = await exp.create_job() + await job.set_experiment(experiment_id) # Set the experiment_id on the job + await job.set_type(type) + await job.update_status(status) + await job.set_job_data(job_data) return job.id -def jobs_get_all(experiment_id, type="", status=""): +async def jobs_get_all(experiment_id, type="", status=""): exp_obj = Experiment(experiment_id) - return exp_obj.get_jobs(type, status) + return await exp_obj.get_jobs(type, status) -def jobs_get_all_by_experiment_and_type(experiment_id, job_type): - return jobs_get_all(experiment_id, job_type) +async def jobs_get_all_by_experiment_and_type(experiment_id, job_type): + return await jobs_get_all(experiment_id, job_type) -def jobs_get_by_experiment(experiment_id): +async def jobs_get_by_experiment(experiment_id): """Get all jobs for a specific experiment""" - return jobs_get_all(experiment_id) + return await jobs_get_all(experiment_id) -def job_get(job_id): +async def job_get(job_id): try: - job = Job.get(job_id) - return job.get_json_data(uncached=True) + job = await Job.get(job_id) + return await job.get_json_data(uncached=True) except Exception as e: print("Error getting job data", e) return None -def job_count_running(): - return Job.count_running_jobs() +async def job_count_running(): + return await Job.count_running_jobs() -def _find_org_id_for_job(job_id: str) -> Optional[str]: +async def _find_org_id_for_job(job_id: str) -> Optional[str]: """ Find which organization a job belongs to by searching all org directories. Returns the org_id if found, None otherwise. @@ -92,25 +94,25 @@ def _find_org_id_for_job(job_id: str) -> Optional[str]: # Check if context is set correctly already from lab.dirs import get_workspace_dir - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() if "/orgs/" in workspace_dir: return workspace_dir.split("/orgs/")[-1].split("/")[0] # Check all org directories orgs_dir = storage.join(home_dir, "orgs") - if storage.exists(orgs_dir) and storage.isdir(orgs_dir): + if await storage.exists(orgs_dir) and await storage.isdir(orgs_dir): try: - org_entries = storage.ls(orgs_dir, detail=False) + org_entries = await storage.ls(orgs_dir, detail=False) for org_path in org_entries: - if storage.isdir(org_path): + if await storage.isdir(org_path): org_id = org_path.rstrip("/").split("/")[-1] # Set org context and check if job exists lab_dirs.set_organization_id(org_id) try: - jobs_dir = lab_dirs.get_jobs_dir() + jobs_dir = await lab_dirs.get_jobs_dir() job_path = storage.join(jobs_dir, job_id) - if storage.exists(job_path) and storage.isdir(job_path): + if await storage.exists(job_path) and await storage.isdir(job_path): # Job found in this org lab_dirs.set_organization_id(None) return org_id @@ -124,7 +126,7 @@ def _find_org_id_for_job(job_id: str) -> Optional[str]: return None -def job_count_running_across_all_orgs() -> int: +async def job_count_running_across_all_orgs() -> int: """ Count running jobs across all organizations. Returns the total count of jobs with status "RUNNING" across all orgs. @@ -139,15 +141,15 @@ def job_count_running_across_all_orgs() -> int: # Check all org directories orgs_dir = storage.join(home_dir, "orgs") - if storage.exists(orgs_dir) and storage.isdir(orgs_dir): + if await storage.exists(orgs_dir) and await storage.isdir(orgs_dir): try: - org_entries = storage.ls(orgs_dir, detail=False) + org_entries = await storage.ls(orgs_dir, detail=False) for org_path in org_entries: - if storage.isdir(org_path): + if await storage.isdir(org_path): org_id = org_path.rstrip("/").split("/")[-1] lab_dirs.set_organization_id(org_id) try: - count += Job.count_running_jobs() + count += await Job.count_running_jobs() except Exception: continue except Exception: @@ -159,11 +161,11 @@ def job_count_running_across_all_orgs() -> int: return count -def jobs_get_next_queued_job(): - return Job.get_next_queued_job() +async def jobs_get_next_queued_job(): + return await Job.get_next_queued_job() -def jobs_get_next_queued_job_across_all_orgs() -> Tuple[Optional[dict], Optional[str]]: +async def jobs_get_next_queued_job_across_all_orgs() -> Tuple[Optional[dict], Optional[str]]: """ Get the next queued job across all organizations. Returns a tuple of (job_data_dict, organization_id) or (None, None) if no queued jobs found. @@ -184,11 +186,11 @@ def jobs_get_next_queued_job_across_all_orgs() -> Tuple[Optional[dict], Optional orgs_dir = storage.join(home_dir, "orgs") # Check all org directories - if storage.exists(orgs_dir) and storage.isdir(orgs_dir): + if await storage.exists(orgs_dir) and await storage.isdir(orgs_dir): try: - org_entries = storage.ls(orgs_dir, detail=False) + org_entries = await storage.ls(orgs_dir, detail=False) for org_path in org_entries: - if storage.isdir(org_path): + if await storage.isdir(org_path): org_id = org_path.rstrip("/").split("/")[-1] # Set org context to get jobs for this org @@ -196,16 +198,16 @@ def jobs_get_next_queued_job_across_all_orgs() -> Tuple[Optional[dict], Optional try: # Get jobs directory for this org - jobs_dir = lab_dirs.get_jobs_dir() - if storage.exists(jobs_dir) and storage.isdir(jobs_dir): - entries = storage.ls(jobs_dir, detail=False) + jobs_dir = await lab_dirs.get_jobs_dir() + if await storage.exists(jobs_dir) and await storage.isdir(jobs_dir): + entries = await storage.ls(jobs_dir, detail=False) for job_path in entries: - if storage.isdir(job_path): + if await storage.isdir(job_path): job_id_str = job_path.rstrip("/").split("/")[-1] try: job_id = int(job_id_str) if job_id_str.isdigit() else 0 - job = Job.get(job_id_str) - job_data = job.get_json_data(uncached=True) + job = await Job.get(job_id_str) + job_data = await job.get_json_data(uncached=True) if job_data.get("status") == "QUEUED": queued_jobs.append((job_id, job_data, org_id)) except Exception: @@ -227,75 +229,81 @@ def jobs_get_next_queued_job_across_all_orgs() -> Tuple[Optional[dict], Optional return (None, None) -def job_delete_all(experiment_id): +async def job_delete_all(experiment_id): if experiment_id is not None: experiment = Experiment(experiment_id) - experiment.delete_all_jobs() + await experiment.delete_all_jobs() -def job_delete(job_id, experiment_id): +async def job_delete(job_id, experiment_id): try: - job = Job.get(job_id) - if experiment_id is not None and job.get_experiment_id() != experiment_id: + job = await Job.get(job_id) + exp_id = await job.get_experiment_id() + if experiment_id is not None and exp_id != experiment_id: return - job.delete() + await job.delete() except Exception as e: print(f"Error deleting job {job_id}: {e}") -def job_update_job_data_insert_key_value(job_id, key, value, experiment_id): +async def job_update_job_data_insert_key_value(job_id, key, value, experiment_id): try: - job = Job.get(job_id) - if experiment_id is not None and job.get_experiment_id() != experiment_id: + job = await Job.get(job_id) + exp_id = await job.get_experiment_id() + if experiment_id is not None and exp_id != experiment_id: return - job.update_job_data_field(key, value) + await job.update_job_data_field(key, value) except Exception as e: print(f"Error updating job {job_id}: {e}") -def job_stop(job_id, experiment_id): +async def job_stop(job_id, experiment_id): print("Stopping job: " + str(job_id)) - job_update_job_data_insert_key_value(job_id, "stop", True, experiment_id) + await job_update_job_data_insert_key_value(job_id, "stop", True, experiment_id) -def job_update_progress(job_id, progress, experiment_id): +async def job_update_progress(job_id, progress, experiment_id): """ Update the percent complete for this job. progress: int representing percent complete """ try: - job = Job.get(job_id) - if experiment_id is not None and job.get_experiment_id() != experiment_id: + job = await Job.get(job_id) + exp_id = await job.get_experiment_id() + if experiment_id is not None and exp_id != experiment_id: return - job.update_progress(progress) + await job.update_progress(progress) except Exception as e: print(f"Error updating job {job_id}: {e}") -def job_update_sweep_progress(job_id, value, experiment_id): +async def job_update_sweep_progress(job_id, value, experiment_id): """ Update the 'sweep_progress' key in the job_data JSON column for a given job. """ try: - job = Job.get(job_id) - if experiment_id is not None and job.get_experiment_id() != experiment_id: + job = await Job.get(job_id) + exp_id = await job.get_experiment_id() + if experiment_id is not None and exp_id != experiment_id: return - job.update_sweep_progress(value) + await job.update_sweep_progress(value) except Exception as e: print(f"Error updating sweep job {job_id}: {e}") -def jobs_get_sweep_children(parent_job_id, experiment_id=None): +async def jobs_get_sweep_children(parent_job_id, experiment_id=None): """ Get all child jobs that belong to a sweep parent job. """ try: - parent_job = Job.get(parent_job_id) - if experiment_id is not None and parent_job.get_experiment_id() != experiment_id: - return [] + parent_job = await Job.get(parent_job_id) + if experiment_id is not None: + exp_id = await parent_job.get_experiment_id() + if exp_id != experiment_id: + return [] - job_data = parent_job.get_job_data() + job_data = await parent_job.get_job_data() if not isinstance(job_data, dict): return [] @@ -307,9 +315,9 @@ def jobs_get_sweep_children(parent_job_id, experiment_id=None): child_jobs = [] for child_job_id in sweep_job_ids: try: - child_job = Job.get(child_job_id) + child_job = await Job.get(child_job_id) # Get full job data (including type, status, etc.) - job_json = child_job.get_json_data() + job_json = await child_job.get_json_data() child_jobs.append(job_json) except Exception: # Skip if job doesn't exist @@ -321,17 +329,19 @@ def jobs_get_sweep_children(parent_job_id, experiment_id=None): return [] -def job_get_sweep_parent(child_job_id, experiment_id=None): +async def job_get_sweep_parent(child_job_id, experiment_id=None): """ Get the parent sweep job for a child job. Returns None if the job is not a sweep child. """ try: - child_job = Job.get(child_job_id) - if experiment_id is not None and child_job.get_experiment_id() != experiment_id: - return None + child_job = await Job.get(child_job_id) + if experiment_id is not None: + exp_id = await child_job.get_experiment_id() + if exp_id != experiment_id: + return None - job_data = child_job.get_job_data() + job_data = await child_job.get_job_data() if not isinstance(job_data, dict): return None @@ -339,8 +349,8 @@ def job_get_sweep_parent(child_job_id, experiment_id=None): if not parent_job_id: return None - parent_job = Job.get(parent_job_id) - return parent_job.get_json_data() + parent_job = await Job.get(parent_job_id) + return await parent_job.get_json_data() except Exception as e: print(f"Error getting sweep parent for job {child_job_id}: {e}") return None @@ -482,7 +492,7 @@ async def _trigger_workflows_on_job_completion(job_id: str): """ try: # Get the job details - job = job_get(job_id) + job = await job_get(job_id) if not job: return @@ -532,12 +542,13 @@ async def job_update_status( """ # Update the job status using SDK Job class try: - job = Job.get(job_id) - if experiment_id is not None and job.get_experiment_id() != experiment_id: + job = await Job.get(job_id) + exp_id = await job.get_experiment_id() + if experiment_id is not None and exp_id != experiment_id: return - job.update_status(status) + await job.update_status(status) if error_msg: - job.set_error_message(error_msg) + await job.set_error_message(error_msg) except Exception as e: print(f"Error updating job {job_id}: {e}") @@ -546,7 +557,7 @@ async def job_update_status( # Track quota for REMOTE jobs when they transition to terminal states if status in ("COMPLETE", "STOPPED", "FAILED", "DELETED"): try: - job_dict = job.get_json_data() if job else {} + job_dict = await job.get_json_data() if job else {} if job_dict.get("type") == "REMOTE": # If session is provided, await quota tracking in the same transaction # Otherwise, run it as a background task @@ -580,11 +591,12 @@ async def job_update(job_id: str, type: str, status: str, experiment_id: Optiona """ # Update the job in the database using SDK Job class try: - job = Job.get(job_id) - if experiment_id is not None and job.get_experiment_id() != experiment_id: + job = await Job.get(job_id) + exp_id = await job.get_experiment_id() + if experiment_id is not None and exp_id != experiment_id: return - job.set_type(type) - job.update_status(status) + await job.set_type(type) + await job.update_status(status) except Exception as e: print(f"Error updating job {job_id}: {e}") pass @@ -609,19 +621,20 @@ def job_update_status_sync( # Update the job status using SDK Job class try: # Find which org this job belongs to (in case we're called from a callback without org context) - org_id = _find_org_id_for_job(str(job_id)) + org_id = asyncio.run(_find_org_id_for_job(str(job_id))) # Set org context before accessing the job if org_id: lab_dirs.set_organization_id(org_id) try: - job = Job.get(str(job_id)) - if experiment_id is not None and job.get_experiment_id() != experiment_id: + job = asyncio.run(Job.get(str(job_id))) + exp_id = asyncio.run(job.get_experiment_id()) + if experiment_id is not None and exp_id != experiment_id: return - job.update_status(status) + asyncio.run(job.update_status(status)) if error_msg: - job.set_error_message(error_msg) + asyncio.run(job.set_error_message(error_msg)) finally: # Clear org context if org_id: @@ -651,10 +664,11 @@ def job_update_sync(job_id: str, status: str, experiment_id: Optional[str] = Non """ # Update the job in the database using SDK Job class try: - job = Job.get(job_id) - if experiment_id is not None and job.get_experiment_id() != experiment_id: + job = asyncio.run(Job.get(job_id)) + exp_id = asyncio.run(job.get_experiment_id()) + if experiment_id is not None and exp_id != experiment_id: return - job.update_status(status) + asyncio.run(job.update_status(status)) except Exception as e: print(f"Error updating job {job_id}: {e}") pass @@ -675,11 +689,12 @@ def job_update_type_and_status_sync(job_id: str, job_type: str, status: str, exp experiment_id: The experiment ID (required for most operations, optional for backward compatibility) """ try: - job = Job.get(job_id) - if experiment_id is not None and job.get_experiment_id() != experiment_id: + job = asyncio.run(Job.get(job_id)) + exp_id = asyncio.run(job.get_experiment_id()) + if experiment_id is not None and exp_id != experiment_id: return - job.set_type(job_type) - job.update_status(status) + asyncio.run(job.set_type(job_type)) + asyncio.run(job.update_status(status)) # Trigger workflows if job status is COMPLETE # if status == "COMPLETE": @@ -698,10 +713,10 @@ def _trigger_workflows_on_job_completion_sync(job_id: str): """ try: # 1. Get job details using SDK - job = Job.get(job_id) - job_type = job.get_type() + job = asyncio.run(Job.get(job_id)) + job_type = asyncio.run(job.get_type()) # Get experiment_id from job data to match the type expected by workflow functions - experiment_id = job.get_experiment_id() + experiment_id = asyncio.run(job.get_experiment_id()) if not experiment_id: return @@ -715,9 +730,6 @@ def _trigger_workflows_on_job_completion_sync(job_id: str): # Note: This is a limitation - we can't easily do async operations in a sync context # For now, we'll import the async function and call it - # This is not ideal but necessary for now - import asyncio - try: loop = asyncio.get_event_loop() if loop.is_running(): @@ -756,19 +768,21 @@ def job_mark_as_complete_if_running(job_id: int, experiment_id: int) -> None: """Service wrapper: mark job as complete if running and then trigger workflows.""" try: # Find which org this job belongs to - org_id = _find_org_id_for_job(str(job_id)) + org_id = asyncio.run(_find_org_id_for_job(str(job_id))) # Set org context before accessing the job if org_id: lab_dirs.set_organization_id(org_id) try: - job = Job.get(str(job_id)) - if experiment_id is not None and job.get_experiment_id() != experiment_id: + job = asyncio.run(Job.get(str(job_id))) + exp_id = asyncio.run(job.get_experiment_id()) + if experiment_id is not None and exp_id != experiment_id: return # Only update if currently running - if job.get_status() == "RUNNING": - job.update_status("COMPLETE") + status = asyncio.run(job.get_status()) + if status == "RUNNING": + asyncio.run(job.update_status("COMPLETE")) # _trigger_workflows_on_job_completion_sync(job_id) finally: # Clear org context @@ -783,14 +797,14 @@ def job_mark_as_complete_if_running(job_id: int, experiment_id: int) -> None: pass -def get_file_metadata(file_path: str, storage) -> Dict[str, any]: +async def get_file_metadata(file_path: str, storage) -> Dict[str, any]: """ Extract file metadata (size, modified time) from storage or filesystem. Returns dict with 'size' and 'mtime' keys, or empty values if unavailable. """ try: # Try storage.ls with detail=True first (works for S3 and local) - file_info_list = storage.ls(file_path, detail=True) + file_info_list = await storage.ls(file_path, detail=True) # Handle dict response (some storage backends) if isinstance(file_info_list, dict): @@ -815,14 +829,14 @@ def get_file_metadata(file_path: str, storage) -> Dict[str, any]: return {"size": None, "mtime": None} -def format_artifact(file_path: str, storage) -> Optional[Dict[str, any]]: +async def format_artifact(file_path: str, storage) -> Optional[Dict[str, any]]: """ Format a single artifact file into the response structure. Returns None if the artifact can't be processed. """ try: filename = file_path.split("/")[-1] if "/" in file_path else file_path - metadata = get_file_metadata(file_path, storage) + metadata = await get_file_metadata(file_path, storage) artifact = {"filename": filename, "full_path": file_path} @@ -838,7 +852,7 @@ def format_artifact(file_path: str, storage) -> Optional[Dict[str, any]]: return None -def get_artifacts_from_sdk(job_id: str, storage) -> Optional[List[Dict]]: +async def get_artifacts_from_sdk(job_id: str, storage) -> Optional[List[Dict]]: """ Get artifacts using the SDK method. Returns list of artifacts or None if SDK method fails. @@ -854,7 +868,7 @@ def get_artifacts_from_sdk(job_id: str, storage) -> Optional[List[Dict]]: artifacts = [] for artifact_path in artifact_paths: - artifact = format_artifact(artifact_path, storage) + artifact = await format_artifact(artifact_path, storage) if artifact: artifacts.append(artifact) @@ -864,17 +878,17 @@ def get_artifacts_from_sdk(job_id: str, storage) -> Optional[List[Dict]]: return None -def get_artifacts_from_directory(artifacts_dir: str, storage) -> List[Dict]: +async def get_artifacts_from_directory(artifacts_dir: str, storage) -> List[Dict]: """ Get artifacts by listing files in the artifacts directory. Returns list of artifacts (empty if directory can't be read). """ - if not artifacts_dir or not storage.exists(artifacts_dir): + if not artifacts_dir or not await storage.exists(artifacts_dir): return [] artifacts = [] try: - items = storage.ls(artifacts_dir, detail=False) + items = await storage.ls(artifacts_dir, detail=False) for item in items: # Handle both string paths and dict responses from storage.ls @@ -887,9 +901,8 @@ def get_artifacts_from_directory(artifacts_dir: str, storage) -> List[Dict]: else: file_path = str(item) - # Ensure we have a valid file path (not a directory) - if file_path and storage.isfile(file_path): - artifact = format_artifact(file_path, storage) + if await storage.isfile(file_path): + artifact = await format_artifact(file_path, storage) if artifact: artifacts.append(artifact) except Exception as e: @@ -898,18 +911,18 @@ def get_artifacts_from_directory(artifacts_dir: str, storage) -> List[Dict]: return artifacts -def get_all_artifact_paths(job_id: str, storage) -> List[str]: +async def get_all_artifact_paths(job_id: str, storage) -> List[str]: """ Get all artifact file paths for a job. Uses get_artifacts_from_sdk and get_artifacts_from_directory to retrieve paths. """ # 1. Try SDK method - sdk_artifacts = get_artifacts_from_sdk(job_id, storage) + sdk_artifacts = await get_artifacts_from_sdk(job_id, storage) if sdk_artifacts: return [a.get("full_path") for a in sdk_artifacts if a.get("full_path")] # 2. Fallback to artifacts directory - job = job_get(job_id) + job = await job_get(job_id) if job: job_data = job.get("job_data", {}) artifacts_dir = job_data.get("artifacts_dir") @@ -918,12 +931,12 @@ def get_all_artifact_paths(job_id: str, storage) -> List[str]: try: from lab.dirs import get_job_artifacts_dir - artifacts_dir = get_job_artifacts_dir(job_id) + artifacts_dir = await get_job_artifacts_dir(job_id) except Exception: pass if artifacts_dir: - dir_artifacts = get_artifacts_from_directory(artifacts_dir, storage) + dir_artifacts = await get_artifacts_from_directory(artifacts_dir, storage) if dir_artifacts: return [a.get("full_path") for a in dir_artifacts if a.get("full_path")] diff --git a/api/transformerlab/services/task_service.py b/api/transformerlab/services/task_service.py index 7049e3d91..539a24227 100644 --- a/api/transformerlab/services/task_service.py +++ b/api/transformerlab/services/task_service.py @@ -14,53 +14,53 @@ class TaskService: def __init__(self): self.task_service = TaskTemplateService - def task_get_all(self) -> List[Dict[str, Any]]: + async def task_get_all(self) -> List[Dict[str, Any]]: """Get all tasks from filesystem""" - return self.task_service.list_all() + return await self.task_service.list_all() - def task_get_by_id(self, task_id: str) -> Optional[Dict[str, Any]]: + async def task_get_by_id(self, task_id: str) -> Optional[Dict[str, Any]]: """Get a specific task by ID""" - return self.task_service.get_by_id(task_id) + return await self.task_service.get_by_id(task_id) - def task_get_by_type(self, task_type: str) -> List[Dict[str, Any]]: + async def task_get_by_type(self, task_type: str) -> List[Dict[str, Any]]: """Get all tasks of a specific type""" - return self.task_service.list_by_type(task_type) + return await self.task_service.list_by_type(task_type) - def task_get_by_experiment(self, experiment_id: str) -> List[Dict[str, Any]]: + async def task_get_by_experiment(self, experiment_id: str) -> List[Dict[str, Any]]: """Get all tasks for a specific experiment""" - return self.task_service.list_by_experiment(experiment_id) + return await self.task_service.list_by_experiment(experiment_id) - def task_get_by_type_in_experiment(self, task_type: str, experiment_id: str) -> List[Dict[str, Any]]: + async def task_get_by_type_in_experiment(self, task_type: str, experiment_id: str) -> List[Dict[str, Any]]: """Get all tasks of a specific type in a specific experiment""" - return self.task_service.list_by_type_in_experiment(task_type, experiment_id) + return await self.task_service.list_by_type_in_experiment(task_type, experiment_id) - def task_get_by_subtype_in_experiment( + async def task_get_by_subtype_in_experiment( self, experiment_id: str, subtype: str, task_type: Optional[str] = None ) -> List[Dict[str, Any]]: """Get all tasks for a specific experiment filtered by subtype and optionally by type""" - return self.task_service.list_by_subtype_in_experiment(experiment_id, subtype, task_type) + return await self.task_service.list_by_subtype_in_experiment(experiment_id, subtype, task_type) - def add_task(self, task_data: Dict[str, Any]) -> str: + async def add_task(self, task_data: Dict[str, Any]) -> str: """Create a new task - all fields stored directly in JSON""" # Generate a unique ID for the task task_id = str(uuid.uuid4()) try: - task = self.task_service.create(task_id) + task = await self.task_service.create(task_id) # Store all fields directly (not nested) - task.set_metadata(**task_data) + await task.set_metadata(**task_data) return task_id except FileExistsError: # If task already exists, generate a new ID task_id = str(uuid.uuid4()) - task = self.task_service.create(task_id) - task.set_metadata(**task_data) + task = await self.task_service.create(task_id) + await task.set_metadata(**task_data) return task_id - def update_task(self, task_id: str, new_task_data: Dict[str, Any]) -> bool: + async def update_task(self, task_id: str, new_task_data: Dict[str, Any]) -> bool: """Update an existing task""" try: - task = self.task_service.get(str(task_id)) + task = await self.task_service.get(str(task_id)) # Update only the fields that are provided update_data = {} @@ -69,23 +69,23 @@ def update_task(self, task_id: str, new_task_data: Dict[str, Any]) -> bool: update_data[key] = value if update_data: - task.set_metadata(**update_data) + await task.set_metadata(**update_data) return True except FileNotFoundError: return False - def delete_task(self, task_id: str) -> bool: + async def delete_task(self, task_id: str) -> bool: """Delete a task""" try: - task = self.task_service.get(str(task_id)) - task.delete() + task = await self.task_service.get(str(task_id)) + await task.delete() return True except FileNotFoundError: return False - def task_delete_all(self) -> None: + async def task_delete_all(self) -> None: """Delete all tasks""" - self.task_service.delete_all() + await self.task_service.delete_all() # Create a singleton instance diff --git a/api/transformerlab/services/tasks_service.py b/api/transformerlab/services/tasks_service.py index 932ebefa0..3b28fef33 100644 --- a/api/transformerlab/services/tasks_service.py +++ b/api/transformerlab/services/tasks_service.py @@ -14,27 +14,27 @@ class TasksService: def __init__(self): self.task_service = TaskService - def tasks_get_all(self) -> List[Dict[str, Any]]: + async def tasks_get_all(self) -> List[Dict[str, Any]]: """Get all tasks from filesystem""" - return self.task_service.list_all() + return await self.task_service.list_all() - def tasks_get_by_id(self, task_id: str) -> Optional[Dict[str, Any]]: + async def tasks_get_by_id(self, task_id: str) -> Optional[Dict[str, Any]]: """Get a specific task by ID""" - return self.task_service.get_by_id(task_id) + return await self.task_service.get_by_id(task_id) - def tasks_get_by_type(self, task_type: str) -> List[Dict[str, Any]]: + async def tasks_get_by_type(self, task_type: str) -> List[Dict[str, Any]]: """Get all tasks of a specific type""" - return self.task_service.list_by_type(task_type) + return await self.task_service.list_by_type(task_type) - def tasks_get_by_experiment(self, experiment_id: str) -> List[Dict[str, Any]]: + async def tasks_get_by_experiment(self, experiment_id: str) -> List[Dict[str, Any]]: """Get all tasks for a specific experiment""" - return self.task_service.list_by_experiment(experiment_id) + return await self.task_service.list_by_experiment(experiment_id) - def tasks_get_by_type_in_experiment(self, task_type: str, experiment_id: str) -> List[Dict[str, Any]]: + async def tasks_get_by_type_in_experiment(self, task_type: str, experiment_id: str) -> List[Dict[str, Any]]: """Get all tasks of a specific type in a specific experiment""" - return self.task_service.list_by_type_in_experiment(task_type, experiment_id) + return await self.task_service.list_by_type_in_experiment(task_type, experiment_id) - def add_task( + async def add_task( self, name: str, task_type: str, @@ -50,8 +50,8 @@ def add_task( task_id = str(uuid.uuid4()) try: - task = self.task_service.create(task_id) - task.set_metadata( + task = await self.task_service.create(task_id) + await task.set_metadata( name=name, type=task_type, inputs=inputs, @@ -65,8 +65,8 @@ def add_task( except FileExistsError: # If task already exists, generate a new ID task_id = str(uuid.uuid4()) - task = self.task_service.create(task_id) - task.set_metadata( + task = await self.task_service.create(task_id) + await task.set_metadata( name=name, type=task_type, inputs=inputs, @@ -78,10 +78,10 @@ def add_task( ) return task_id - def update_task(self, task_id: str, new_task_data: Dict[str, Any]) -> bool: + async def update_task(self, task_id: str, new_task_data: Dict[str, Any]) -> bool: """Update an existing task""" try: - task = self.task_service.get(str(task_id)) + task = await self.task_service.get(str(task_id)) # Update only the fields that are provided update_data = {} @@ -95,23 +95,23 @@ def update_task(self, task_id: str, new_task_data: Dict[str, Any]) -> bool: update_data["outputs"] = new_task_data["outputs"] if update_data: - task.set_metadata(**update_data) + await task.set_metadata(**update_data) return True except FileNotFoundError: return False - def delete_task(self, task_id: str) -> bool: + async def delete_task(self, task_id: str) -> bool: """Delete a task""" try: - task = self.task_service.get(str(task_id)) - task.delete() + task = await self.task_service.get(str(task_id)) + await task.delete() return True except FileNotFoundError: return False - def tasks_delete_all(self) -> None: + async def tasks_delete_all(self) -> None: """Delete all tasks""" - self.task_service.delete_all() + await self.task_service.delete_all() # Create a singleton instance diff --git a/api/transformerlab/shared/dirs.py b/api/transformerlab/shared/dirs.py index bde6e2c4e..f638dfb9f 100644 --- a/api/transformerlab/shared/dirs.py +++ b/api/transformerlab/shared/dirs.py @@ -16,24 +16,63 @@ ROOT_DIR is a legacy variable that we should replace with the above, eventually. """ -FASTCHAT_LOGS_DIR = storage.join(HOME_DIR, "logs") -if not storage.exists(FASTCHAT_LOGS_DIR): - storage.makedirs(FASTCHAT_LOGS_DIR, exist_ok=True) +# Note: These path constants are computed synchronously at module load. +# Directory creation is deferred to async init function called at app startup. +# We use None as sentinel to indicate these need to be computed lazily +_fastchat_logs_dir = None +_static_files_dir = None -# TFL_STATIC_FILES_DIR is TFL_HOME_DIR/webapp -STATIC_FILES_DIR = storage.join(HOME_DIR, "webapp") -storage.makedirs(STATIC_FILES_DIR, exist_ok=True) -# if there is no index.html file in the static directory, create blank one -if not storage.exists(storage.join(STATIC_FILES_DIR, "index.html")): - with storage.open(storage.join(STATIC_FILES_DIR, "index.html"), "w") as f: - f.write( - "

Transformer Lab Cloud App Files Missing. Run

curl https://raw.githubusercontent.com/transformerlab/transformerlab-app/main/api/install.sh | bash
to install.

" - ) +def get_fastchat_logs_dir_sync(): + """Get FASTCHAT_LOGS_DIR synchronously - for module-level code only.""" + global _fastchat_logs_dir + if _fastchat_logs_dir is None: + # This is a hack for module initialization - we'll set it properly in async init + # For now, just compute the path without async storage operations + _fastchat_logs_dir = os.path.join(HOME_DIR, "logs") + return _fastchat_logs_dir + + +def get_static_files_dir_sync(): + """Get STATIC_FILES_DIR synchronously - for module-level code only.""" + global _static_files_dir + if _static_files_dir is None: + _static_files_dir = os.path.join(HOME_DIR, "webapp") + return _static_files_dir + + +# Export as module-level constants for backward compatibility +FASTCHAT_LOGS_DIR = get_fastchat_logs_dir_sync() +STATIC_FILES_DIR = get_static_files_dir_sync() + + +async def initialize_dirs(): + """Initialize directories asynchronously. Should be called at app startup.""" + global FASTCHAT_LOGS_DIR, STATIC_FILES_DIR + from lab.dirs import get_workspace_dir + + # Compute FASTCHAT_LOGS_DIR using async storage + workspace_dir = await get_workspace_dir() + FASTCHAT_LOGS_DIR = storage.join(workspace_dir, "logs") + if not await storage.exists(FASTCHAT_LOGS_DIR): + await storage.makedirs(FASTCHAT_LOGS_DIR, exist_ok=True) + + # Ensure STATIC_FILES_DIR exists + STATIC_FILES_DIR = storage.join(HOME_DIR, "webapp") + await storage.makedirs(STATIC_FILES_DIR, exist_ok=True) + + # Create default index.html if missing + index_html_path = storage.join(STATIC_FILES_DIR, "index.html") + if not await storage.exists(index_html_path): + async with await storage.open(index_html_path, "w") as f: + await f.write( + "

Transformer Lab Cloud App Files Missing. Run

curl https://raw.githubusercontent.com/transformerlab/transformerlab-app/main/api/install.sh | bash
to install.

" + ) + # TFL_SOURCE_CODE_DIR api_py_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -if api_py_dir != storage.join(HOME_DIR, "src"): +if api_py_dir != os.path.join(HOME_DIR, "src"): print(f"We are working from {api_py_dir} which is not {os.path.join(HOME_DIR, 'src')}") print( "That means you are probably developing in a different location so we will set source dir to the current directory" @@ -41,21 +80,21 @@ TFL_SOURCE_CODE_DIR = api_py_dir else: print(f"Source code directory is set to: {os.path.join(HOME_DIR, 'src')}") - TFL_SOURCE_CODE_DIR = storage.join(HOME_DIR, "src") + TFL_SOURCE_CODE_DIR = os.path.join(HOME_DIR, "src") # ROOT_DIR (deprecate later) ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -# PLUGIN_PRELOADED_GALLERY -PLUGIN_PRELOADED_GALLERY = storage.join(TFL_SOURCE_CODE_DIR, "transformerlab", "plugins") +# PLUGIN_PRELOADED_GALLERY - use os.path.join for module-level paths +PLUGIN_PRELOADED_GALLERY = os.path.join(TFL_SOURCE_CODE_DIR, "transformerlab", "plugins") -PLUGIN_SDK_DIR = storage.join(TFL_SOURCE_CODE_DIR, "transformerlab", "plugin_sdk") -PLUGIN_HARNESS = storage.join(PLUGIN_SDK_DIR, "plugin_harness.py") +PLUGIN_SDK_DIR = os.path.join(TFL_SOURCE_CODE_DIR, "transformerlab", "plugin_sdk") +PLUGIN_HARNESS = os.path.join(PLUGIN_SDK_DIR, "plugin_harness.py") # Galleries cache directory -GALLERIES_LOCAL_FALLBACK_DIR = storage.join(TFL_SOURCE_CODE_DIR, "transformerlab/galleries/") +GALLERIES_LOCAL_FALLBACK_DIR = os.path.join(TFL_SOURCE_CODE_DIR, "transformerlab/galleries/") # TEMPORARY: We want to move jobs back into the root directory instead of under experiment diff --git a/api/transformerlab/shared/download_huggingface_model.py b/api/transformerlab/shared/download_huggingface_model.py index f9c2eddca..2149590ce 100644 --- a/api/transformerlab/shared/download_huggingface_model.py +++ b/api/transformerlab/shared/download_huggingface_model.py @@ -1,3 +1,4 @@ +import asyncio import json import sqlite3 import time @@ -200,12 +201,12 @@ def get_downloaded_files_from_cache(repo_id, file_metadata): def update_job_progress(job_id, model_name, downloaded_bytes, total_bytes, files_downloaded=None, files_total=None): """Update progress in the database""" try: - job = Job.get(job_id) + job = asyncio.run(Job.get(job_id)) downloaded_mb = downloaded_bytes / 1024 / 1024 total_mb = total_bytes / 1024 / 1024 progress_pct = (downloaded_bytes / total_bytes * 100) if total_bytes > 0 else 0 - job.update_progress(progress_pct) + asyncio.run(job.update_progress(progress_pct)) # Set more data in job_data job_data = { @@ -224,7 +225,7 @@ def update_job_progress(job_id, model_name, downloaded_bytes, total_bytes, files job_data["files_downloaded"] = files_downloaded job_data["files_total"] = files_total - job.set_job_data(job_data) + asyncio.run(job.set_job_data(job_data)) if files_downloaded is not None and files_total is not None: print( @@ -352,7 +353,7 @@ def check_model_gated(repo_id): if not os.path.commonpath([target_dir, WORKSPACE_DIR]) == os.path.abspath(WORKSPACE_DIR): raise ValueError("Invalid path after sanitization. Potential security risk.") print(f"DOWNLOADING TO: {target_dir}") - storage.makedirs(target_dir, exist_ok=True) + asyncio.run(storage.makedirs(target_dir, exist_ok=True)) print(f"Downloading adaptor {peft} with job_id {job_id}") @@ -397,12 +398,12 @@ def cancel_check(job_id, org_id): if org_id: set_organization_id(org_id) try: - job = Job.get(job_id) + job = asyncio.run(Job.get(job_id)) - if job.get_status() == "cancelled": + if asyncio.run(job.get_status()) == "cancelled": return True - job_data = job.get_job_data() + job_data = asyncio.run(job.get_job_data()) if job_data.get("stop") is True: return True @@ -502,9 +503,9 @@ def download_blocking(model_is_downloaded, org_id): print(job_data) # Initialize job data using SDK (context is now set in this thread) - job = Job.get(job_id) - job.update_progress(0) - job.set_job_data(job_data) + job = asyncio.run(Job.get(job_id)) + asyncio.run(job.update_progress(0)) + asyncio.run(job.set_job_data(job_data)) # Check if model is gated before starting download if mode == "adaptor": @@ -529,9 +530,9 @@ def download_blocking(model_is_downloaded, org_id): file_metadata, actual_total_size = get_repo_file_metadata(peft) # Update job_data with files_total - job_data = job.get_job_data() or {} + job_data = asyncio.run(job.get_job_data()) or {} job_data["files_total"] = len(file_metadata) - job.set_job_data(job_data) + asyncio.run(job.set_job_data(job_data)) # Start progress monitoring thread # Pass org_id so thread can set context @@ -571,7 +572,7 @@ def download_blocking(model_is_downloaded, org_id): print("downloading model to workspace/models using filename mode") # Use the model ID (repo name) as the directory name, not the filename location = storage.join(WORKSPACE_DIR, "models", secure_filename(model)) - storage.makedirs(location, exist_ok=True) + asyncio.run(storage.makedirs(location, exist_ok=True)) # Get metadata for single file try: fs = HfFileSystem() @@ -583,9 +584,9 @@ def download_blocking(model_is_downloaded, org_id): file_size = total_size_of_model_in_mb * 1024 * 1024 # Update job_data with files_total (1 file for single file downloads) - job_data = job.get_job_data() or {} + job_data = asyncio.run(job.get_job_data()) or {} job_data["files_total"] = 1 - job.set_job_data(job_data) + asyncio.run(job.set_job_data(job_data)) # Start progress monitoring thread # Pass org_id so thread can set context @@ -605,28 +606,30 @@ def download_blocking(model_is_downloaded, org_id): try: from lab.model import Model as ModelService - model_service = ModelService.create(model) - model_service.set_metadata( - model_id=model, - name=model, - json_data={ - "uniqueId": f"gguf/{model}", - "name": model, - "description": "A GGUF model downloaded from the HuggingFace Hub", - "source": "huggingface", - "source_id_or_path": model, - "huggingface_repo": model, - "model_filename": model_filename - if model_filename - else "", # Use specific filename for GGUF - "architecture": "GGUF", - "private": False, - "gated": False, - "model_type": "", - "library_name": "", - "formats": ["GGUF"], - "logo": "https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png", - }, + model_service = asyncio.run(ModelService.create(model)) + asyncio.run( + model_service.set_metadata( + model_id=model, + name=model, + json_data={ + "uniqueId": f"gguf/{model}", + "name": model, + "description": "A GGUF model downloaded from the HuggingFace Hub", + "source": "huggingface", + "source_id_or_path": model, + "huggingface_repo": model, + "model_filename": model_filename + if model_filename + else "", # Use specific filename for GGUF + "architecture": "GGUF", + "private": False, + "gated": False, + "model_type": "", + "library_name": "", + "formats": ["GGUF"], + "logo": "https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png", + }, + ) ) print(f"Created GGUF model metadata for {model}") except Exception as e: @@ -637,9 +640,9 @@ def download_blocking(model_is_downloaded, org_id): file_metadata, actual_total_size = get_repo_file_metadata(model, allow_patterns) # Update job_data with files_total - job_data = job.get_job_data() or {} + job_data = asyncio.run(job.get_job_data()) or {} job_data["files_total"] = len(file_metadata) - job.set_job_data(job_data) + asyncio.run(job.set_job_data(job_data)) # Start progress monitoring thread # Pass org_id so thread can set context @@ -679,25 +682,27 @@ def download_blocking(model_is_downloaded, org_id): # Use SDK to create model metadata from lab.model import Model as ModelService - model_service = ModelService.create(model) - model_service.set_metadata( - model_id=model, - name=model, - json_data={ - "uniqueId": model, - "name": model, - "description": f"Model downloaded from HuggingFace Hub: {model}", - "source": "huggingface", - "source_id_or_path": model, - "huggingface_repo": model, - "model_filename": "", # Empty for regular HuggingFace models - "architecture": "Unknown", # Will be updated by the system later - "private": False, - "gated": False, - "model_type": "", - "library_name": "", - "formats": [], - }, + model_service = asyncio.run(ModelService.create(model)) + asyncio.run( + model_service.set_metadata( + model_id=model, + name=model, + json_data={ + "uniqueId": model, + "name": model, + "description": f"Model downloaded from HuggingFace Hub: {model}", + "source": "huggingface", + "source_id_or_path": model, + "huggingface_repo": model, + "model_filename": "", # Empty for regular HuggingFace models + "architecture": "Unknown", # Will be updated by the system later + "private": False, + "gated": False, + "model_type": "", + "library_name": "", + "formats": [], + }, + ) ) print(f"Created model metadata for {model}") except Exception as e: @@ -727,7 +732,7 @@ def main(): try: from transformerlab.plugin_sdk.transformerlab.plugin import get_db_config_value - hf_token = get_db_config_value("HuggingfaceUserAccessToken", team_id=org_id, user_id=user_id) + hf_token = asyncio.run(get_db_config_value("HuggingfaceUserAccessToken", team_id=org_id, user_id=user_id)) if hf_token: os.environ["HF_TOKEN"] = hf_token print(f"Set HF_TOKEN from {'user' if user_id else 'team'} config") @@ -740,7 +745,7 @@ def main(): from lab.dirs import get_workspace_dir - workspace_dir = get_workspace_dir() + workspace_dir = asyncio.run(get_workspace_dir()) print(f"Workspace dir: {workspace_dir}") # Simple approach: just run the download with built-in progress tracking @@ -762,9 +767,9 @@ def main(): # for the same reason! Better catch and at least print a message. try: # Context is already set in main(), so we can get job here - job = Job.get(job_id) - job.update_status(status) - job.update_job_data_field("error_msg", str(error_msg)) + job = asyncio.run(Job.get(job_id)) + asyncio.run(job.update_status(status)) + asyncio.run(job.update_job_data_field("error_msg", str(error_msg))) except sqlite3.OperationalError: # NOTE: If we fail to write to the database the app won't get # the right error message. So set a different diff --git a/api/transformerlab/shared/galleries.py b/api/transformerlab/shared/galleries.py index 14c62b1a7..d74d8a951 100644 --- a/api/transformerlab/shared/galleries.py +++ b/api/transformerlab/shared/galleries.py @@ -34,45 +34,45 @@ TLAB_REMOTE_GALLERIES_URL = "https://raw.githubusercontent.com/transformerlab/galleries/main/" -def update_gallery_cache(): +async def update_gallery_cache(): """ Called when Transformer Lab starts up. Initializes any cached gallery files and tries to update from remote. """ for filename in GALLERY_FILES: - update_gallery_cache_file(filename) + await update_gallery_cache_file(filename) -def get_models_gallery(): - return get_gallery_file(MODEL_GALLERY_FILE) +async def get_models_gallery(): + return await get_gallery_file(MODEL_GALLERY_FILE) -def get_model_groups_gallery(): - return get_gallery_file(MODEL_GROUP_GALLERY_FILE) +async def get_model_groups_gallery(): + return await get_gallery_file(MODEL_GROUP_GALLERY_FILE) -def get_data_gallery(): - return get_gallery_file(DATA_GALLERY_FILE) +async def get_data_gallery(): + return await get_gallery_file(DATA_GALLERY_FILE) -def get_exp_recipe_gallery(): - return get_gallery_file(EXP_RECIPES_GALLERY_FILE) +async def get_exp_recipe_gallery(): + return await get_gallery_file(EXP_RECIPES_GALLERY_FILE) -def get_tasks_gallery(): - return get_gallery_file(TASKS_GALLERY_FILE) +async def get_tasks_gallery(): + return await get_gallery_file(TASKS_GALLERY_FILE) -def get_interactive_gallery(): +async def get_interactive_gallery(): """ Get the interactive tasks gallery. This contains templates for interactive task types (vscode, jupyter, vllm, ssh). """ - return get_gallery_file(INTERACTIVE_GALLERY_FILE) + return await get_gallery_file(INTERACTIVE_GALLERY_FILE) -def get_team_tasks_gallery(): +async def get_team_tasks_gallery(): """ Team-specific tasks gallery stored in the workspace directory. Falls back to an empty list when missing or unreadable. @@ -80,27 +80,27 @@ def get_team_tasks_gallery(): from lab.dirs import get_workspace_dir from lab import storage - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() gallery_path = storage.join(workspace_dir, TEAM_TASKS_GALLERY_FILE) try: # Ensure the workspace directory exists before checking the file - storage.makedirs(workspace_dir, exist_ok=True) + await storage.makedirs(workspace_dir, exist_ok=True) - if not storage.exists(gallery_path): + if not await storage.exists(gallery_path): # Initialize an empty gallery file - with storage.open(gallery_path, "w") as f: - json.dump([], f) + async with await storage.open(gallery_path, "w") as f: + await f.write(json.dumps([])) return [] - with storage.open(gallery_path, "r") as f: - return json.load(f) + async with await storage.open(gallery_path, "r") as f: + return json.loads(await f.read()) except Exception as e: print(f"❌ Failed to read team tasks gallery: {e}") return [] -def add_team_task_to_gallery(entry: dict): +async def add_team_task_to_gallery(entry: dict): """ Append (or upsert) a task entry to the team-specific gallery. Replaces an existing entry with the same id/title to avoid duplicates. @@ -108,12 +108,12 @@ def add_team_task_to_gallery(entry: dict): from lab.dirs import get_workspace_dir from lab import storage - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() gallery_path = storage.join(workspace_dir, TEAM_TASKS_GALLERY_FILE) try: - storage.makedirs(workspace_dir, exist_ok=True) - current = get_team_tasks_gallery() + await storage.makedirs(workspace_dir, exist_ok=True) + current = await get_team_tasks_gallery() # De-duplicate on id or title new_id = entry.get("id") @@ -128,15 +128,15 @@ def add_team_task_to_gallery(entry: dict): filtered.append(entry) - with storage.open(gallery_path, "w") as f: - json.dump(filtered, f, indent=2) + async with await storage.open(gallery_path, "w") as f: + await f.write(json.dumps(filtered, indent=2)) return filtered except Exception as e: print(f"❌ Failed to write team tasks gallery: {e}") - return get_team_tasks_gallery() + return await get_team_tasks_gallery() -def delete_team_task_from_gallery(task_id: str): +async def delete_team_task_from_gallery(task_id: str): """ Delete a task entry from the team-specific gallery by id or title. Returns True if deleted, False if not found. @@ -144,12 +144,12 @@ def delete_team_task_from_gallery(task_id: str): from lab.dirs import get_workspace_dir from lab import storage - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() gallery_path = storage.join(workspace_dir, TEAM_TASKS_GALLERY_FILE) try: - storage.makedirs(workspace_dir, exist_ok=True) - current = get_team_tasks_gallery() + await storage.makedirs(workspace_dir, exist_ok=True) + current = await get_team_tasks_gallery() # Filter out the task with matching id or title filtered = [] @@ -161,8 +161,8 @@ def delete_team_task_from_gallery(task_id: str): filtered.append(item) if found: - with storage.open(gallery_path, "w") as f: - json.dump(filtered, f, indent=2) + async with await storage.open(gallery_path, "w") as f: + await f.write(json.dumps(filtered, indent=2)) return True return False except Exception as e: @@ -175,20 +175,20 @@ def delete_team_task_from_gallery(task_id: str): ###################### -def gallery_cache_file_path(filename: str): +async def gallery_cache_file_path(filename: str): from lab.dirs import get_galleries_cache_dir return os.path.join(get_galleries_cache_dir(), filename) -def update_gallery_cache_file(filename: str): +async def update_gallery_cache_file(filename: str): """ Initialize the gallery cache file if it doesn't exist from code, then try to update from remote. """ # First, if nothing is cached yet, then initialize with the local copy. - cached_gallery_file = gallery_cache_file_path(filename) + cached_gallery_file = await gallery_cache_file_path(filename) if not os.path.isfile(cached_gallery_file): print(f"✅ Initializing {filename} from local source.") @@ -203,16 +203,16 @@ def update_gallery_cache_file(filename: str): print("❌ Unable to find local gallery file", sourcefile) # Then, try to update from remote. - update_cache_from_remote(filename) + await update_cache_from_remote(filename) -def update_cache_from_remote(gallery_filename: str): +async def update_cache_from_remote(gallery_filename: str): """ Fetches a gallery file from source and updates the cache """ try: remote_gallery = TLAB_REMOTE_GALLERIES_URL + gallery_filename - local_cache_filename = gallery_cache_file_path(gallery_filename) + local_cache_filename = await gallery_cache_file_path(gallery_filename) # Stream download and write via fsspec with urllib.request.urlopen(remote_gallery) as resp: data = resp.read() @@ -226,15 +226,15 @@ def update_cache_from_remote(gallery_filename: str): print(f"❌ Failed to update gallery from remote: {remote_gallery} {e}") -def get_gallery_file(filename: str): +async def get_gallery_file(filename: str): # default empty gallery returned in case of failed gallery file open gallery = [] - gallery_path = gallery_cache_file_path(filename) + gallery_path = await gallery_cache_file_path(filename) # Check for the cached file. If it's not there then initialize. if not os.path.isfile(gallery_path): print(f"Updating gallery cache file {filename}") - update_gallery_cache_file(filename) + await update_gallery_cache_file(filename) with open(gallery_path, "r") as f: gallery = json.load(f) diff --git a/api/transformerlab/shared/github_utils.py b/api/transformerlab/shared/github_utils.py index ad1b57a8f..489842bb7 100644 --- a/api/transformerlab/shared/github_utils.py +++ b/api/transformerlab/shared/github_utils.py @@ -12,7 +12,7 @@ from lab.dirs import get_workspace_dir -def read_github_pat_from_workspace(workspace_dir: str) -> Optional[str]: +async def read_github_pat_from_workspace(workspace_dir: str) -> Optional[str]: """Read GitHub PAT from workspace/github_pat.txt file. Args: @@ -23,9 +23,9 @@ def read_github_pat_from_workspace(workspace_dir: str) -> Optional[str]: """ try: pat_path = storage.join(workspace_dir, "github_pat.txt") - if storage.exists(pat_path): - with storage.open(pat_path, "r") as f: - pat = f.read().strip() + if await storage.exists(pat_path): + async with await storage.open(pat_path, "r") as f: + pat = (await f.read()).strip() if pat: return pat except Exception as e: @@ -132,8 +132,8 @@ async def _fetch_task_json_impl( file_path = file_path.strip("/") # Get GitHub PAT from workspace - workspace_dir = get_workspace_dir() - github_pat = read_github_pat_from_workspace(workspace_dir) + workspace_dir = await get_workspace_dir() + github_pat = await read_github_pat_from_workspace(workspace_dir) # Build GitHub API URL with optional ref parameter api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{file_path}" diff --git a/api/transformerlab/shared/shared.py b/api/transformerlab/shared/shared.py index e41c7170a..8a76597bd 100644 --- a/api/transformerlab/shared/shared.py +++ b/api/transformerlab/shared/shared.py @@ -51,10 +51,19 @@ def runInThread(onExit, popenArgs, popenKWArgs): # -------- OUTPUT FILE SETUP -------- if output_file is not None: - log = storage.open(output_file, "a") + # For subprocess, we need a regular file handle (not async) + # Write header using async storage, then open regular file for subprocess current_time = time.strftime("%Y-%m-%d %H:%M:%S") - log.write(f"\n\n-- RUN {current_time} --\n") - log.flush() + + async def _write_header(): + async with await storage.open(output_file, "a") as f: + await f.write(f"\n\n-- RUN {current_time} --\n") + await f.flush() + + asyncio.run(_write_header()) + + # Open regular file handle for subprocess + log = open(output_file, "a") else: log = subprocess.PIPE @@ -130,7 +139,7 @@ async def async_run_python_script_and_update_status( plugin_location = os.path.normpath(plugin_location) from lab.dirs import get_plugin_dir - plugin_dir_root = get_plugin_dir() + plugin_dir_root = await get_plugin_dir() if not plugin_location.startswith(plugin_dir_root): print(f"Plugin location {plugin_location} is not in {plugin_dir_root}") raise Exception(f"Plugin location {plugin_location} is not in {plugin_dir_root}") @@ -163,12 +172,12 @@ async def async_run_python_script_and_update_status( print(">> " + text) if begin_string in text: print(f"Job {job_id} now in progress!") - job = job_service.job_get(job_id) - experiment_id = job["experiment_id"] + job = await job_service.job_get(job_id) + experiment_id = job.get("experiment_id") if job else None await job_update_status(job_id=job_id, status="RUNNING", experiment_id=experiment_id) # Check the job_data column for the stop flag: - job_row = job_service.job_get(job_id) + job_row = await job_service.job_get(job_id) job_data = job_row.get("job_data", None) if job_data and job_data.get("stop", False): print(f"Job {job_id}: 'stop' flag detected. Cancelling job.") @@ -179,13 +188,13 @@ async def async_run_python_script_and_update_status( if process.returncode == 0: print(f"Job {job_id} completed successfully") - job = job_service.job_get(job_id) - experiment_id = job["experiment_id"] + job = await job_service.job_get(job_id) + experiment_id = job.get("experiment_id") if job else None await job_update_status(job_id=job_id, status="COMPLETE", experiment_id=experiment_id) else: print(f"ERROR: Job {job_id} failed with exit code {process.returncode}.") - job = job_service.job_get(job_id) - experiment_id = job["experiment_id"] + job = await job_service.job_get(job_id) + experiment_id = job.get("experiment_id") if job else None await job_update_status(job_id=job_id, status="FAILED", experiment_id=experiment_id) return process @@ -212,24 +221,24 @@ async def read_process_output(process, job_id, log_handle=None): # Close the log handle if one was passed (from async_run_python_daemon_and_update_status) if log_handle: try: - log_handle.close() + await log_handle.__aexit__(None, None, None) except Exception: pass # Wrap log write in try-except to handle errors gracefully during shutdown try: - with storage.open(get_global_log_path(), "a") as log: - log.write(f"Inference Server Terminated with {returncode}.\n") - log.flush() + async with await storage.open(await get_global_log_path(), "a") as log: + await log.write(f"Inference Server Terminated with {returncode}.\n") + await log.flush() except Exception: # Silently ignore logging errors during shutdown to prevent error bursts pass # so we should delete the pid file: from lab.dirs import get_temp_dir - pid_file = storage.join(get_temp_dir(), f"worker_job_{job_id}.pid") - if storage.exists(pid_file): - storage.rm(pid_file) + pid_file = storage.join(await get_temp_dir(), f"worker_job_{job_id}.pid") + if await storage.exists(pid_file): + await storage.rm(pid_file) # Clean up resources after process ends clear_vram_and_kill_sglang() @@ -267,17 +276,19 @@ async def async_run_python_daemon_and_update_status( # Open a file to write the output to: # Use context manager to ensure proper cleanup, but we need to keep it open - # so we'll use a different approach - store the handle and close it later + # so we'll use a different approach - manually enter the context manager log = None + log_cm = None try: - log = storage.open(get_global_log_path(), "a") + log_cm = await storage.open(await get_global_log_path(), "a") + log = await log_cm.__aenter__() # Check if plugin has a venv directory if plugin_location: plugin_location = os.path.normpath(plugin_location) from lab.dirs import get_plugin_dir - plugin_dir_root = get_plugin_dir() + plugin_dir_root = await get_plugin_dir() if not plugin_location.startswith(plugin_dir_root): print(f"Plugin location {plugin_location} is not in {plugin_dir_root}") raise Exception(f"Plugin location {plugin_location} is not in {plugin_dir_root}") @@ -309,9 +320,9 @@ async def async_run_python_daemon_and_update_status( pid = process.pid from lab.dirs import get_temp_dir - pid_file = storage.join(get_temp_dir(), f"worker_job_{job_id}.pid") - with storage.open(pid_file, "w") as f: - f.write(str(pid)) + pid_file = storage.join(await get_temp_dir(), f"worker_job_{job_id}.pid") + async with await storage.open(pid_file, "w") as f: + await f.write(str(pid)) # keep a tail of recent lines so we can show them on failure: recent_lines = deque(maxlen=10) @@ -328,16 +339,17 @@ async def async_run_python_daemon_and_update_status( if set_process_id_function: set_process_id_function(process) print(f"Worker job {job_id} started successfully") - job = job_service.job_get(job_id) - experiment_id = job["experiment_id"] + job = await job_service.job_get(job_id) + experiment_id = job.get("experiment_id") if job else None await job_update_status(job_id=job_id, status="COMPLETE", experiment_id=experiment_id) # Schedule the read_process_output coroutine in the current event # so we can keep watching this process, but return back to the caller # so that the REST call can complete - # Pass the log handle to read_process_output so it can close it - # Set log to None so the finally block doesn't close it - log_handle_to_pass = log + # Pass the log context manager to read_process_output so it can close it + # Set log_cm to None so the finally block doesn't close it + log_handle_to_pass = log_cm + log_cm = None log = None asyncio.create_task(read_process_output(process, job_id, log_handle_to_pass)) @@ -350,17 +362,17 @@ async def async_run_python_daemon_and_update_status( # Wrap log write in try-except to handle errors gracefully during shutdown if log: try: - log.write(decoded) - log.flush() + await log.write(decoded) + await log.flush() except Exception: # Silently ignore logging errors during shutdown pass line = await process.stdout.readline() finally: # Ensure log file is closed even if there's an error - if log: + if log_cm: try: - log.close() + await log_cm.__aexit__(None, None, None) except Exception: pass @@ -376,13 +388,13 @@ async def async_run_python_daemon_and_update_status( print(f"ERROR: Worker job {job_id} failed with exit code {returncode}.") print(error_msg) - job = job_service.job_get(job_id) - experiment_id = job["experiment_id"] + job = await job_service.job_get(job_id) + experiment_id = job.get("experiment_id") if job else None await job_update_status(job_id=job_id, status="FAILED", error_msg=error_msg, experiment_id=experiment_id) return process -def _get_org_id_for_subprocess(): +async def _get_org_id_for_subprocess(): """ Helper function to get organization_id from various contexts. Tries request context first, then lab SDK context. @@ -391,7 +403,7 @@ def _get_org_id_for_subprocess(): # get from lab dirs workspace path from lab.dirs import get_workspace_dir - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() if "/orgs/" in workspace_dir: return workspace_dir.split("/orgs/")[-1].split("/")[0] @@ -440,7 +452,7 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job print(master_job_type) # Get organization_id and user_id for passing to plugin subprocesses - org_id = _get_org_id_for_subprocess() + org_id = await _get_org_id_for_subprocess() user_id = _get_user_id_for_subprocess(job_details) subprocess_env = {} if org_id: @@ -463,11 +475,11 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job return {"status": "complete", "job_id": job_id, "message": "Task job completed successfully"} # Common setup using SDK classes - job_obj = Job.get(job_id) - exp_obj = Experiment(experiment_name) - output_temp_file_dir = job_obj.get_dir() + job_obj = await Job.get(job_id) + exp_obj = await Experiment.create_or_get(experiment_name) + output_temp_file_dir = await job_obj.get_dir() - experiment_details = experiment_get(experiment_name) + experiment_details = await experiment_get(experiment_name) # Extract plugin name consistently across all job types plugin_name = None @@ -480,7 +492,7 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job # Common plugin location check for job types that use plugins if plugin_name: - plugin_location = lab_dirs.plugin_dir_by_name(plugin_name) + plugin_location = await lab_dirs.plugin_dir_by_name(plugin_name) if not os.path.exists(plugin_location): await job_service.job_update_status(job_id, "FAILED", experiment_id=experiment_name) error_msg = f"{master_job_type} job failed: No plugin found" @@ -494,17 +506,17 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job evals_output_file = storage.join(output_temp_file_dir, f"output_{job_id}.txt") # Create output file if it doesn't exist - if not storage.exists(evals_output_file): - with storage.open(evals_output_file, "w") as f: - f.write("") + if not await storage.exists(evals_output_file): + async with await storage.open(evals_output_file, "w") as f: + await f.write("") # Pass user_id extracted from job_details if available user_id_from_job = _get_user_id_for_subprocess(job_details) await run_evaluation_script( experiment_name, plugin_name, eval_name, job_id, org_id=org_id, user_id=user_id_from_job ) # Check if stop button was clicked and update status accordingly - job_row = job_service.job_get(job_id) - job_data = job_row.get("job_data", None) + job_row = await job_service.job_get(job_id) + job_data = job_row.get("job_data", None) if job_row else None if job_data is None: await job_update_status(job_id, "FAILED", experiment_id=experiment_name) return {"status": "error", "job_id": job_id, "message": "Evaluation job failed: No job data found"} @@ -514,7 +526,7 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job return {"status": "stopped", "job_id": job_id, "message": "Evaluation job was stopped by user"} else: # Only set to COMPLETE if not already FAILED - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) current_status = job.get("status") if current_status != "FAILED": await job_update_status(job_id, "COMPLETE", experiment_id=experiment_name) @@ -529,9 +541,9 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job gen_output_file = storage.join(output_temp_file_dir, f"output_{job_id}.txt") # Create output file if it doesn't exist - if not storage.exists(gen_output_file): - with storage.open(gen_output_file, "w") as f: - f.write("") + if not await storage.exists(gen_output_file): + async with await storage.open(gen_output_file, "w") as f: + await f.write("") # Pass user_id extracted from job_details if available user_id_from_job = _get_user_id_for_subprocess(job_details) @@ -540,8 +552,8 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job ) # Check should_stop flag and update status accordingly - job_row = job_service.job_get(job_id) - job_data = job_row.get("job_data", None) + job_row = await job_service.job_get(job_id) + job_data = job_row.get("job_data", None) if job_row else None if job_data is None: await job_update_status(job_id, "FAILED", experiment_id=experiment_name) return {"status": "error", "job_id": job_id, "message": "Generation job failed: No job data found"} @@ -551,7 +563,7 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job return {"status": "stopped", "job_id": job_id, "message": "Generation job was stopped by user"} else: # Only set to COMPLETE if not already FAILED - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) current_status = job.get("status") if current_status != "FAILED": await job_update_status(job_id, "COMPLETE", experiment_id=experiment_name) @@ -564,9 +576,9 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job export_output_file = storage.join(output_temp_file_dir, f"output_{job_id}.txt") # Create output file if it doesn't exist - if not storage.exists(export_output_file): - with storage.open(export_output_file, "w") as f: - f.write("") + if not await storage.exists(export_output_file): + async with await storage.open(export_output_file, "w") as f: + await f.write("") # Run the export script using the existing run_exporter_script function from transformerlab.routers.experiment.export import run_exporter_script @@ -600,7 +612,7 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job # Check the result and update job status accordingly if result.get("status") == "success": # Only set to COMPLETE if not already FAILED - job = job_service.job_get(job_id) + job = await job_service.job_get(job_id) current_status = job.get("status") if current_status != "FAILED": await job_update_status(job_id, "COMPLETE", experiment_id=experiment_name) @@ -618,7 +630,7 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job await job_service.job_update_status(job_id, "RUNNING", experiment_id=experiment_name) # Use existing job object and output directory - plugin_dir = lab_dirs.plugin_dir_by_name(plugin_name) + plugin_dir = await lab_dirs.plugin_dir_by_name(plugin_name) plugin_main_args = ["--plugin_dir", plugin_dir] # Flatten job_config["config"] into CLI args @@ -639,8 +651,8 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job decoded = base64.b64decode(config[base64_key]) file_path = storage.join(output_temp_file_dir, f"{file_arg}.png") - with storage.open(file_path, "wb") as f: - f.write(decoded) + async with await storage.open(file_path, "wb") as f: + await f.write(decoded) config[file_arg] = file_path del config[base64_key] @@ -651,8 +663,8 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job # Remove input_image and mask_image from job_data['config'] in db if they were present if removed_base64_keys: - job_row = job_service.job_get(job_id) - job_data = job_row.get("job_data", {}) + job_row = await job_service.job_get(job_id) + job_data = job_row.get("job_data", {}) if job_row else {} # Handle job_data as str or dict if isinstance(job_data, str): try: @@ -686,7 +698,7 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job if updated: if double_encoded: config_in_db = json.dumps(config_in_db) - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value( job_id, "config", config_in_db, experiment_id=experiment_name ) @@ -721,10 +733,10 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job subprocess_command = [python_bin, dirs.PLUGIN_HARNESS] + extra_args output_path = storage.join(output_temp_file_dir, f"output_{job_id}.txt") - storage.makedirs(storage.join(output_temp_file_dir), exist_ok=True) + await storage.makedirs(storage.join(output_temp_file_dir), exist_ok=True) print(f"[DIFFUSION] Running command: {subprocess_command}") try: - with storage.open(output_path, "w") as f: + async with await storage.open(output_path, "w") as f: process = await asyncio.create_subprocess_exec( *subprocess_command, stdout=f, @@ -757,10 +769,10 @@ async def run_job(job_id: str, job_config, experiment_name: str = "default", job # Use experiment details and SDK objects for path management print("Experiment Details: ", experiment_details) experiment_details_as_string = json.dumps(experiment_details) - experiment_dir = exp_obj.get_dir() + experiment_dir = await exp_obj.get_dir() # Use Job SDK for output file path - output_file = job_obj.get_log_path() + output_file = await job_obj.get_log_path() def on_train_complete(): print("Training Job: The process has finished") @@ -782,7 +794,7 @@ def on_job_complete(): adaptor_name = template_config.get("adaptor_name", "adaptor") template_config["job_id"] = job_id # Resolve org-aware workspace dir if multitenant via job_data (org_id may be persisted by caller) - workspace_dir = get_workspace_dir() + workspace_dir = await get_workspace_dir() template_config["adaptor_output_dir"] = storage.join(workspace_dir, "adaptors", model_name, adaptor_name) template_config["output_dir"] = storage.join( experiment_dir, @@ -797,8 +809,8 @@ def on_job_complete(): venv_python = os.path.join(venv_path, "bin", "python") tempdir = storage.join(workspace_dir, "temp") - if not storage.exists(tempdir): - storage.makedirs(tempdir, exist_ok=True) + if not await storage.exists(tempdir): + await storage.makedirs(tempdir, exist_ok=True) # Check if hyperparameter sweep is requested run_sweeps = template_config.get("run_sweeps", False) # if run_sweeps in ["on", "true", "yes"]: @@ -832,7 +844,7 @@ def on_job_complete(): # Create sweep directory to store results sweep_dir = storage.join(template_config["output_dir"], f"sweep_{job_id}") - storage.makedirs(sweep_dir, exist_ok=True) + await storage.makedirs(sweep_dir, exist_ok=True) # Generate all configurations from itertools import product @@ -851,8 +863,10 @@ def on_job_complete(): print(f"Generated {total_configs} configurations for sweep") # Initialize sweep tracking - job_service.job_update_job_data_insert_key_value(job_id, "sweep_total", str(total_configs), experiment_name) - job_service.job_update_job_data_insert_key_value(job_id, "sweep_current", "0", experiment_name) + await job_service.job_update_job_data_insert_key_value( + job_id, "sweep_total", str(total_configs), experiment_name + ) + await job_service.job_update_job_data_insert_key_value(job_id, "sweep_current", "0", experiment_name) # Get metrics configuration metric_name = template_config.get("sweep_metric", "eval/loss") @@ -870,13 +884,13 @@ def on_job_complete(): # Create a unique run directory run_dir = storage.join(sweep_dir, f"run_{i + 1}") - storage.makedirs(run_dir, exist_ok=True) + await storage.makedirs(run_dir, exist_ok=True) # Create a unique adaptor directory for this run run_adaptor_dir = storage.join( workspace_dir, "adaptors", secure_filename(model_name), f"{adaptor_name}_sweep_{i + 1}" ) - storage.makedirs(run_adaptor_dir, exist_ok=True) + await storage.makedirs(run_adaptor_dir, exist_ok=True) # Create a copy of the template config for this run run_config = template_config.copy() @@ -892,19 +906,21 @@ def on_job_complete(): # Create input file for this run run_input_file = storage.join(tempdir, f"plugin_input_{job_id}_run_{i + 1}.json") run_input_contents = {"experiment": experiment_details, "config": run_config} - with storage.open(run_input_file, "w") as outfile: - json.dump(run_input_contents, outfile, indent=4) + async with await storage.open(run_input_file, "w") as outfile: + await outfile.write(json.dumps(run_input_contents, indent=4)) # Update job progress - job_service.job_update_sweep_progress(job_id, int((i / total_configs) * 100), experiment_name) - job_service.job_update_job_data_insert_key_value(job_id, "sweep_current", str(i + 1), experiment_name) - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_sweep_progress(job_id, int((i / total_configs) * 100), experiment_name) + await job_service.job_update_job_data_insert_key_value( + job_id, "sweep_current", str(i + 1), experiment_name + ) + await job_service.job_update_job_data_insert_key_value( job_id, "sweep_running_config", json.dumps(config_params), experiment_name ) # Run the training job with this configuration run_output_file = storage.join(sweep_dir, f"output_sweep_{job_id}.txt") - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value( job_id, "sweep_output_file", storage.join(sweep_dir, f"output_sweep_{job_id}.txt"), experiment_name ) @@ -938,7 +954,7 @@ def on_job_complete(): # Replace synchronous subprocess.run with asyncio async def run_process_async(cmd, output_file): # Open file for writing - with storage.open(output_file, "a") as f: + async with await storage.open(output_file, "a") as f: # Create subprocess with piped stdout # Pass organization_id via environment variable process_env = {**os.environ, **subprocess_env} if subprocess_env_or_none else None @@ -954,8 +970,8 @@ async def run_process_async(cmd, output_file): # Decode and write to file decoded_line = line.decode("utf-8") - f.write(f"\n[Run {i + 1}/{total_configs}]: {decoded_line.strip()}") - f.flush() + await f.write(f"\n[Run {i + 1}/{total_configs}]: {decoded_line.strip()}") + await f.flush() # Wait for process to complete await process.wait() @@ -965,17 +981,17 @@ async def run_process_async(cmd, output_file): await run_process_async(run_command, run_output_file) # Delete the output adaptor directory if it exists - if storage.exists(run_adaptor_dir) and storage.isdir(run_adaptor_dir): + if await storage.exists(run_adaptor_dir) and await storage.isdir(run_adaptor_dir): print(f"Deleting adaptor directory: {run_adaptor_dir}") - storage.rm_tree(run_adaptor_dir) + await storage.rm_tree(run_adaptor_dir) # Check job data for training metrics try: # Get latest metrics from job_data (assuming plugin saved metrics there) metrics_path = storage.join(run_dir, "metrics.json") - if storage.exists(metrics_path): - with storage.open(metrics_path, "r") as f: - run_metrics = json.load(f) + if await storage.exists(metrics_path): + async with await storage.open(metrics_path, "r") as f: + run_metrics = json.loads(await f.read()) else: # Fallback to a default metric value if no metrics found run_metrics = {metric_name: 0.0} @@ -1002,10 +1018,10 @@ async def run_process_async(cmd, output_file): best_config = config_params.copy() # Update job data with current best - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value( job_id, "sweep_best_config", json.dumps(best_config), experiment_name ) - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value( job_id, "sweep_best_metric", json.dumps({metric_name: best_metric}), experiment_name ) except Exception as e: @@ -1025,17 +1041,17 @@ async def run_process_async(cmd, output_file): } sweep_results_file = storage.join(sweep_dir, "sweep_results.json") - with storage.open(sweep_results_file, "w") as f: - json.dump(sweep_results, f, indent=2) + async with await storage.open(sweep_results_file, "w") as f: + await f.write(json.dumps(sweep_results, indent=2)) - job_service.job_update_job_data_insert_key_value( + await job_service.job_update_job_data_insert_key_value( job_id, "sweep_results_file", sweep_results_file, experiment_name ) print("\n--- Sweep completed ---") print(f"Best configuration: {json.dumps(best_config, indent=2)}") print(f"Best {metric_name}: {best_metric}") - job_service.job_update_sweep_progress(job_id, 100, experiment_name) + await job_service.job_update_sweep_progress(job_id, 100, experiment_name) # Optionally train final model with best configuration train_final_model = template_config.get("train_final_model", True) @@ -1052,8 +1068,8 @@ async def run_process_async(cmd, output_file): # Create input file for final run final_input_file = storage.join(tempdir, f"plugin_input_{job_id}_final.json") final_input_contents = {"experiment": experiment_details, "config": final_config} - with storage.open(final_input_file, "w") as outfile: - json.dump(final_input_contents, outfile, indent=4) + async with await storage.open(final_input_file, "w") as outfile: + await outfile.write(json.dumps(final_input_contents, indent=4)) # Create command for final training if os.path.exists(venv_path) and os.path.isdir(venv_path): @@ -1096,8 +1112,8 @@ async def run_process_async(cmd, output_file): else: # Create a file in the temp directory to store the inputs: tempdir = storage.join(workspace_dir, "temp") - if not storage.exists(tempdir): - storage.makedirs(tempdir, exist_ok=True) + if not await storage.exists(tempdir): + await storage.makedirs(tempdir, exist_ok=True) input_file = storage.join(tempdir, f"plugin_input_{job_id}.json") # The following two ifs convert nested JSON strings to JSON objects -- this is a hack # and should be done in the API itself @@ -1112,11 +1128,11 @@ async def run_process_async(cmd, output_file): experiment_details["config"]["inferenceParams"] ) input_contents = {"experiment": experiment_details, "config": template_config} - with storage.open(input_file, "w") as outfile: - json.dump(input_contents, outfile, indent=4) + async with await storage.open(input_file, "w") as outfile: + await outfile.write(json.dumps(input_contents, indent=4)) start_time = time.strftime("%Y-%m-%d %H:%M:%S") - job_service.job_update_job_data_insert_key_value(job_id, "start_time", start_time, experiment_name) + await job_service.job_update_job_data_insert_key_value(job_id, "start_time", start_time, experiment_name) # Check if plugin has a venv directory venv_path = os.path.join(plugin_location, "venv") @@ -1169,8 +1185,8 @@ async def run_process_async(cmd, output_file): # Create a file in the temp directory to store the inputs: tempdir = storage.join(workspace_dir, "temp") - if not storage.exists(tempdir): - storage.makedirs(tempdir, exist_ok=True) + if not await storage.exists(tempdir): + await storage.makedirs(tempdir, exist_ok=True) input_file = storage.join(tempdir, f"plugin_input_{job_id}.json") # The following two ifs convert nested JSON strings to JSON objects -- this is a hack # and should be done in the API itself @@ -1181,8 +1197,8 @@ async def run_process_async(cmd, output_file): experiment_details["config"]["inferenceParams"] ) input_contents = {"experiment": experiment_details, "config": template_config} - with storage.open(input_file, "w") as outfile: - json.dump(input_contents, outfile, indent=4) + async with await storage.open(input_file, "w") as outfile: + await outfile.write(json.dumps(input_contents, indent=4)) start_time = time.strftime("%Y-%m-%d %H:%M:%S") job_service.job_update_job_data_insert_key_value(job_id, "start_time", start_time, experiment_name) @@ -1234,14 +1250,14 @@ async def run_process_async(cmd, output_file): template_config["template_name"], ) - if not storage.exists(output_file): - with storage.open(output_file, "w") as f: - f.write("") + if not await storage.exists(output_file): + async with await storage.open(output_file, "w") as f: + await f.write("") # Create a file in the temp directory to store the inputs: - tempdir = storage.join(get_workspace_dir(), "temp") - if not storage.exists(tempdir): - storage.makedirs(tempdir, exist_ok=True) + tempdir = storage.join(await get_workspace_dir(), "temp") + if not await storage.exists(tempdir): + await storage.makedirs(tempdir, exist_ok=True) input_file = storage.join(tempdir, f"plugin_input_{job_id}.json") # The following two ifs convert nested JSON strings to JSON objects -- this is a hack # and should be done in the API itself @@ -1252,8 +1268,8 @@ async def run_process_async(cmd, output_file): experiment_details["config"]["inferenceParams"] ) input_contents = {"experiment": experiment_details, "config": template_config} - with storage.open(input_file, "w") as outfile: - json.dump(input_contents, outfile, indent=4) + async with await storage.open(input_file, "w") as outfile: + await outfile.write(json.dumps(input_contents, indent=4)) start_time = time.strftime("%Y-%m-%d %H:%M:%S") job_service.job_update_job_data_insert_key_value(job_id, "start_time", start_time, experiment_name) @@ -1306,8 +1322,8 @@ async def run_process_async(cmd, output_file): async def get_job_output_file_name(job_id: str, plugin_name: str = None, experiment_name: str = None): try: - job_obj = Job(job_id) - output_file = job_obj.get_log_path() + job_obj = await Job.get(job_id) + output_file = await job_obj.get_log_path() return output_file except Exception as e: raise e diff --git a/api/transformerlab/shared/ssl_utils.py b/api/transformerlab/shared/ssl_utils.py index d35780b55..3e2cf8338 100644 --- a/api/transformerlab/shared/ssl_utils.py +++ b/api/transformerlab/shared/ssl_utils.py @@ -15,23 +15,22 @@ from lab import storage __all__ = [ - "CERT_DIR", - "CERT_PATH", - "KEY_PATH", "ensure_persistent_self_signed_cert", ] -CERT_DIR: Path = Path(get_workspace_dir()) / "certs" -CERT_PATH: Path = CERT_DIR / "server-cert.pem" -KEY_PATH: Path = CERT_DIR / "server-key.pem" +async def ensure_persistent_self_signed_cert() -> Tuple[str, str]: + # Compute paths lazily to avoid asyncio.run() at module level + workspace_dir = await get_workspace_dir() + cert_dir = Path(workspace_dir) / "certs" + cert_path = cert_dir / "server-cert.pem" + key_path = cert_dir / "server-key.pem" -def ensure_persistent_self_signed_cert() -> Tuple[str, str]: - lock = CERT_DIR / ".cert.lock" + lock = cert_dir / ".cert.lock" with FileLock(str(lock)): - if CERT_PATH.exists() and KEY_PATH.exists(): - return str(CERT_PATH), str(KEY_PATH) - CERT_DIR.mkdir(parents=True, exist_ok=True) + if await storage.exists(str(cert_path)) and await storage.exists(str(key_path)): + return str(cert_path), str(key_path) + await storage.makedirs(str(cert_dir), exist_ok=True) key = rsa.generate_private_key(public_exponent=65537, key_size=2048) subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "TransformerLab-Selfhost")]) cert_builder = ( @@ -55,15 +54,15 @@ def ensure_persistent_self_signed_cert() -> Tuple[str, str]: ) cert = cert_builder.sign(key, hashes.SHA256()) # Write via fsspec storage - storage.makedirs(str(CERT_DIR), exist_ok=True) - with storage.open(str(CERT_PATH), "wb") as f: - f.write(cert.public_bytes(serialization.Encoding.PEM)) - with storage.open(str(KEY_PATH), "wb") as f: - f.write( + await storage.makedirs(str(cert_dir), exist_ok=True) + async with await storage.open(str(cert_path), "wb") as f: + await f.write(cert.public_bytes(serialization.Encoding.PEM)) + async with await storage.open(str(key_path), "wb") as f: + await f.write( key.private_bytes( serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption(), ) ) - return str(CERT_PATH), str(KEY_PATH) + return str(cert_path), str(key_path) diff --git a/api/transformerlab/shared/zip_utils.py b/api/transformerlab/shared/zip_utils.py index 0c83c791a..37a0b8d85 100644 --- a/api/transformerlab/shared/zip_utils.py +++ b/api/transformerlab/shared/zip_utils.py @@ -3,7 +3,7 @@ from typing import List -def create_zip_from_storage(file_paths: List[str], storage) -> io.BytesIO: +async def create_zip_from_storage(file_paths: List[str], storage) -> io.BytesIO: """ Create a zip file in an in-memory buffer from a list of storage file paths. @@ -24,17 +24,17 @@ def create_zip_from_storage(file_paths: List[str], storage) -> io.BytesIO: filename = file_path.split("/")[-1] if "/" in file_path else file_path # Check if file exists to avoid errors - if not storage.exists(file_path): + if not await storage.exists(file_path): print(f"File not found during zipping: {file_path}") continue - if not storage.isfile(file_path): + if not await storage.isfile(file_path): # Skip directories continue # Read file content from storage - with storage.open(file_path, "rb") as f: - file_content = f.read() + async with await storage.open(file_path, "rb") as f: + file_content = await f.read() zip_file.writestr(filename, file_content) except Exception as e: print(f"Error adding file {file_path} to zip: {e}") diff --git a/lab-sdk/pyproject.toml b/lab-sdk/pyproject.toml index f036aafa4..66d6ab8a0 100644 --- a/lab-sdk/pyproject.toml +++ b/lab-sdk/pyproject.toml @@ -4,13 +4,22 @@ build-backend = "setuptools.build_meta" [project] name = "transformerlab" -version = "0.0.61" +version = "0.0.62" description = "Python SDK for Transformer Lab" readme = "README.md" requires-python = ">=3.10" authors = [{ name = "Transformer Lab", email = "developers@lab.cloud" }] license = { file = "LICENSE" } -dependencies = ["werkzeug", "pytest", "wandb", "fsspec", "s3fs", "gcsfs"] +dependencies = [ + "werkzeug", + "pytest", + "pytest-asyncio", + "wandb", + "fsspec", + "s3fs", + "aiofiles", + "gcsfs", +] [project.urls] "Homepage" = "https://github.com/transformerlab/transformerlab-app" @@ -21,3 +30,5 @@ where = ["src"] [tool.pytest.ini_options] testpaths = ["tests"] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" diff --git a/lab-sdk/scripts/examples/fake_dataset_generate.py b/lab-sdk/scripts/examples/fake_dataset_generate.py index c728ae66d..5ec3b74d3 100644 --- a/lab-sdk/scripts/examples/fake_dataset_generate.py +++ b/lab-sdk/scripts/examples/fake_dataset_generate.py @@ -82,7 +82,7 @@ def generate_dataset(): lab.log(f"Dataset generation completed in {generation_duration}") # Get generated dataset from job data - job_data = lab.job.get_job_data() + job_data = lab.get_job_data() generated_datasets = job_data.get("generated_datasets", []) if generated_datasets: diff --git a/lab-sdk/scripts/examples/fake_train.py b/lab-sdk/scripts/examples/fake_train.py index 3febf7446..25e370c71 100644 --- a/lab-sdk/scripts/examples/fake_train.py +++ b/lab-sdk/scripts/examples/fake_train.py @@ -159,7 +159,7 @@ def train(): config_artifact_path = lab.save_artifact(config_file, "training_config.json") lab.log(f"Saved training config: {config_artifact_path}") # Get the captured wandb URL from job data for reporting - job_data = lab.job.get_job_data() + job_data = lab.get_job_data() captured_wandb_url = job_data.get("wandb_run_url", "None") lab.log(f"📋 Final wandb URL stored in job data: {captured_wandb_url}") diff --git a/lab-sdk/scripts/examples/slow_test_script.py b/lab-sdk/scripts/examples/slow_test_script.py index b9e349f86..774c4ed93 100644 --- a/lab-sdk/scripts/examples/slow_test_script.py +++ b/lab-sdk/scripts/examples/slow_test_script.py @@ -159,7 +159,7 @@ def train(): config_artifact_path = lab.save_artifact(config_file, "training_config.json") lab.log(f"Saved training config: {config_artifact_path}") # Get the captured wandb URL from job data for reporting - job_data = lab.job.get_job_data() + job_data = lab.get_job_data() captured_wandb_url = job_data.get("wandb_run_url", "None") lab.log(f"📋 Final wandb URL stored in job data: {captured_wandb_url}") sleep(300) diff --git a/lab-sdk/scripts/examples/trl_train_script.py b/lab-sdk/scripts/examples/trl_train_script.py index 000c3f272..9aa7cba79 100644 --- a/lab-sdk/scripts/examples/trl_train_script.py +++ b/lab-sdk/scripts/examples/trl_train_script.py @@ -451,7 +451,7 @@ def train_with_trl(quick_test=True): lab.log(f"✅ Model saved to job models directory: {saved_path}") # Get the captured wandb URL from job data for reporting - job_data = lab.job.get_job_data() + job_data = lab.get_job_data() captured_wandb_url = job_data.get("wandb_run_url", "None") lab.log(f"📋 Final wandb URL stored in job data: {captured_wandb_url}") diff --git a/lab-sdk/src/lab/dataset.py b/lab-sdk/src/lab/dataset.py index ae323cd50..8c3581a2d 100644 --- a/lab-sdk/src/lab/dataset.py +++ b/lab-sdk/src/lab/dataset.py @@ -6,10 +6,11 @@ class Dataset(BaseLabResource): - def get_dir(self): + async def get_dir(self): """Abstract method on BaseLabResource""" dataset_id_safe = secure_filename(str(self.id)) - return storage.join(get_datasets_dir(), dataset_id_safe) + datasets_dir = await get_datasets_dir() + return storage.join(datasets_dir, dataset_id_safe) def _default_json(self): # Default metadata modeled after API dataset table fields @@ -21,7 +22,7 @@ def _default_json(self): "json_data": {}, } - def set_metadata( + async def set_metadata( self, *, location: str | None = None, @@ -29,7 +30,7 @@ def set_metadata( size: int | None = None, json_data: dict | None = None, ): - data = self.get_json_data() + data = await self.get_json_data() if location is not None: data["location"] = location if description is not None: @@ -43,29 +44,29 @@ def set_metadata( current = {} current.update(json_data) data["json_data"] = current - self._set_json_data(data) + await self._set_json_data(data) - def get_metadata(self): - return self.get_json_data() + async def get_metadata(self): + return await self.get_json_data() @staticmethod - def list_all(): + async def list_all(): results = [] - datasets_dir = get_datasets_dir() - if not storage.isdir(datasets_dir): + datasets_dir = await get_datasets_dir() + if not await storage.isdir(datasets_dir): return results try: - entries = storage.ls(datasets_dir, detail=False) + entries = await storage.ls(datasets_dir, detail=False) except Exception: entries = [] for full in entries: - if not storage.isdir(full): + if not await storage.isdir(full): continue # Attempt to read index.json (or latest snapshot) try: entry = full.rstrip("/").split("/")[-1] ds = Dataset(entry) - results.append(ds.get_metadata()) + results.append(await ds.get_metadata()) except Exception: continue return results diff --git a/lab-sdk/src/lab/dirs.py b/lab-sdk/src/lab/dirs.py index 05e7591b8..ec8a72a28 100644 --- a/lab-sdk/src/lab/dirs.py +++ b/lab-sdk/src/lab/dirs.py @@ -7,6 +7,8 @@ from .storage import _current_tfl_storage_uri, REMOTE_WORKSPACE_HOST # TFL_HOME_DIR +# Note: This is a temporary sync value for module initialization +# Actual async resolution happens via get functions if "TFL_HOME_DIR" in os.environ and not (_current_tfl_storage_uri.get() or os.getenv("TFL_STORAGE_URI")): HOME_DIR = os.environ["TFL_HOME_DIR"] if not os.path.exists(HOME_DIR): @@ -14,13 +16,11 @@ exit(1) print(f"Home directory is set to: {HOME_DIR}") else: - # If TFL_STORAGE_URI is set (via context or env), HOME_DIR concept maps to storage.root_uri() - HOME_DIR = ( - storage.root_uri() - if (_current_tfl_storage_uri.get() or os.getenv("TFL_STORAGE_URI")) - else os.path.join(os.path.expanduser("~"), ".transformerlab") - ) - if not (_current_tfl_storage_uri.get() or os.getenv("TFL_STORAGE_URI")): + # For remote storage, this is a placeholder - actual value resolved via async functions + if _current_tfl_storage_uri.get() or os.getenv("TFL_STORAGE_URI"): + HOME_DIR = os.getenv("TFL_STORAGE_URI", "") + else: + HOME_DIR = os.path.join(os.path.expanduser("~"), ".transformerlab") os.makedirs(name=HOME_DIR, exist_ok=True) print(f"Using default home directory: {HOME_DIR}") @@ -45,12 +45,12 @@ def set_organization_id(organization_id: str | None) -> None: _current_tfl_storage_uri.set(None) -def get_workspace_dir() -> str: +async def get_workspace_dir() -> str: # Remote SkyPilot workspace override (highest precedence) # Only return container workspace path when value is exactly "true" if os.getenv("_TFL_REMOTE_SKYPILOT_WORKSPACE") == "true": if os.getenv("TFL_STORAGE_URI") is not None: - return storage.root_uri() + return await storage.root_uri() return "/workspace" @@ -71,19 +71,20 @@ def get_workspace_dir() -> str: if _current_tfl_storage_uri.get() is not None: return _current_tfl_storage_uri.get() path = storage.join(HOME_DIR, "orgs", org_id, "workspace") - storage.makedirs(path, exist_ok=True) + await storage.makedirs(path, exist_ok=True) return path if os.getenv("TFL_STORAGE_URI"): - return storage.root_uri() + return await storage.root_uri() path = storage.join(HOME_DIR, "workspace") - storage.makedirs(path, exist_ok=True) + await storage.makedirs(path, exist_ok=True) return path -# Legacy constant for backward compatibility -WORKSPACE_DIR = get_workspace_dir() +# Legacy constant for backward compatibility - placeholder value +# Use await get_workspace_dir() in async code +WORKSPACE_DIR = HOME_DIR """ TFL_HOME_DIR is the directory that is the parent of the src and workspace directories. @@ -106,101 +107,114 @@ def get_workspace_dir() -> str: os.environ["LOGDIR"] = os.getenv("TFL_HOME_DIR", os.path.join(str(os.path.expanduser("~")), ".transformerlab")) -def get_experiments_dir() -> str: - path = storage.join(get_workspace_dir(), "experiments") - storage.makedirs(path, exist_ok=True) +async def get_experiments_dir() -> str: + workspace = await get_workspace_dir() + path = storage.join(workspace, "experiments") + await storage.makedirs(path, exist_ok=True) return path -def get_jobs_dir() -> str: - workspace_dir = get_workspace_dir() +async def get_jobs_dir() -> str: + workspace_dir = await get_workspace_dir() path = storage.join(workspace_dir, "jobs") - storage.makedirs(path, exist_ok=True) + await storage.makedirs(path, exist_ok=True) return path -def get_global_log_path() -> str: - return storage.join(get_workspace_dir(), "transformerlab.log") +async def get_global_log_path() -> str: + workspace = await get_workspace_dir() + return storage.join(workspace, "transformerlab.log") -def get_logs_dir() -> str: +async def get_logs_dir() -> str: path = storage.join(HOME_DIR, "logs") - storage.makedirs(path, exist_ok=True) + await storage.makedirs(path, exist_ok=True) return path # TODO: Move this to Experiment -def experiment_dir_by_name(experiment_name: str) -> str: - experiments_dir = get_experiments_dir() +async def experiment_dir_by_name(experiment_name: str) -> str: + experiments_dir = await get_experiments_dir() return storage.join(experiments_dir, experiment_name) -def get_plugin_dir() -> str: - return storage.join(get_workspace_dir(), "plugins") +async def get_plugin_dir() -> str: + workspace = await get_workspace_dir() + return storage.join(workspace, "plugins") -def plugin_dir_by_name(plugin_name: str) -> str: +async def plugin_dir_by_name(plugin_name: str) -> str: plugin_name = secure_filename(plugin_name) - return storage.join(get_plugin_dir(), plugin_name) + plugin_dir = await get_plugin_dir() + return storage.join(plugin_dir, plugin_name) -def get_models_dir() -> str: - path = storage.join(get_workspace_dir(), "models") - storage.makedirs(path, exist_ok=True) +async def get_models_dir() -> str: + workspace = await get_workspace_dir() + path = storage.join(workspace, "models") + await storage.makedirs(path, exist_ok=True) return path -def get_datasets_dir() -> str: - path = storage.join(get_workspace_dir(), "datasets") - storage.makedirs(path, exist_ok=True) +async def get_datasets_dir() -> str: + workspace = await get_workspace_dir() + path = storage.join(workspace, "datasets") + await storage.makedirs(path, exist_ok=True) return path -def get_tasks_dir() -> str: +async def get_tasks_dir() -> str: tfl_storage_uri = _current_tfl_storage_uri.get() if tfl_storage_uri is not None: return storage.join(tfl_storage_uri, "tasks") - path = storage.join(get_workspace_dir(), "tasks") - storage.makedirs(path, exist_ok=True) + workspace = await get_workspace_dir() + path = storage.join(workspace, "tasks") + await storage.makedirs(path, exist_ok=True) return path -def get_task_dir() -> str: +async def dataset_dir_by_id(dataset_id: str) -> str: + datasets_dir = await get_datasets_dir() + return storage.join(datasets_dir, dataset_id) + + +async def get_task_dir() -> str: tfl_storage_uri = _current_tfl_storage_uri.get() if tfl_storage_uri is not None: return storage.join(tfl_storage_uri, "task") - path = storage.join(get_workspace_dir(), "task") - storage.makedirs(path, exist_ok=True) + workspace = await get_workspace_dir() + path = storage.join(workspace, "task") + await storage.makedirs(path, exist_ok=True) return path -def dataset_dir_by_id(dataset_id: str) -> str: - return storage.join(get_datasets_dir(), dataset_id) - - -def get_temp_dir() -> str: - path = storage.join(get_workspace_dir(), "temp") - storage.makedirs(path, exist_ok=True) +async def get_temp_dir() -> str: + workspace = await get_workspace_dir() + path = storage.join(workspace, "temp") + await storage.makedirs(path, exist_ok=True) return path -def get_prompt_templates_dir() -> str: - path = storage.join(get_workspace_dir(), "prompt_templates") - storage.makedirs(path, exist_ok=True) +async def get_prompt_templates_dir() -> str: + workspace = await get_workspace_dir() + path = storage.join(workspace, "prompt_templates") + await storage.makedirs(path, exist_ok=True) return path -def get_tools_dir() -> str: - path = storage.join(get_workspace_dir(), "tools") - storage.makedirs(path, exist_ok=True) +async def get_tools_dir() -> str: + workspace = await get_workspace_dir() + path = storage.join(workspace, "tools") + await storage.makedirs(path, exist_ok=True) return path -def get_batched_prompts_dir() -> str: - path = storage.join(get_workspace_dir(), "batched_prompts") - storage.makedirs(path, exist_ok=True) +async def get_batched_prompts_dir() -> str: + workspace = await get_workspace_dir() + path = storage.join(workspace, "batched_prompts") + await storage.makedirs(path, exist_ok=True) return path @@ -211,43 +225,47 @@ def get_galleries_cache_dir() -> str: return path -def get_job_dir(job_id: str | int) -> str: +async def get_job_dir(job_id: str | int) -> str: """ Return the filesystem directory for a specific job id under the jobs root. Mirrors `Job.get_dir()` but provided here for convenience where a `Job` instance is not readily available. """ job_id_safe = secure_filename(str(job_id)) - return storage.join(get_jobs_dir(), job_id_safe) + jobs_dir = await get_jobs_dir() + return storage.join(jobs_dir, job_id_safe) -def get_job_artifacts_dir(job_id: str | int) -> str: +async def get_job_artifacts_dir(job_id: str | int) -> str: """ Return the artifacts directory for a specific job, creating it if needed. Example: ~/.transformerlab/workspace/jobs//artifacts """ - path = storage.join(get_job_dir(job_id), "artifacts") - storage.makedirs(path, exist_ok=True) + job_dir = await get_job_dir(job_id) + path = storage.join(job_dir, "artifacts") + await storage.makedirs(path, exist_ok=True) return path -def get_job_checkpoints_dir(job_id: str | int) -> str: +async def get_job_checkpoints_dir(job_id: str | int) -> str: """ Return the checkpoints directory for a specific job, creating it if needed. Example: ~/.transformerlab/workspace/jobs//checkpoints """ - path = storage.join(get_job_dir(job_id), "checkpoints") - storage.makedirs(path, exist_ok=True) + job_dir = await get_job_dir(job_id) + path = storage.join(job_dir, "checkpoints") + await storage.makedirs(path, exist_ok=True) return path -def get_job_eval_results_dir(job_id: str | int) -> str: +async def get_job_eval_results_dir(job_id: str | int) -> str: """ Return the eval_results directory for a specific job, creating it if needed. Example: ~/.transformerlab/workspace/jobs//eval_results """ - path = storage.join(get_job_dir(job_id), "eval_results") - storage.makedirs(path, exist_ok=True) + job_dir = await get_job_dir(job_id) + path = storage.join(job_dir, "eval_results") + await storage.makedirs(path, exist_ok=True) return path @@ -256,16 +274,16 @@ def get_job_eval_results_dir(job_id: str | int) -> str: async def eval_output_file(experiment_name: str, eval_name: str) -> str: - experiment_dir = experiment_dir_by_name(experiment_name) + experiment_dir = await experiment_dir_by_name(experiment_name) eval_name = secure_filename(eval_name) p = storage.join(experiment_dir, "evals", eval_name) - storage.makedirs(p, exist_ok=True) + await storage.makedirs(p, exist_ok=True) return storage.join(p, "output.txt") async def generation_output_file(experiment_name: str, generation_name: str) -> str: - experiment_dir = experiment_dir_by_name(experiment_name) + experiment_dir = await experiment_dir_by_name(experiment_name) generation_name = secure_filename(generation_name) p = storage.join(experiment_dir, "generations", generation_name) - storage.makedirs(p, exist_ok=True) + await storage.makedirs(p, exist_ok=True) return storage.join(p, "output.txt") diff --git a/lab-sdk/src/lab/experiment.py b/lab-sdk/src/lab/experiment.py index 91982202b..b7777e7ec 100644 --- a/lab-sdk/src/lab/experiment.py +++ b/lab-sdk/src/lab/experiment.py @@ -21,45 +21,66 @@ class Experiment(BaseLabResource): _cache_rebuild_lock = threading.Lock() _cache_rebuild_thread = None - def __init__(self, experiment_id, create_new=False): + def __init__(self, experiment_id): # For consistency and simplicity, let's edit experiment name to match # the directory (which requires experiment_id) self.id = secure_filename(str(experiment_id)) - # Auto-initialize if create_new=True and experiment doesn't exist - if create_new and (not storage.exists(self.get_dir()) or not storage.exists(self._get_json_file())): - self._initialize() + @classmethod + async def create_or_get(cls, experiment_id, create_new=False): + """ + Factory method to create or get an experiment. + If create_new=True, will initialize a new experiment if it doesn't exist. + + Args: + experiment_id: The experiment identifier + create_new: If True, create the experiment if it doesn't exist + + Returns: + Experiment instance + """ + exp = cls(experiment_id) - def get_dir(self): + if create_new: + exp_dir = await exp.get_dir() + json_file = await exp._get_json_file() + # Auto-initialize if experiment doesn't exist + if not await storage.exists(exp_dir) or not await storage.exists(json_file): + await exp._initialize() + + return exp + + async def get_dir(self): """Abstract method on BaseLabResource""" experiment_id_safe = secure_filename(str(self.id)) - return storage.join(get_experiments_dir(), experiment_id_safe) + experiments_dir = await get_experiments_dir() + return storage.join(experiments_dir, experiment_id_safe) def _default_json(self): return {"name": self.id, "id": self.id, "config": {}} - def _initialize(self): - super()._initialize() + async def _initialize(self): + await super()._initialize() # Create a empty jobs index and write - jobs_json_path = self._jobs_json_file() + jobs_json_path = await self._jobs_json_file() empty_jobs_data = {"index": self.DEFAULT_JOBS_INDEX, "cached_jobs": {}} - with storage.open(jobs_json_path, "w") as f: - json.dump(empty_jobs_data, f, indent=4) + async with await storage.open(jobs_json_path, "w") as f: + await f.write(json.dumps(empty_jobs_data, indent=4)) - def update_config_field(self, key, value): + async def update_config_field(self, key, value): """Update a single key in config.""" - current_config = self._get_json_data_field("config", {}) + current_config = await self._get_json_data_field("config", {}) if isinstance(current_config, str): try: current_config = json.loads(current_config) except json.JSONDecodeError: current_config = {} current_config[key] = value - self._update_json_data_field("config", current_config) + await self._update_json_data_field("config", current_config) @classmethod - def create_with_config(cls, name: str, config: dict) -> "Experiment": + async def create_with_config(cls, name: str, config: dict) -> "Experiment": """Create an experiment with config.""" if isinstance(config, str): try: @@ -68,38 +89,39 @@ def create_with_config(cls, name: str, config: dict) -> "Experiment": raise TypeError("config must be a dict or valid JSON string") elif not isinstance(config, dict): raise TypeError("config must be a dict") - exp = cls.create(name) - exp._update_json_data_field("config", config) + exp = await cls.create(name) + await exp._update_json_data_field("config", config) return exp - def update_config(self, config: dict): + async def update_config(self, config: dict): """Update entire config.""" - current_config = self._get_json_data_field("config", {}) + current_config = await self._get_json_data_field("config", {}) if isinstance(current_config, str): try: current_config = json.loads(current_config) except json.JSONDecodeError: current_config = {} current_config.update(config) - self._update_json_data_field("config", current_config) + await self._update_json_data_field("config", current_config) @classmethod - def get_all(cls): + async def get_all(cls): """Get all experiments as list of dicts.""" experiments = [] - exp_root = get_experiments_dir() - if storage.exists(exp_root): + exp_root = await get_experiments_dir() + if await storage.exists(exp_root): try: - entries = storage.ls(exp_root, detail=False) + entries = await storage.ls(exp_root, detail=False) except Exception: entries = [] for exp_path in entries: try: - if storage.isdir(exp_path): + if await storage.isdir(exp_path): index_file = storage.join(exp_path, "index.json") - if storage.exists(index_file): - with storage.open(index_file, "r", uncached=True) as f: - data = json.load(f) + if await storage.exists(index_file): + async with await storage.open(index_file, "r", uncached=True) as f: + content = await f.read() + data = json.loads(content) name = data.get("name") exp_id = data.get("id") @@ -113,8 +135,9 @@ def get_all(cls): if not name and exp_id: data["name"] = exp_id try: - with storage.open(index_file, "w") as wf: - json.dump(data, wf, indent=4) + async with await storage.open(index_file, "w") as wf: + content = json.dumps(data, indent=4) + await wf.write(content) name = exp_id except Exception: # If we couldn't persist, skip to avoid inconsistent state @@ -124,8 +147,9 @@ def get_all(cls): if not exp_id and name: data["id"] = name try: - with storage.open(index_file, "w") as wf: - json.dump(data, wf, indent=4) + async with await storage.open(index_file, "w") as wf: + content = json.dumps(data, indent=4) + await wf.write(content) exp_id = name except Exception as e: print( @@ -139,7 +163,7 @@ def get_all(cls): pass return experiments - def create_job(self): + async def create_job(self): """ Creates a new job with a blank template and returns a Job object. """ @@ -148,14 +172,14 @@ def create_job(self): # Scan the jobs directory for subdirectories with numberic names # Find the largest number and increment to get the new job ID largest_numeric_subdir = 0 - jobs_dir = get_jobs_dir() + jobs_dir = await get_jobs_dir() try: - entries = storage.ls(jobs_dir, detail=False) + entries = await storage.ls(jobs_dir, detail=False) except Exception: entries = [] for full_path in entries: entry = full_path.rstrip("/").split("/")[-1] - if entry.isdigit() and storage.isdir(full_path): + if entry.isdigit() and await storage.isdir(full_path): job_id = int(entry) if job_id > largest_numeric_subdir: largest_numeric_subdir = job_id @@ -163,12 +187,12 @@ def create_job(self): new_job_id = largest_numeric_subdir + 1 # Create job with next available job_id and associate the new job with this experiment - new_job = Job.create(new_job_id) - new_job.set_experiment(self.id) + new_job = await Job.create(new_job_id) + await new_job.set_experiment(self.id) return new_job - def get_jobs(self, type: str = "", status: str = ""): + async def get_jobs(self, type: str = "", status: str = ""): """ Get a list of jobs stored in this experiment. Uses cached data from jobs.json for completed jobs, only reads individual files for RUNNING jobs. @@ -179,12 +203,12 @@ def get_jobs(self, type: str = "", status: str = ""): # First get jobs of the passed type job_list = [] if type: - job_list = self._get_jobs_of_type(type) + job_list = await self._get_jobs_of_type(type) else: - job_list = self._get_all_jobs() + job_list = await self._get_all_jobs() # Get cached job data from jobs.json - cached_jobs = self._get_cached_jobs_data() + cached_jobs = await self._get_cached_jobs_data() # Iterate through the job list to return Job objects for valid jobs. # Also filter for status if that parameter was passed. @@ -204,20 +228,22 @@ def get_jobs(self, type: str = "", status: str = ""): ]: old_status = job_json.get("status", "") del cached_jobs[job_id] - job = Job.get(job_id) - job_json = job.get_json_data(uncached=True) + job = await Job.get(job_id) + job_json = await job.get_json_data(uncached=True) # Trigger rebuild cache if old status and new status are different if old_status != job_json.get("status", ""): - self._trigger_cache_rebuild(get_workspace_dir()) + workspace = await get_workspace_dir() + self._trigger_cache_rebuild(workspace) cached_jobs[job_id] = job_json else: # Job not in cache - job = Job.get(job_id) - job_json = job.get_json_data(uncached=True) + job = await Job.get(job_id) + job_json = await job.get_json_data(uncached=True) # Check if job is COMPLETE, STOPPED or FAILED, then update cache if job_json.get("status", "") in ["COMPLETE", "STOPPED", "FAILED"]: - self._trigger_cache_rebuild(get_workspace_dir()) + workspace = await get_workspace_dir() + self._trigger_cache_rebuild(workspace) except Exception as e: print("ERROR getting job", job_id, e) continue @@ -241,7 +267,7 @@ def get_jobs(self, type: str = "", status: str = ""): # Index for tracking which jobs belong to this Experiment ############################### - def _read_jobs_json_file(self, jobs_json_path, max_retries=5): + async def _read_jobs_json_file(self, jobs_json_path, max_retries=5): """ Read jobs.json file with retry logic for Etag mismatch errors. This handles race conditions where the file is being rebuilt while being read. @@ -254,11 +280,14 @@ def _read_jobs_json_file(self, jobs_json_path, max_retries=5): Returns: dict: The parsed JSON data from jobs.json """ + import asyncio + for attempt in range(max_retries): try: # Use uncached=True to avoid Etag caching issues - with storage.open(jobs_json_path, "r", uncached=True) as f: - jobs_data = json.load(f) + async with await storage.open(jobs_json_path, "r", uncached=True) as f: + content = await f.read() + jobs_data = json.loads(content) return jobs_data except FileNotFoundError: # File doesn't exist, let caller handle it @@ -276,13 +305,14 @@ def _read_jobs_json_file(self, jobs_json_path, max_retries=5): if attempt < max_retries - 1: # Wait a short time before retrying (exponential backoff) # Start with 0.5s and increase to give cache rebuild time - time.sleep(0.5 * (2**attempt)) + await asyncio.sleep(0.5 * (2**attempt)) continue else: # Last attempt failed, try one more time try: - with storage.open(jobs_json_path, "r", uncached=True) as f: - jobs_data = json.load(f) + async with await storage.open(jobs_json_path, "r", uncached=True) as f: + content = await f.read() + jobs_data = json.loads(content) return jobs_data except Exception: raise e @@ -290,16 +320,17 @@ def _read_jobs_json_file(self, jobs_json_path, max_retries=5): # Different exception, re-raise it raise - def _jobs_json_file(self, workspace_dir=None, experiment_id=None): + async def _jobs_json_file(self, workspace_dir=None, experiment_id=None): """ Path to jobs.json index file for this experiment. """ if workspace_dir and experiment_id: return storage.join(workspace_dir, "experiments", experiment_id, "jobs.json") - return storage.join(self.get_dir(), "jobs.json") + exp_dir = await self.get_dir() + return storage.join(exp_dir, "jobs.json") - def rebuild_jobs_index(self, workspace_dir=None): + async def rebuild_jobs_index(self, workspace_dir=None): results = {} cached_jobs = {} @@ -307,23 +338,23 @@ def rebuild_jobs_index(self, workspace_dir=None): fs_override = None if workspace_dir: # Use uncached filesystem to avoid stale directory listings and file reads - fs_override = storage._get_uncached_filesystem(workspace_dir) + fs_override = await storage._get_uncached_filesystem(workspace_dir) else: # For local workspace, also use uncached to ensure fresh data - jobs_dir = get_jobs_dir() - fs_override = storage._get_uncached_filesystem(jobs_dir) + jobs_dir = await get_jobs_dir() + fs_override = await storage._get_uncached_filesystem(jobs_dir) try: # Use provided jobs_dir or get current one if workspace_dir: jobs_directory = storage.join(workspace_dir, "jobs") else: - jobs_directory = get_jobs_dir() + jobs_directory = await get_jobs_dir() # Iterate through jobs directories and check for index.json # Sort entries numerically since job IDs are numeric strings (descending order) try: - job_entries_full = storage.ls(jobs_directory, detail=False, fs=fs_override) + job_entries_full = await storage.ls(jobs_directory, detail=False, fs=fs_override) except Exception as e: print(f"Error getting job entries full: {e}") job_entries_full = [] @@ -342,13 +373,16 @@ def rebuild_jobs_index(self, workspace_dir=None): sorted_entries = sorted(job_entries, key=lambda x: int(x), reverse=True) for entry in sorted_entries: entry_path = storage.join(jobs_directory, entry) - if not storage.isdir(entry_path, fs=fs_override): + if not await storage.isdir(entry_path, fs=fs_override): continue # Prefer the latest snapshot if available; fall back to index.json index_file = storage.join(entry_path, "index.json") try: - with storage.open(index_file, "r", encoding="utf-8", fs=fs_override, uncached=True) as lf: - content = lf.read().strip() + async with await storage.open( + index_file, "r", encoding="utf-8", fs=fs_override, uncached=True + ) as lf: + content = await lf.read() + content = content.strip() if not content: # Skip empty files continue @@ -377,12 +411,13 @@ def rebuild_jobs_index(self, workspace_dir=None): jobs_data = {"index": results, "cached_jobs": cached_jobs} if results: try: - with storage.open( - self._jobs_json_file(workspace_dir=workspace_dir, experiment_id=self.id), + jobs_json_path = await self._jobs_json_file(workspace_dir=workspace_dir, experiment_id=self.id) + async with await storage.open( + jobs_json_path, "w", fs=fs_override, ) as out: - json.dump(jobs_data, out, indent=4) + await out.write(json.dumps(jobs_data, indent=4)) except Exception as e: print(f"Error writing jobs index: {e}") pass @@ -390,14 +425,14 @@ def rebuild_jobs_index(self, workspace_dir=None): print(f"Error rebuilding jobs index: {e}") pass - def _get_cached_jobs_data(self): + async def _get_cached_jobs_data(self): """ Get cached job data from jobs.json file. If the file doesn't exist, create it with default structure. """ - jobs_json_path = self._jobs_json_file() + jobs_json_path = await self._jobs_json_file() try: - jobs_data = self._read_jobs_json_file(jobs_json_path) + jobs_data = await self._read_jobs_json_file(jobs_json_path) # Handle both old format (just index) and new format (with cached_jobs) if "cached_jobs" in jobs_data: return jobs_data["cached_jobs"] @@ -406,10 +441,10 @@ def _get_cached_jobs_data(self): return {} except FileNotFoundError: # Rebuild jobs index to discover and create jobs.json - self.rebuild_jobs_index() + await self.rebuild_jobs_index() # Try to read the newly created file try: - jobs_data = self._read_jobs_json_file(jobs_json_path) + jobs_data = await self._read_jobs_json_file(jobs_json_path) if "cached_jobs" in jobs_data: return jobs_data["cached_jobs"] else: @@ -419,14 +454,14 @@ def _get_cached_jobs_data(self): except Exception: return {} - def _get_all_jobs(self): + async def _get_all_jobs(self): """ Amalgamates all jobs in the index file. If the file doesn't exist, create it with default structure. """ - jobs_json_path = self._jobs_json_file() + jobs_json_path = await self._jobs_json_file() try: - jobs_data = self._read_jobs_json_file(jobs_json_path) + jobs_data = await self._read_jobs_json_file(jobs_json_path) # Handle both old format (just index) and new format (with index key) if "index" in jobs_data: jobs = jobs_data["index"] @@ -439,10 +474,10 @@ def _get_all_jobs(self): return results except FileNotFoundError: # Rebuild jobs index to discover and create jobs.json - self.rebuild_jobs_index() + await self.rebuild_jobs_index() # Try to read the newly created file try: - jobs_data = self._read_jobs_json_file(jobs_json_path) + jobs_data = await self._read_jobs_json_file(jobs_json_path) if "index" in jobs_data: jobs = jobs_data["index"] else: @@ -457,14 +492,16 @@ def _get_all_jobs(self): except Exception: return [] - def _get_jobs_of_type(self, type="TRAIN"): + async def _get_jobs_of_type(self, type="TRAIN"): """ " Returns all jobs of a specific type in this experiment's index file. If the file doesn't exist, create it with default structure. """ - jobs_json_path = self._jobs_json_file() + import asyncio + + jobs_json_path = await self._jobs_json_file() try: - jobs_data = self._read_jobs_json_file(jobs_json_path) + jobs_data = await self._read_jobs_json_file(jobs_json_path) # Handle both old format (just index) and new format (with index key) if "index" in jobs_data: jobs = jobs_data["index"] @@ -474,10 +511,10 @@ def _get_jobs_of_type(self, type="TRAIN"): return result except FileNotFoundError: # Rebuild jobs index to discover and create jobs.json - self.rebuild_jobs_index() + await self.rebuild_jobs_index() # Try to read the newly created file try: - jobs_data = self._read_jobs_json_file(jobs_json_path) + jobs_data = await self._read_jobs_json_file(jobs_json_path) if "index" in jobs_data: jobs = jobs_data["index"] else: @@ -496,9 +533,9 @@ def _get_jobs_of_type(self, type="TRAIN"): if is_etag_error: # Wait a bit longer for cache rebuild to complete, then retry - time.sleep(0.5) + await asyncio.sleep(0.5) try: - jobs_data = self._read_jobs_json_file(jobs_json_path) + jobs_data = await self._read_jobs_json_file(jobs_json_path) if "index" in jobs_data: jobs = jobs_data["index"] else: @@ -512,9 +549,10 @@ def _get_jobs_of_type(self, type="TRAIN"): print("Failed getting jobs:", e) return [] - def _add_job(self, job_id, type): + async def _add_job(self, job_id, type): try: - jobs_data = self._read_jobs_json_file(self._jobs_json_file()) + jobs_json_path = await self._jobs_json_file() + jobs_data = await self._read_jobs_json_file(jobs_json_path) except Exception: jobs_data = {"index": {}, "cached_jobs": {}} @@ -531,11 +569,13 @@ def _add_job(self, job_id, type): jobs[type] = [job_id] # Update the file with new structure - with storage.open(self._jobs_json_file(), "w") as f: - json.dump(jobs_data, f, indent=4) + jobs_json_path = await self._jobs_json_file() + async with await storage.open(jobs_json_path, "w") as f: + await f.write(json.dumps(jobs_data, indent=4)) # Trigger background cache rebuild - self._trigger_cache_rebuild(get_workspace_dir()) + workspace = await get_workspace_dir() + self._trigger_cache_rebuild(workspace) @classmethod def _start_background_cache_rebuild(cls): @@ -548,6 +588,8 @@ def _start_background_cache_rebuild(cls): @classmethod def _background_cache_rebuild_worker(cls): """Background worker that rebuilds caches for pending experiments.""" + import asyncio + print("STARTING CACHE REBUILD WORKER") while True: try: @@ -560,7 +602,8 @@ def _background_cache_rebuild_worker(cls): for experiment_id, workspace_dir in pending_experiments: try: exp = cls(experiment_id) - exp.rebuild_jobs_index(workspace_dir=workspace_dir) + # Run async method in sync context using asyncio.run + asyncio.run(exp.rebuild_jobs_index(workspace_dir=workspace_dir)) except Exception as e: print( f"Error rebuilding cache for experiment {experiment_id} in workspace {workspace_dir}: {e}" @@ -574,9 +617,11 @@ def _background_cache_rebuild_worker(cls): def _trigger_cache_rebuild(self, workspace_dir, sync=False): """Trigger a cache rebuild for this experiment.""" + import asyncio + if sync: - # Run synchronously (useful for tests) - self.rebuild_jobs_index(workspace_dir=workspace_dir) + # Run synchronously (useful for tests) - run async method in sync context + asyncio.run(self.rebuild_jobs_index(workspace_dir=workspace_dir)) else: # Start background thread if not running self._start_background_cache_rebuild() @@ -586,23 +631,24 @@ def _trigger_cache_rebuild(self, workspace_dir, sync=False): self._cache_rebuild_pending.add((self.id, workspace_dir)) # TODO: For experiments, delete the same way as jobs - def delete(self): + async def delete(self): """Delete the experiment and all associated jobs.""" # Delete all associated jobs - self.delete_all_jobs() + await self.delete_all_jobs() # Delete the experiment directory - exp_dir = self.get_dir() - if storage.exists(exp_dir): - storage.rm_tree(exp_dir) + exp_dir = await self.get_dir() + if await storage.exists(exp_dir): + await storage.rm_tree(exp_dir) - def delete_all_jobs(self): + async def delete_all_jobs(self): """Delete all jobs associated with this experiment.""" - all_jobs = self._get_all_jobs() + all_jobs = await self._get_all_jobs() for job_id in all_jobs: try: - job = Job.get(job_id) - job.delete() + job = await Job.get(job_id) + await job.delete() except Exception: pass # Job might not exist - self._trigger_cache_rebuild(get_workspace_dir()) + workspace = await get_workspace_dir() + self._trigger_cache_rebuild(workspace) diff --git a/lab-sdk/src/lab/job.py b/lab-sdk/src/lab/job.py index 8e5e8de71..bcd40aca8 100644 --- a/lab-sdk/src/lab/job.py +++ b/lab-sdk/src/lab/job.py @@ -16,23 +16,25 @@ def __init__(self, job_id): self.id = job_id self.should_stop = False - def get_dir(self): + async def get_dir(self): """Abstract method on BaseLabResource""" job_id_safe = secure_filename(str(self.id)) - job_dir = storage.join(dirs.get_jobs_dir(), job_id_safe) + jobs_dir = await dirs.get_jobs_dir() + job_dir = storage.join(jobs_dir, job_id_safe) return job_dir - def get_log_path(self): + async def get_log_path(self): """ Returns the path where this job should write logs. """ # Default location for log file - log_path = storage.join(self.get_dir(), f"output_{self.id}.txt") + job_dir = await self.get_dir() + log_path = storage.join(job_dir, f"output_{self.id}.txt") - if not storage.exists(log_path): + if not await storage.exists(log_path): # Then check if there is a path explicitly set in the job data try: - job_data = self.get_job_data() + job_data = await self.get_job_data() if isinstance(job_data, dict): override_path = job_data.get("output_file_path", "") if isinstance(override_path, str) and override_path.strip() != "": @@ -42,15 +44,17 @@ def get_log_path(self): # Make sure whatever log_path we return actually exists # Put an empty file there if not - if not storage.exists(log_path): - with storage.open(log_path, "w") as f: - f.write("") + if not await storage.exists(log_path): + async with await storage.open(log_path, "w") as f: + await f.write("") return log_path def _default_json(self): + # Note: _default_json can't be async as it's called during object initialization + # The output_file_path will be set properly when the job is actually used default_job_data = { - "output_file_path": self.get_log_path(), + "output_file_path": "", # Will be set when first accessed } return { "id": self.id, @@ -61,9 +65,9 @@ def _default_json(self): "progress": 0, } - def set_experiment(self, experiment_id: str, sync_rebuild: bool = False): - self._update_json_data_field("experiment_id", experiment_id) - self.update_job_data_field("experiment_name", experiment_id) + async def set_experiment(self, experiment_id: str, sync_rebuild: bool = False): + await self._update_json_data_field("experiment_id", experiment_id) + await self.update_job_data_field("experiment_name", experiment_id) # Trigger cache rebuild for the experiment to discover this job try: @@ -71,72 +75,74 @@ def set_experiment(self, experiment_id: str, sync_rebuild: bool = False): from .dirs import get_workspace_dir exp = Experiment(experiment_id) - exp._trigger_cache_rebuild(workspace_dir=get_workspace_dir(), sync=sync_rebuild) + workspace = await get_workspace_dir() + exp._trigger_cache_rebuild(workspace_dir=workspace, sync=sync_rebuild) except Exception: # Don't fail if cache rebuild trigger fails pass - def update_progress(self, progress: int): + async def update_progress(self, progress: int): """ Update the percent complete for this job. progress: int representing percent complete """ - self._update_json_data_field("progress", progress) + await self._update_json_data_field("progress", progress) - def update_status(self, status: str): + async def update_status(self, status: str): """ Update the status of this job. status: str representing the status of the job """ - self._update_json_data_field("status", status) + await self._update_json_data_field("status", status) # Trigger rebuild on every status update try: from .experiment import Experiment - experiment_id = self.get_experiment_id() + experiment_id = await self.get_experiment_id() if experiment_id: exp = Experiment(experiment_id) - exp._trigger_cache_rebuild(workspace_dir=get_workspace_dir()) + workspace = await get_workspace_dir() + exp._trigger_cache_rebuild(workspace_dir=workspace) except Exception: # Don't fail if cache rebuild trigger fails pass - def get_status(self): + async def get_status(self): """ Get the status of this job. """ - return self._get_json_data_field("status") + return await self._get_json_data_field("status") - def get_progress(self): + async def get_progress(self): """ Get the progress of this job. """ - return self._get_json_data_field("progress") + return await self._get_json_data_field("progress") - def get_job_data(self): + async def get_job_data(self): """ Get the job_data of this job. """ - return self._get_json_data_field("job_data", {}) + return await self._get_json_data_field("job_data", {}) - def set_job_data(self, job_data): - self._update_json_data_field("job_data", job_data) + async def set_job_data(self, job_data): + await self._update_json_data_field("job_data", job_data) - def set_tensorboard_output_dir(self, tensorboard_dir: str): + async def set_tensorboard_output_dir(self, tensorboard_dir: str): """ Sets the directory that tensorboard output is stored. """ - self.update_job_data_field("tensorboard_output_dir", tensorboard_dir) + await self.update_job_data_field("tensorboard_output_dir", tensorboard_dir) - def update_job_data_field(self, key: str, value): + async def update_job_data_field(self, key: str, value): """ Updates a key-value pair in the job_data JSON object. """ # Fetch current job_data (use uncached to avoid stale data) - json_data = self.get_json_data(uncached=True) + json_data = await self.get_json_data(uncached=True) # If there isn't a job_data property then make one if "job_data" not in json_data: @@ -144,9 +150,9 @@ def update_job_data_field(self, key: str, value): # Set the key property to value and save the whole object json_data["job_data"][key] = value - self._set_json_data(json_data) + await self._set_json_data(json_data) - def log_info(self, message): + async def log_info(self, message): """ Save info message to output log file and display to terminal. @@ -166,14 +172,14 @@ def log_info(self, message): # Read existing content, append new message, and write back to log file try: - log_path = self.get_log_path() - storage.makedirs(posixpath.dirname(log_path), exist_ok=True) + log_path = await self.get_log_path() + await storage.makedirs(posixpath.dirname(log_path), exist_ok=True) # Read existing content if file exists existing_content = "" - if storage.exists(log_path): - with storage.open(log_path, "r", encoding="utf-8") as f: - existing_content = f.read() + if await storage.exists(log_path): + async with await storage.open(log_path, "r", encoding="utf-8") as f: + existing_content = await f.read() # Append new message to existing content on a new line if existing_content and not existing_content.endswith("\n"): @@ -181,54 +187,54 @@ def log_info(self, message): new_content = existing_content + message_str # Write back the complete content - with storage.open(log_path, "w", encoding="utf-8") as f: - f.write(new_content) - f.flush() + async with await storage.open(log_path, "w", encoding="utf-8") as f: + await f.write(new_content) + # Note: async file objects may not have flush() except Exception: # Best-effort file logging; ignore file errors to avoid crashing job pass - def set_type(self, job_type: str): + async def set_type(self, job_type: str): """ Set the type of this job. """ - self._update_json_data_field("type", job_type) + await self._update_json_data_field("type", job_type) - def get_experiment_id(self): + async def get_experiment_id(self): """ Get the experiment_id of this job. """ - return self._get_json_data_field("experiment_id") + return await self._get_json_data_field("experiment_id") - def set_error_message(self, error_msg: str): + async def set_error_message(self, error_msg: str): """ Set an error message in the job_data. """ - self.update_job_data_field("error_msg", str(error_msg)) + await self.update_job_data_field("error_msg", str(error_msg)) - def update_sweep_progress(self, value): + async def update_sweep_progress(self, value): """ Update the 'sweep_progress' key in the job_data JSON object. """ - self.update_job_data_field("sweep_progress", value) + await self.update_job_data_field("sweep_progress", value) @classmethod - def count_running_jobs(cls): + async def count_running_jobs(cls): """ Count how many jobs are currently running. """ count = 0 - jobs_dir = dirs.get_jobs_dir() + jobs_dir = await dirs.get_jobs_dir() try: - entries = storage.ls(jobs_dir, detail=False) + entries = await storage.ls(jobs_dir, detail=False) except Exception: entries = [] for job_path in entries: - if storage.isdir(job_path): + if await storage.isdir(job_path): entry = job_path.rstrip("/").split("/")[-1] try: - job = cls.get(entry) - job_data = job.get_json_data(uncached=True) + job = await cls.get(entry) + job_data = await job.get_json_data(uncached=True) if job_data.get("status") == "RUNNING": count += 1 except Exception: @@ -236,23 +242,23 @@ def count_running_jobs(cls): return count @classmethod - def get_next_queued_job(cls): + async def get_next_queued_job(cls): """ Get the next queued job (oldest first based on directory creation time). Returns Job data dict or None if no queued jobs. """ queued_jobs = [] - jobs_dir = dirs.get_jobs_dir() + jobs_dir = await dirs.get_jobs_dir() try: - entries = storage.ls(jobs_dir, detail=False) + entries = await storage.ls(jobs_dir, detail=False) except Exception: entries = [] for job_path in entries: - if storage.isdir(job_path): + if await storage.isdir(job_path): entry = job_path.rstrip("/").split("/")[-1] try: - job = cls.get(entry) - job_data = job.get_json_data(uncached=True) + job = await cls.get(entry) + job_data = await job.get_json_data(uncached=True) if job_data.get("status") == "QUEUED": # Without ctime in object stores, sort lexicographically by job id queued_jobs.append((int(entry) if entry.isdigit() else 0, job_data)) @@ -264,30 +270,30 @@ def get_next_queued_job(cls): return queued_jobs[0][1] return None - def get_checkpoints_dir(self): + async def get_checkpoints_dir(self): """ Get the checkpoints directory path for this job. """ - return dirs.get_job_checkpoints_dir(self.id) + return await dirs.get_job_checkpoints_dir(self.id) - def get_artifacts_dir(self): + async def get_artifacts_dir(self): """ Get the artifacts directory path for this job. """ - return dirs.get_job_artifacts_dir(self.id) + return await dirs.get_job_artifacts_dir(self.id) - def get_checkpoint_paths(self): + async def get_checkpoint_paths(self): """ Get list of checkpoint paths for this job. Returns list of all items (files and dirs) in the checkpoints directory. """ try: # Scan the checkpoints directory for all items (files and dirs) - checkpoints_dir = self.get_checkpoints_dir() - if storage.exists(checkpoints_dir): + checkpoints_dir = await self.get_checkpoints_dir() + if await storage.exists(checkpoints_dir): checkpoint_files = [] try: - items = storage.ls(checkpoints_dir, detail=False) + items = await storage.ls(checkpoints_dir, detail=False) except Exception: items = [] for item_path in items: @@ -298,43 +304,44 @@ def get_checkpoint_paths(self): except Exception: return [] - def get_artifact_paths(self): + async def get_artifact_paths(self): """ Get list of artifact file paths for this job. Returns list of artifact paths from job_data or scans directory. """ try: # Scan the artifacts directory - artifacts_dir = self.get_artifacts_dir() - if storage.exists(artifacts_dir): + artifacts_dir = await self.get_artifacts_dir() + if await storage.exists(artifacts_dir): artifact_files = [] try: - items = storage.ls(artifacts_dir, detail=False) + items = await storage.ls(artifacts_dir, detail=False) except Exception: items = [] for item_path in items: - if storage.isfile(item_path): + if await storage.isfile(item_path): artifact_files.append(item_path) return sorted(artifact_files) except Exception: return [] return [] - def delete(self): + async def delete(self): """ Mark this job as deleted. """ - self.update_status("DELETED") + await self.update_status("DELETED") # Trigger cache rebuild since deleted jobs are removed from cache # This is non-blocking - just adds to pending queue try: from .experiment import Experiment - experiment_id = self.get_experiment_id() + experiment_id = await self.get_experiment_id() if experiment_id: exp = Experiment(experiment_id) - exp._trigger_cache_rebuild(workspace_dir=get_workspace_dir()) + workspace = await get_workspace_dir() + exp._trigger_cache_rebuild(workspace_dir=workspace) except Exception: # Don't fail if cache rebuild trigger fails pass diff --git a/lab-sdk/src/lab/lab_facade.py b/lab-sdk/src/lab/lab_facade.py index cd8a27af8..bfe261311 100644 --- a/lab-sdk/src/lab/lab_facade.py +++ b/lab-sdk/src/lab/lab_facade.py @@ -1,6 +1,7 @@ from __future__ import annotations import time +import asyncio from typing import Optional, Dict, Any, Union import os import io @@ -14,6 +15,41 @@ from .dataset import Dataset +def _run_async(coro): + """ + Helper to run async code from sync context. + Handles event loop detection intelligently: + - If there's a running event loop, we need to use the async version instead + - If no running loop, uses asyncio.run() or run_until_complete() + """ + try: + # Try to get the running event loop + loop = asyncio.get_running_loop() + # If we're in an async context, we can't use blocking calls + # The sync wrapper should not be used from async contexts + # Use the async version (a_* methods) instead + raise RuntimeError( + "Cannot use sync method when already in async context. " + "Use the async version instead (e.g., await lab.a_save_artifact() instead of lab.save_artifact())." + ) + except RuntimeError as e: + # Check if this is our custom error or a "no running loop" error + if "Cannot use sync method" in str(e): + raise + # No running loop - we can safely create/use one + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + # Loop is closed, create a new one + return asyncio.run(coro) + else: + # Loop exists but not running, use it + return loop.run_until_complete(coro) + except RuntimeError: + # No event loop at all, create one + return asyncio.run(coro) + + class Lab: """ Simple facade over Experiment and Job for easy usage: @@ -53,25 +89,27 @@ def init(self, experiment_id: str | None = None, config: Optional[Dict[str, Any] if existing_job_id: # Use existing job from environment variable # This will raise an error if the job doesn't exist - self._experiment = Experiment(experiment_id, create_new=False) - self._job = Job.get(existing_job_id) + self._experiment = _run_async(Experiment.create_or_get(experiment_id, create_new=False)) + self._job = _run_async(Job.get(existing_job_id)) if self._job is None: raise RuntimeError(f"Job with ID {existing_job_id} not found. Check _TFL_JOB_ID environment variable.") print(f"Using existing job ID: {existing_job_id}") # Set start_time if not already set (for remote jobs launched through providers) - job_data = self._job.get_job_data() + job_data = _run_async(self._job.get_job_data()) if not job_data.get("start_time"): - self._job.update_job_data_field("start_time", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) + _run_async( + self._job.update_job_data_field("start_time", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) + ) else: # Create new job as before - self._experiment = Experiment(experiment_id, create_new=True) - self._job = self._experiment.create_job() - self._job.update_job_data_field("start_time", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) - self._job.set_experiment(experiment_id) + self._experiment = _run_async(Experiment.create_or_get(experiment_id, create_new=True)) + self._job = _run_async(self._experiment.create_job()) + _run_async(self._job.update_job_data_field("start_time", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))) + _run_async(self._job.set_experiment(experiment_id)) print(f"Created new job ID: {self._job.id}") # Update status to RUNNING for both cases - self._job.update_status("RUNNING") + _run_async(self._job.update_status("RUNNING")) # Check for wandb integration and capture URL if available self._detect_and_capture_wandb_url() @@ -97,9 +135,9 @@ def set_config(self, config: Dict[str, Any]) -> None: if isinstance(config, dict) and "experiment_name" not in config and self._experiment is not None: config = {**config, "experiment_name": self._experiment.id} # keep the existing config with fields that are not in the new config - config_old = self._job.get_job_data() + config_old = _run_async(self._job.get_job_data()) config_new = {**config_old, **config} - self._job.set_job_data(config_new) # type: ignore[union-attr] + _run_async(self._job.set_job_data(config_new)) # type: ignore[union-attr] def get_config(self) -> Dict[str, Any]: """ @@ -125,7 +163,7 @@ def get_config(self) -> Dict[str, Any]: # ------------- convenience logging ------------- def log(self, message: str) -> None: self._ensure_initialized() - self._job.log_info(message) # type: ignore[union-attr] + _run_async(self._job.log_info(message)) # type: ignore[union-attr] # Check for wandb URL on every log operation self._check_and_capture_wandb_url() @@ -134,12 +172,21 @@ def update_progress(self, progress: int) -> None: Update job progress and check for wandb URL detection. """ self._ensure_initialized() - self._job.update_progress(progress) # type: ignore[union-attr] + _run_async(self._job.update_progress(progress)) # type: ignore[union-attr] # Check for wandb URL on every progress update self._check_and_capture_wandb_url() # ------------- checkpoint resume support ------------- def get_checkpoint_to_resume(self) -> Optional[str]: + """ + Get the checkpoint path to resume training from (sync version). + + This is a sync wrapper around the async implementation. + Use a_get_checkpoint_to_resume() if you're already in an async context. + """ + return _run_async(self.async_get_checkpoint_to_resume()) + + async def async_get_checkpoint_to_resume(self) -> Optional[str]: """ Get the checkpoint path to resume training from. @@ -153,7 +200,7 @@ def get_checkpoint_to_resume(self) -> Optional[str]: if not self._job: return None - job_data = self._job.get_job_data() + job_data = await self._job.get_job_data() if not job_data: return None @@ -164,15 +211,24 @@ def get_checkpoint_to_resume(self) -> Optional[str]: return None # Build the checkpoint path from parent job's checkpoints directory - checkpoint_path = self.get_parent_job_checkpoint_path(parent_job_id, checkpoint_name) + checkpoint_path = await self.a_get_parent_job_checkpoint_path(parent_job_id, checkpoint_name) # Verify the checkpoint exists - if checkpoint_path and storage.exists(checkpoint_path): + if checkpoint_path and await storage.exists(checkpoint_path): return checkpoint_path return None def get_parent_job_checkpoint_path(self, parent_job_id: str, checkpoint_name: str) -> Optional[str]: + """ + Get the full path to a checkpoint from a parent job (sync version). + + This is a sync wrapper around the async implementation. + Use a_get_parent_job_checkpoint_path() if you're already in an async context. + """ + return _run_async(self.async_get_parent_job_checkpoint_path(parent_job_id, checkpoint_name)) + + async def async_get_parent_job_checkpoint_path(self, parent_job_id: str, checkpoint_name: str) -> Optional[str]: """ Get the full path to a checkpoint from a parent job. @@ -187,7 +243,7 @@ def get_parent_job_checkpoint_path(self, parent_job_id: str, checkpoint_name: st Optional[str]: The full path to the checkpoint, or None if it doesn't exist """ try: - checkpoints_dir = dirs.get_job_checkpoints_dir(parent_job_id) + checkpoints_dir = await dirs.get_job_checkpoints_dir(parent_job_id) checkpoint_path = storage.join(checkpoints_dir, checkpoint_name) # Security check: ensure the checkpoint path is within the checkpoints directory @@ -212,7 +268,7 @@ def get_parent_job_checkpoint_path(self, parent_job_id: str, checkpoint_name: st if not checkpoint_path_normalized.startswith(checkpoints_dir_normalized + "/"): return None - if storage.exists(checkpoint_path_normalized): + if await storage.exists(checkpoint_path_normalized): return checkpoint_path_normalized return None @@ -232,17 +288,17 @@ def finish( Mark the job as successfully completed and set completion metadata. """ self._ensure_initialized() - self._job.update_progress(100) # type: ignore[union-attr] - self._job.update_status("COMPLETE") # type: ignore[union-attr] - self._job.update_job_data_field("completion_status", "success") # type: ignore[union-attr] - self._job.update_job_data_field("completion_details", message) # type: ignore[union-attr] - self._job.update_job_data_field("end_time", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) # type: ignore[union-attr] + _run_async(self._job.update_progress(100)) # type: ignore[union-attr] + _run_async(self._job.update_status("COMPLETE")) # type: ignore[union-attr] + _run_async(self._job.update_job_data_field("completion_status", "success")) # type: ignore[union-attr] + _run_async(self._job.update_job_data_field("completion_details", message)) # type: ignore[union-attr] + _run_async(self._job.update_job_data_field("end_time", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))) # type: ignore[union-attr] if score is not None: - self._job.update_job_data_field("score", score) # type: ignore[union-attr] + _run_async(self._job.update_job_data_field("score", score)) # type: ignore[union-attr] if additional_output_path is not None and additional_output_path.strip() != "": - self._job.update_job_data_field("additional_output_path", additional_output_path) # type: ignore[union-attr] + _run_async(self._job.update_job_data_field("additional_output_path", additional_output_path)) # type: ignore[union-attr] if plot_data_path is not None and plot_data_path.strip() != "": - self._job.update_job_data_field("plot_data_path", plot_data_path) # type: ignore[union-attr] + _run_async(self._job.update_job_data_field("plot_data_path", plot_data_path)) # type: ignore[union-attr] def save_artifact( self, @@ -250,6 +306,21 @@ def save_artifact( name: Optional[str] = None, type: Optional[str] = None, config: Optional[Dict[str, Any]] = None, + ) -> str: + """ + Save an artifact file or directory into this job's artifacts folder (sync version). + + This is a sync wrapper around the async implementation. + Use a_save_artifact() if you're already in an async context. + """ + return _run_async(self.async_save_artifact(source_path, name, type, config)) + + async def async_save_artifact( + self, + source_path: Union[str, Any], + name: Optional[str] = None, + type: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, ) -> str: """ Save an artifact file or directory into this job's artifacts folder. @@ -317,7 +388,7 @@ def save_artifact( is_image = config["is_image"] # Use the existing save_dataset method - output_path = self.save_dataset( + output_path = await self.a_save_dataset( df=df, dataset_id=dataset_id, additional_metadata=additional_metadata if additional_metadata else None, @@ -327,18 +398,20 @@ def save_artifact( # Track dataset_id in job_data try: - job_data = self._job.get_job_data() + job_data = await self._job.get_job_data() generated_datasets_list = [] if isinstance(job_data, dict): existing = job_data.get("generated_datasets", []) if isinstance(existing, list): generated_datasets_list = existing generated_datasets_list.append(dataset_id) - self._job.update_job_data_field("generated_datasets", generated_datasets_list) + await self._job.update_job_data_field("generated_datasets", generated_datasets_list) except Exception: pass - self.log(f"Dataset saved to '{output_path}' and registered as generated dataset '{dataset_id}'") + await self._job.log_info( + f"Dataset saved to '{output_path}' and registered as generated dataset '{dataset_id}'" + ) # type: ignore[union-attr] return output_path # Handle DataFrame input when type="evals" @@ -380,7 +453,7 @@ def save_artifact( raise ValueError(f"Missing required columns in DataFrame: {missing_columns}") # Determine destination directory and filename - dest_dir = dirs.get_job_eval_results_dir(job_id) + dest_dir = await dirs.get_job_eval_results_dir(job_id) if name is None or (isinstance(name, str) and name.strip() == ""): import time @@ -393,7 +466,7 @@ def save_artifact( dest = storage.join(dest_dir, filename) # Create parent directories - storage.makedirs(dest_dir, exist_ok=True) + await storage.makedirs(dest_dir, exist_ok=True) # Save DataFrame to CSV using storage module try: @@ -406,25 +479,25 @@ def save_artifact( df.to_csv(buffer, index=False) buffer.seek(0) # Then write buffer content to storage - with storage.open(dest, "w", encoding="utf-8") as f: - f.write(buffer.getvalue()) + async with await storage.open(dest, "w", encoding="utf-8") as f: + await f.write(buffer.getvalue()) except Exception as e: raise RuntimeError(f"Failed to save evaluation results to {dest}: {str(e)}") # Track in job_data try: - job_data = self._job.get_job_data() + job_data = await self._job.get_job_data() eval_results_list = [] if isinstance(job_data, dict): existing = job_data.get("eval_results", []) if isinstance(existing, list): eval_results_list = existing eval_results_list.append(dest) - self._job.update_job_data_field("eval_results", eval_results_list) + await self._job.update_job_data_field("eval_results", eval_results_list) except Exception: pass - self.log(f"Evaluation results saved to '{dest}'") + await self._job.log_info(f"Evaluation results saved to '{dest}'") # type: ignore[union-attr] return dest # Handle file path input when type="model" @@ -439,7 +512,7 @@ def save_artifact( # Check source existence: use local filesystem for local paths, storage backend for remote if is_remote: - if not storage.exists(src): + if not await storage.exists(src): raise FileNotFoundError(f"Model source does not exist: {src}") else: if not os.path.exists(src): @@ -476,25 +549,25 @@ def save_artifact( base_name = f"{job_id}_{posixpath.basename(src)}" # Save to main workspace models directory for Model Zoo visibility - models_dir = dirs.get_models_dir() + models_dir = await dirs.get_models_dir() dest = storage.join(models_dir, base_name) # Create parent directories - storage.makedirs(models_dir, exist_ok=True) + await storage.makedirs(models_dir, exist_ok=True) # Copy file or directory # Check if source is directory: use local filesystem for local paths, storage backend for remote if is_remote: - src_is_dir = storage.isdir(src) + src_is_dir = await storage.isdir(src) else: src_is_dir = os.path.isdir(src) if src_is_dir: - if storage.exists(dest): - storage.rm_tree(dest) - storage.copy_dir(src, dest) + if await storage.exists(dest): + await storage.rm_tree(dest) + await storage.copy_dir(src, dest) else: - storage.copy_file(src, dest) + await storage.copy_file(src, dest) # Initialize model service for metadata creation model_service = ModelService(base_name) @@ -503,7 +576,7 @@ def save_artifact( try: # Use provided architecture or detect it if architecture is None: - architecture = model_service.detect_architecture(dest) + architecture = await model_service.detect_architecture(dest) # Handle pipeline tag logic if pipeline_tag is None and parent_model is not None: @@ -511,7 +584,7 @@ def save_artifact( pipeline_tag = model_service.fetch_pipeline_tag(parent_model) # Determine model_filename for single-file models - model_filename = "" if storage.isdir(dest) else posixpath.basename(dest) + model_filename = "" if await storage.isdir(dest) else posixpath.basename(dest) # Prepare json_data with basic info json_data = { @@ -524,25 +597,25 @@ def save_artifact( json_data["pipeline_tag"] = pipeline_tag # Use the Model class's generate_model_json method to create metadata - model_service.generate_model_json( + await model_service.generate_model_json( architecture=architecture, model_filename=model_filename, json_data=json_data, ) - self.log(f"Model saved to Model Zoo as '{base_name}'") + await self._job.log_info(f"Model saved to Model Zoo as '{base_name}'") # type: ignore[union-attr] except Exception as e: self.log(f"Warning: Model saved but metadata creation failed: {str(e)}") # Track in job_data try: - job_data = self._job.get_job_data() + job_data = await self._job.get_job_data() model_list = [] if isinstance(job_data, dict): existing = job_data.get("models", []) if isinstance(existing, list): model_list = existing model_list.append(dest) - self._job.update_job_data_field("models", model_list) + await self._job.update_job_data_field("models", model_list) except Exception: pass @@ -559,7 +632,7 @@ def save_artifact( # Check source existence: use local filesystem for local paths, storage backend for remote if is_remote: - if not storage.exists(src): + if not await storage.exists(src): raise FileNotFoundError(f"Artifact source does not exist: {src}") else: if not os.path.exists(src): @@ -567,33 +640,33 @@ def save_artifact( # Determine destination directory based on type if type == "evals": - dest_dir = dirs.get_job_eval_results_dir(job_id) + dest_dir = await dirs.get_job_eval_results_dir(job_id) else: - dest_dir = dirs.get_job_artifacts_dir(job_id) + dest_dir = await dirs.get_job_artifacts_dir(job_id) base_name = name if (isinstance(name, str) and name.strip() != "") else posixpath.basename(src) dest = storage.join(dest_dir, base_name) # Create parent directories - storage.makedirs(dest_dir, exist_ok=True) + await storage.makedirs(dest_dir, exist_ok=True) # Copy file or directory # Check if source is directory: use local filesystem for local paths, storage backend for remote if is_remote: - src_is_dir = storage.isdir(src) + src_is_dir = await storage.isdir(src) else: src_is_dir = os.path.isdir(src) if src_is_dir: - if storage.exists(dest): - storage.rm_tree(dest) - storage.copy_dir(src, dest) + if await storage.exists(dest): + await storage.rm_tree(dest) + await storage.copy_dir(src, dest) else: - storage.copy_file(src, dest) + await storage.copy_file(src, dest) # Track in job_data based on type try: - job_data = self._job.get_job_data() + job_data = await self._job.get_job_data() if type == "evals": # For eval results, track in eval_results list eval_results_list = [] @@ -602,7 +675,7 @@ def save_artifact( if isinstance(existing, list): eval_results_list = existing eval_results_list.append(dest) - self._job.update_job_data_field("eval_results", eval_results_list) + await self._job.update_job_data_field("eval_results", eval_results_list) else: # For regular artifacts, track in artifacts list artifact_list = [] @@ -611,7 +684,7 @@ def save_artifact( if isinstance(existing, list): artifact_list = existing artifact_list.append(dest) - self._job.update_job_data_field("artifacts", artifact_list) + await self._job.update_job_data_field("artifacts", artifact_list) except Exception: pass @@ -624,6 +697,22 @@ def save_dataset( additional_metadata: Optional[Dict[str, Any]] = None, suffix: Optional[str] = None, is_image: bool = False, + ) -> str: + """ + Save a dataset under the workspace datasets directory (sync version). + + This is a sync wrapper around the async implementation. + Use a_save_dataset() if you're already in an async context. + """ + return _run_async(self.async_save_dataset(df, dataset_id, additional_metadata, suffix, is_image)) + + async def async_save_dataset( + self, + df, + dataset_id: str, + additional_metadata: Optional[Dict[str, Any]] = None, + suffix: Optional[str] = None, + is_image: bool = False, ) -> str: """ Save a dataset under the workspace datasets directory and mark it as generated. @@ -651,11 +740,11 @@ def save_dataset( # Prepare dataset directory dataset_id_safe = dataset_id.strip() - dataset_dir = dirs.dataset_dir_by_id(dataset_id_safe) + dataset_dir = await dirs.dataset_dir_by_id(dataset_id_safe) # If exists, then raise an error - if storage.exists(dataset_dir): + if await storage.exists(dataset_dir): raise FileExistsError(f"Dataset with ID {dataset_id_safe} already exists") - storage.makedirs(dataset_dir, exist_ok=True) + await storage.makedirs(dataset_dir, exist_ok=True) # Determine output filename if is_image: @@ -679,17 +768,17 @@ def save_dataset( df.to_json(buffer, orient="records", lines=lines) buffer.seek(0) # Then write buffer content to storage - with storage.open(output_path, "w", encoding="utf-8") as f: - f.write(buffer.getvalue()) + async with await storage.open(output_path, "w", encoding="utf-8") as f: + await f.write(buffer.getvalue()) except Exception as e: raise RuntimeError(f"Failed to save dataset to {output_path}: {str(e)}") # Create or update filesystem metadata so it appears under generated datasets try: try: - ds = Dataset.get(dataset_id_safe) + ds = await Dataset.get(dataset_id_safe) except FileNotFoundError: - ds = Dataset.create(dataset_id_safe) + ds = await Dataset.create(dataset_id_safe) # Base json_data with generated flag for UI filtering json_data: Dict[str, Any] = { @@ -700,7 +789,7 @@ def save_dataset( if additional_metadata and isinstance(additional_metadata, dict): json_data.update(additional_metadata) - ds.set_metadata( + await ds.set_metadata( location="local", description=json_data.get("description", ""), size=-1, @@ -710,20 +799,31 @@ def save_dataset( # Do not fail the save if metadata write fails; log to job data print(f"Warning: Failed to create dataset metadata: {str(e)}") try: - self._job.update_job_data_field("dataset_metadata_error", str(e)) # type: ignore[union-attr] + await self._job.update_job_data_field("dataset_metadata_error", str(e)) # type: ignore[union-attr] except Exception as e2: print(f"Warning: Failed to log dataset metadata error: {str(e2)}") # Track dataset on the job for provenance try: - self._job.update_job_data_field("dataset_id", dataset_id_safe) # type: ignore[union-attr] + await self._job.update_job_data_field("dataset_id", dataset_id_safe) # type: ignore[union-attr] except Exception as e: print(f"Warning: Failed to track dataset in job_data: {str(e)}") - self.log(f"Dataset saved to '{output_path}' and registered as generated dataset '{dataset_id_safe}'") + await self._job.log_info( + f"Dataset saved to '{output_path}' and registered as generated dataset '{dataset_id_safe}'" + ) # type: ignore[union-attr] return output_path def save_checkpoint(self, source_path: str, name: Optional[str] = None) -> str: + """ + Save a checkpoint file or directory into this job's checkpoints folder (sync version). + + This is a sync wrapper around the async implementation. + Use async_save_checkpoint() if you're already in an async context. + """ + return _run_async(self.async_save_checkpoint(source_path, name)) + + async def async_save_checkpoint(self, source_path: str, name: Optional[str] = None) -> str: """ Save a checkpoint file or directory into this job's checkpoints folder. Returns the destination path on disk. @@ -739,45 +839,45 @@ def save_checkpoint(self, source_path: str, name: Optional[str] = None) -> str: # Check source existence: use local filesystem for local paths, storage backend for remote if is_remote: - if not storage.exists(src): + if not await storage.exists(src): raise FileNotFoundError(f"Checkpoint source does not exist: {src}") else: if not os.path.exists(src): raise FileNotFoundError(f"Checkpoint source does not exist: {src}") job_id = self._job.id # type: ignore[union-attr] - ckpts_dir = dirs.get_job_checkpoints_dir(job_id) + ckpts_dir = await dirs.get_job_checkpoints_dir(job_id) base_name = name if (isinstance(name, str) and name.strip() != "") else posixpath.basename(src) dest = storage.join(ckpts_dir, base_name) # Create parent directories - storage.makedirs(ckpts_dir, exist_ok=True) + await storage.makedirs(ckpts_dir, exist_ok=True) # Copy file or directory # Check if source is directory: use local filesystem for local paths, storage backend for remote if is_remote: - src_is_dir = storage.isdir(src) + src_is_dir = await storage.isdir(src) else: src_is_dir = os.path.isdir(src) if src_is_dir: - if storage.exists(dest): - storage.rm_tree(dest) - storage.copy_dir(src, dest) + if await storage.exists(dest): + await storage.rm_tree(dest) + await storage.copy_dir(src, dest) else: - storage.copy_file(src, dest) + await storage.copy_file(src, dest) # Track in job_data and update latest pointer try: - job_data = self._job.get_job_data() + job_data = await self._job.get_job_data() ckpt_list = [] if isinstance(job_data, dict): existing = job_data.get("checkpoints", []) if isinstance(existing, list): ckpt_list = existing ckpt_list.append(dest) - self._job.update_job_data_field("checkpoints", ckpt_list) - self._job.update_job_data_field("latest_checkpoint", dest) + await self._job.update_job_data_field("checkpoints", ckpt_list) + await self._job.update_job_data_field("latest_checkpoint", dest) except Exception as e: print(f"Warning: Failed to track checkpoint in job_data: {str(e)}") @@ -790,6 +890,22 @@ def save_model( architecture: Optional[str] = None, pipeline_tag: Optional[str] = None, parent_model: Optional[str] = None, + ) -> str: + """ + Save a model file or directory to the workspace models directory (sync version). + + This is a sync wrapper around the async implementation. + Use async_save_model() if you're already in an async context. + """ + return _run_async(self.async_save_model(source_path, name, architecture, pipeline_tag, parent_model)) + + async def async_save_model( + self, + source_path: str, + name: Optional[str] = None, + architecture: Optional[str] = None, + pipeline_tag: Optional[str] = None, + parent_model: Optional[str] = None, ) -> str: """ Save a model file or directory to the workspace models directory. @@ -821,7 +937,7 @@ def save_model( config["parent_model"] = parent_model # Use save_artifact with type="model" - return self.save_artifact( + return await self.a_save_artifact( source_path=source_path, name=name, type="model", @@ -836,11 +952,11 @@ def error( Mark the job as failed and set completion metadata. """ self._ensure_initialized() - self._job.update_status("COMPLETE") # type: ignore[union-attr] - self._job.update_job_data_field("completion_status", "failed") # type: ignore[union-attr] - self._job.update_job_data_field("completion_details", message) # type: ignore[union-attr] - self._job.update_job_data_field("end_time", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) # type: ignore[union-attr] - self._job.update_job_data_field("status", "FAILED") # type: ignore[union-attr] + _run_async(self._job.update_status("COMPLETE")) # type: ignore[union-attr] + _run_async(self._job.update_job_data_field("completion_status", "failed")) # type: ignore[union-attr] + _run_async(self._job.update_job_data_field("completion_details", message)) # type: ignore[union-attr] + _run_async(self._job.update_job_data_field("end_time", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))) # type: ignore[union-attr] + _run_async(self._job.update_job_data_field("status", "FAILED")) # type: ignore[union-attr] def _detect_and_capture_wandb_url(self) -> None: """ @@ -854,7 +970,7 @@ def _detect_and_capture_wandb_url(self) -> None: # Method 1: Check environment variables set by wandb wandb_url = os.environ.get("WANDB_URL") if wandb_url: - self._job.update_job_data_field("wandb_run_url", wandb_url) + _run_async(self._job.update_job_data_field("wandb_run_url", wandb_url)) print(f"📊 Detected wandb run URL: {wandb_url}") return @@ -865,7 +981,7 @@ def _detect_and_capture_wandb_url(self) -> None: if wandb.run is not None: wandb_url = wandb.run.url if wandb_url: - self._job.update_job_data_field("wandb_run_url", wandb_url) + _run_async(self._job.update_job_data_field("wandb_run_url", wandb_url)) print(f"📊 Detected wandb run URL: {wandb_url}") return except ImportError: @@ -883,7 +999,7 @@ def _detect_and_capture_wandb_url(self) -> None: if current_run and hasattr(current_run, "url"): wandb_url = current_run.url if wandb_url: - self._job.update_job_data_field("wandb_run_url", wandb_url) + _run_async(self._job.update_job_data_field("wandb_run_url", wandb_url)) print(f"📊 Detected wandb run URL: {wandb_url}") return except (ImportError, AttributeError): @@ -900,14 +1016,14 @@ def _check_and_capture_wandb_url(self) -> None: """ try: # Only check if we haven't already captured a wandb URL - job_data = self._job.get_job_data() + job_data = _run_async(self._job.get_job_data()) if job_data.get("wandb_run_url"): return # Already have a wandb URL # Method 1: Check environment variables wandb_url = os.environ.get("WANDB_URL") if wandb_url: - self._job.update_job_data_field("wandb_run_url", wandb_url) + _run_async(self._job.update_job_data_field("wandb_run_url", wandb_url)) print(f"📊 Auto-detected wandb URL from environment: {wandb_url}") return @@ -918,7 +1034,7 @@ def _check_and_capture_wandb_url(self) -> None: if wandb.run is not None and hasattr(wandb.run, "url"): wandb_url = wandb.run.url if wandb_url: - self._job.update_job_data_field("wandb_run_url", wandb_url) + _run_async(self._job.update_job_data_field("wandb_run_url", wandb_url)) print(f"📊 Auto-detected wandb URL from wandb.run: {wandb_url}") return except ImportError: @@ -935,7 +1051,7 @@ def capture_wandb_url(self, wandb_url: str) -> None: """ if wandb_url and wandb_url.strip(): self._ensure_initialized() - self._job.update_job_data_field("wandb_run_url", wandb_url.strip()) + _run_async(self._job.update_job_data_field("wandb_run_url", wandb_url.strip())) print(f"📊 Captured wandb run URL: {wandb_url.strip()}") # ------------- helpers ------------- @@ -949,32 +1065,68 @@ def job(self) -> Job: return self._job # type: ignore[return-value] def get_checkpoints_dir(self) -> str: + """ + Get the checkpoints directory path for the current job (sync version). + + This is a sync wrapper around the async implementation. + Use async_get_checkpoints_dir() if you're already in an async context. + """ + return _run_async(self.async_get_checkpoints_dir()) + + async def async_get_checkpoints_dir(self) -> str: """ Get the checkpoints directory path for the current job. """ self._ensure_initialized() - return self._job.get_checkpoints_dir() # type: ignore[union-attr] + return await self._job.get_checkpoints_dir() # type: ignore[union-attr] def get_artifacts_dir(self) -> str: + """ + Get the artifacts directory path for the current job (sync version). + + This is a sync wrapper around the async implementation. + Use a_get_artifacts_dir() if you're already in an async context. + """ + return _run_async(self.async_get_artifacts_dir()) + + async def async_get_artifacts_dir(self) -> str: """ Get the artifacts directory path for the current job. """ self._ensure_initialized() - return self._job.get_artifacts_dir() # type: ignore[union-attr] + return await self._job.get_artifacts_dir() # type: ignore[union-attr] def get_checkpoint_paths(self) -> list[str]: + """ + Get list of checkpoint file paths for the current job (sync version). + + This is a sync wrapper around the async implementation. + Use async_get_checkpoint_paths() if you're already in an async context. + """ + return _run_async(self.async_get_checkpoint_paths()) + + async def async_get_checkpoint_paths(self) -> list[str]: """ Get list of checkpoint file paths for the current job. """ self._ensure_initialized() - return self._job.get_checkpoint_paths() # type: ignore[union-attr] + return await self._job.get_checkpoint_paths() # type: ignore[union-attr] def get_artifact_paths(self) -> list[str]: + """ + Get list of artifact file paths for the current job (sync version). + + This is a sync wrapper around the async implementation. + Use a_get_artifact_paths() if you're already in an async context. + """ + return _run_async(self.async_get_artifact_paths()) + + async def async_get_artifact_paths(self) -> list[str]: """ Get list of artifact file paths for the current job. """ self._ensure_initialized() - return self._job.get_artifact_paths() # type: ignore[union-attr] + return await self._job.get_artifact_paths() # type: ignore[union-attr] def list_models(self) -> list[Dict[str, Any]]: """ @@ -986,7 +1138,7 @@ def list_models(self) -> list[Dict[str, Any]]: - name: The model name - json_data: Additional model metadata """ - return ModelService.list_all() + return _run_async(ModelService.list_all()) def get_model(self, model_id: str) -> ModelService: """ @@ -1001,7 +1153,7 @@ def get_model(self, model_id: str) -> ModelService: Raises: FileNotFoundError: If the model directory doesn't exist """ - return ModelService.get(model_id) + return _run_async(ModelService.get(model_id)) def get_model_path(self, model_id: str) -> str: """ @@ -1016,14 +1168,31 @@ def get_model_path(self, model_id: str) -> str: Raises: FileNotFoundError: If the model doesn't exist """ - model = ModelService.get(model_id) - return model.get_dir() + model = _run_async(ModelService.get(model_id)) + return _run_async(model.get_dir()) @property def experiment(self) -> Experiment: self._ensure_initialized() return self._experiment # type: ignore[return-value] + def get_job_data(self) -> Dict[str, Any]: + """ + Get the job data dictionary (sync version). + + This is a sync wrapper around the async Job.get_job_data() method. + Use async_get_job_data() if you're already in an async context. + """ + self._ensure_initialized() + return _run_async(self._job.get_job_data()) # type: ignore[union-attr] + + async def async_get_job_data(self) -> Dict[str, Any]: + """ + Get the job data dictionary (async version). + """ + self._ensure_initialized() + return await self._job.get_job_data() # type: ignore[union-attr] + def get_hf_callback(self): """ Get a HuggingFace TrainerCallback instance for TransformerLab integration. diff --git a/lab-sdk/src/lab/labresource.py b/lab-sdk/src/lab/labresource.py index 6410bd756..96e50b5aa 100644 --- a/lab-sdk/src/lab/labresource.py +++ b/lab-sdk/src/lab/labresource.py @@ -20,29 +20,29 @@ def get_dir(self) -> str: pass @classmethod - def create(cls, id): + async def create(cls, id): """ Default method to create a new entity and initialize it with defualt metadata. """ newobj = cls(id) - newobj._initialize() + await newobj._initialize() return newobj @classmethod - def get(cls, id): + async def get(cls, id): """ Default method to get entity if it exists in the file system. If the entity's directory doesn't exist then throw an error. If the entity's metadata file does not exist then create a default. """ newobj = cls(id) - resource_dir = newobj.get_dir() - if not storage.isdir(resource_dir): + resource_dir = await newobj.get_dir() + if not await storage.isdir(resource_dir): raise FileNotFoundError(f"Directory for {cls.__name__} with id '{id}' not found") - json_file = newobj._get_json_file() - if not storage.exists(json_file): - with storage.open(json_file, "w", encoding="utf-8") as f: - json.dump(newobj._default_json(), f) + json_file = await newobj._get_json_file() + if not await storage.exists(json_file): + async with await storage.open(json_file, "w", encoding="utf-8") as f: + await f.write(json.dumps(newobj._default_json())) return newobj ### @@ -50,33 +50,33 @@ def get(cls, id): # There are used by all subclasses to initialize, get and set JSON data ### - def _initialize(self): + async def _initialize(self): """ Default function to initialize the file system and json object. To alter the default metadata update the _default_json method. """ # Create directory for this resource - dir = self.get_dir() - storage.makedirs(dir, exist_ok=True) + dir = await self.get_dir() + await storage.makedirs(dir, exist_ok=True) print(f"Created directory for {type(self).__name__} with id '{self.id}'") # Create a default json file. Throw an error if one already exists. - json_file = self._get_json_file() - if storage.exists(json_file): + json_file = await self._get_json_file() + if await storage.exists(json_file): raise FileExistsError(f"{type(self).__name__} with id '{self.id}' already exists") - with storage.open(json_file, "w", encoding="utf-8") as f: - json.dump(self._default_json(), f) + async with await storage.open(json_file, "w", encoding="utf-8") as f: + await f.write(json.dumps(self._default_json())) def _default_json(self): """Override in subclasses to support the initialize method.""" return {"id": self.id} - def _get_json_file(self): + async def _get_json_file(self): """Get json file containing metadata for this resource.""" - return storage.join(self.get_dir(), "index.json") + return storage.join(await self.get_dir(), "index.json") - def get_json_data(self, uncached: bool = False, max_retries: int = 5): + async def get_json_data(self, uncached: bool = False, max_retries: int = 5): """ Return the JSON data that is stored for this resource in the filesystem. If the file doesn't exist then return an empty dict. @@ -85,16 +85,16 @@ def get_json_data(self, uncached: bool = False, max_retries: int = 5): uncached: If True, use an uncached filesystem to avoid Etag caching issues max_retries: Maximum number of retries for Etag errors (default: 5) """ - import time + import asyncio - json_file = self._get_json_file() + json_file = await self._get_json_file() # Try opening this file location and parsing the json inside # On any error return an empty dict for attempt in range(max_retries): try: - with storage.open(json_file, "r", encoding="utf-8", uncached=uncached) as f: - content = f.read() + async with await storage.open(json_file, "r", encoding="utf-8") as f: + content = await f.read() # Clean the content - remove trailing whitespace and extra characters content = content.strip() # Remove any trailing % characters (common in some shell outputs) @@ -118,7 +118,7 @@ def get_json_data(self, uncached: bool = False, max_retries: int = 5): if is_etag_error: if attempt < max_retries - 1: # Wait a short time before retrying (exponential backoff) - time.sleep(0.5 * (2**attempt)) + await asyncio.sleep(0.5 * (2**attempt)) continue else: # Last attempt failed, return empty dict @@ -127,7 +127,7 @@ def get_json_data(self, uncached: bool = False, max_retries: int = 5): # Different exception, return empty dict return {} - def _set_json_data(self, json_data): + async def _set_json_data(self, json_data): """ Sets the entire JSON data that is stored for this resource in the filesystem. This will overwrite whatever is stored now. @@ -140,26 +140,26 @@ def _set_json_data(self, json_data): raise TypeError("json_data must be a dict") # Write directly to index.json - json_file = self._get_json_file() - with storage.open(json_file, "w", encoding="utf-8") as f: - json.dump(json_data, f, ensure_ascii=False) + json_file = await self._get_json_file() + async with await storage.open(json_file, "w", encoding="utf-8") as f: + await f.write(json.dumps(json_data, ensure_ascii=False)) - def _get_json_data_field(self, key, default=""): + async def _get_json_data_field(self, key, default=""): """Gets the value of a single top-level field in a JSON object""" - json_data = self.get_json_data(uncached=True) + json_data = await self.get_json_data(uncached=True) return json_data.get(key, default) - def _update_json_data_field(self, key: str, value): + async def _update_json_data_field(self, key: str, value): """Sets the value of a single top-level field in a JSON object""" - json_data = self.get_json_data(uncached=True) + json_data = await self.get_json_data(uncached=True) json_data[key] = value - self._set_json_data(json_data) + await self._set_json_data(json_data) - def delete(self): + async def delete(self): """ Delete this resource by deleting the containing directory. TODO: We should change to soft delete """ - resource_dir = self.get_dir() - if storage.exists(resource_dir): - storage.rm_tree(resource_dir) + resource_dir = await self.get_dir() + if await storage.exists(resource_dir): + await storage.rm_tree(resource_dir) diff --git a/lab-sdk/src/lab/model.py b/lab-sdk/src/lab/model.py index d0cc4f227..8971ea48b 100644 --- a/lab-sdk/src/lab/model.py +++ b/lab-sdk/src/lab/model.py @@ -7,10 +7,11 @@ class Model(BaseLabResource): - def get_dir(self): + async def get_dir(self): """Abstract method on BaseLabResource""" model_id_safe = secure_filename(str(self.id)) - return storage.join(get_models_dir(), model_id_safe) + models_dir = await get_models_dir() + return storage.join(models_dir, model_id_safe) def _default_json(self): # Default metadata modeled after API model table fields @@ -20,7 +21,7 @@ def _default_json(self): "json_data": {}, } - def set_metadata( + async def set_metadata( self, *, model_id: str | None = None, @@ -28,7 +29,7 @@ def set_metadata( json_data: dict | None = None, ): """Set model metadata, similar to dataset service""" - data = self.get_json_data() + data = await self.get_json_data() if model_id is not None: data["model_id"] = model_id if name is not None: @@ -40,42 +41,42 @@ def set_metadata( current = {} current.update(json_data) data["json_data"] = current - self._set_json_data(data) + await self._set_json_data(data) - def get_metadata(self): + async def get_metadata(self): """Get model metadata""" - return self.get_json_data() + return await self.get_json_data() @staticmethod - def list_all(): + async def list_all(): """List all models in the filesystem, similar to dataset service""" results = [] - models_dir = get_models_dir() - if not storage.isdir(models_dir): + models_dir = await get_models_dir() + if not await storage.isdir(models_dir): return results try: - entries = storage.ls(models_dir, detail=False) + entries = await storage.ls(models_dir, detail=False) except Exception: entries = [] for full in entries: - if not storage.isdir(full): + if not await storage.isdir(full): continue # Attempt to read index.json (or latest snapshot) try: entry = full.rstrip("/").split("/")[-1] model = Model(entry) - results.append(model.get_metadata()) + results.append(await model.get_metadata()) except Exception: continue return results - def import_model(self, model_name, model_path): + async def import_model(self, model_name, model_path): """ Given a model name and path, create a new model that can be used in the workspace. """ - self.generate_model_json(model_name, model_path) + await self.generate_model_json(model_name, model_path) - def detect_architecture(self, model_path: str) -> str: + async def detect_architecture(self, model_path: str) -> str: """ Detect the model architecture from a model directory's config.json. @@ -87,12 +88,13 @@ def detect_architecture(self, model_path: str) -> str: """ architecture = "Unknown" - if storage.isdir(model_path): + if await storage.isdir(model_path): config_path = storage.join(model_path, "config.json") - if storage.exists(config_path): + if await storage.exists(config_path): try: - with storage.open(config_path, "r") as f: - config = json.load(f) + async with await storage.open(config_path, "r") as f: + content = await f.read() + config = json.loads(content) architectures = config.get("architectures", []) if architectures: architecture = architectures[0] @@ -121,7 +123,7 @@ def fetch_pipeline_tag(self, parent_model: str) -> str | None: print(f"Could not fetch pipeline tag from parent model '{parent_model}': {type(e).__name__}: {e}") return None - def create_md5_checksums(self, model_path: str) -> list: + async def create_md5_checksums(self, model_path: str) -> list: """ Create MD5 checksums for all files in the model directory. @@ -133,36 +135,36 @@ def create_md5_checksums(self, model_path: str) -> list: """ import hashlib - def compute_md5(file_path): + async def compute_md5(file_path): md5 = hashlib.md5() - with storage.open(file_path, "rb") as f: - while chunk := f.read(8192): + async with await storage.open(file_path, "rb") as f: + while chunk := await f.read(8192): md5.update(chunk) return md5.hexdigest() md5_objects = [] - if not storage.isdir(model_path): + if not await storage.isdir(model_path): print(f"Model path '{model_path}' is not a directory, skipping MD5 checksum creation") return md5_objects # Use fsspec's walk equivalent for directory traversal try: - files = storage.find(model_path) + files = await storage.find(model_path) for file_path in files: try: - md5_hash = compute_md5(file_path) + md5_hash = await compute_md5(file_path) md5_objects.append({"file_path": file_path, "md5_hash": md5_hash}) except Exception as e: print(f"Warning: Could not compute MD5 for {file_path}: {str(e)}") except Exception: # Fallback: if find doesn't work, try listing the directory try: - entries = storage.ls(model_path, detail=False) + entries = await storage.ls(model_path, detail=False) for entry in entries: - if storage.isfile(entry): + if await storage.isfile(entry): try: - md5_hash = compute_md5(entry) + md5_hash = await compute_md5(entry) md5_objects.append({"file_path": entry, "md5_hash": md5_hash}) except Exception as e: print(f"Warning: Could not compute MD5 for {entry}: {str(e)}") @@ -171,7 +173,7 @@ def compute_md5(file_path): return md5_objects - def create_provenance_file( + async def create_provenance_file( self, model_path: str, model_name: str = None, @@ -223,12 +225,12 @@ def create_provenance_file( # Write provenance to file provenance_path = storage.join(model_path, "_tlab_provenance.json") - with storage.open(provenance_path, "w") as f: - json.dump(final_provenance, f, indent=2) + async with await storage.open(provenance_path, "w") as f: + await f.write(json.dumps(final_provenance, indent=2)) return provenance_path - def generate_model_json( + async def generate_model_json( self, architecture: str, model_filename: str = "", @@ -265,7 +267,8 @@ def generate_model_json( model_description["json_data"].update(json_data) # Output the json to the file - with storage.open(storage.join(self.get_dir(), "index.json"), "w") as outfile: - json.dump(model_description, outfile) + model_dir = await self.get_dir() + async with await storage.open(storage.join(model_dir, "index.json"), "w") as outfile: + await outfile.write(json.dumps(model_description)) return model_description diff --git a/lab-sdk/src/lab/storage.py b/lab-sdk/src/lab/storage.py index bd2937bea..db7c3c688 100644 --- a/lab-sdk/src/lab/storage.py +++ b/lab-sdk/src/lab/storage.py @@ -1,8 +1,112 @@ import os import posixpath import contextvars +from types import TracebackType +from typing import Optional, Type import fsspec +import aiofiles + + +class AsyncFileWrapper: + """ + Wrapper to make sync file objects work with async context managers. + This allows sync filesystem file objects to be used with 'async with'. + """ + + def __init__(self, file_obj): + # Store the file object (which may be a context manager) + self._file_obj = file_obj + self.file_obj = None + self._is_context_manager = hasattr(file_obj, "__enter__") and hasattr(file_obj, "__exit__") + + async def __aenter__(self): + # Enter the sync context manager if it is one + if self._is_context_manager: + self.file_obj = self._file_obj.__enter__() + else: + self.file_obj = self._file_obj + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if self._is_context_manager: + # Exit the sync context manager + self._file_obj.__exit__(exc_type, exc_val, exc_tb) + elif self.file_obj and hasattr(self.file_obj, "close"): + # Just close if no context manager protocol + self.file_obj.close() + self.file_obj = None + + # Override common I/O methods to make them async-compatible + async def read(self, size=-1): + """Read from the file (async wrapper for sync read).""" + if self.file_obj is None: + raise ValueError("File object not initialized. Use 'async with' to open the file.") + return self.file_obj.read(size) + + async def write(self, data): + """Write to the file (async wrapper for sync write).""" + if self.file_obj is None: + raise ValueError("File object not initialized. Use 'async with' to open the file.") + return self.file_obj.write(data) + + async def readline(self, size=-1): + """Read a line from the file (async wrapper for sync readline).""" + if self.file_obj is None: + raise ValueError("File object not initialized. Use 'async with' to open the file.") + return self.file_obj.readline(size) + + async def readlines(self, hint=-1): + """Read all lines from the file (async wrapper for sync readlines).""" + if self.file_obj is None: + raise ValueError("File object not initialized. Use 'async with' to open the file.") + return self.file_obj.readlines(hint) + + async def seek(self, offset, whence=0): + """Seek to a position in the file (async wrapper for sync seek).""" + if self.file_obj is None: + raise ValueError("File object not initialized. Use 'async with' to open the file.") + return self.file_obj.seek(offset, whence) + + async def tell(self): + """Get current file position (async wrapper for sync tell).""" + if self.file_obj is None: + raise ValueError("File object not initialized. Use 'async with' to open the file.") + return self.file_obj.tell() + + async def flush(self): + """Flush the file buffer (async wrapper for sync flush).""" + if self.file_obj is None: + raise ValueError("File object not initialized. Use 'async with' to open the file.") + return self.file_obj.flush() + + def __getattr__(self, name): + # Delegate all other attributes to the underlying file object + if self.file_obj is None: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + return getattr(self.file_obj, name) + + def __iter__(self): + if self.file_obj is None: + raise ValueError("File object not initialized. Use 'async with' to open the file.") + return iter(self.file_obj) + + def __aiter__(self): + # For async iteration, we need to wrap the sync iterator + return self + + async def __anext__(self): + if self.file_obj is None: + raise ValueError("File object not initialized. Use 'async with' to open the file.") + try: + return next(self.file_obj) + except StopIteration: + raise StopAsyncIteration # Context variable for storage URI (set by host app/session) @@ -54,17 +158,17 @@ def _get_fs_and_root(): return fs, root -def root_uri() -> str: +async def root_uri() -> str: _, root = _get_fs_and_root() return root -def filesystem(): +async def filesystem(): fs, _ = _get_fs_and_root() return fs -def debug_info() -> dict: +async def debug_info() -> dict: """Debug information about the current storage configuration.""" context_uri = _current_tfl_storage_uri.get() env_uri = os.getenv("TFL_STORAGE_URI") @@ -84,41 +188,45 @@ def join(*parts: str) -> str: return posixpath.join(*parts) -def root_join(*parts: str) -> str: - return join(root_uri(), *parts) +async def root_join(*parts: str) -> str: + root = await root_uri() + return join(root, *parts) -def exists(path: str) -> bool: - return filesystem().exists(path) +async def exists(path: str) -> bool: + fs = await filesystem() + return fs.exists(path) -def isdir(path: str, fs=None) -> bool: +async def isdir(path: str, fs=None) -> bool: try: - filesys = fs if fs is not None else filesystem() + filesys = fs if fs is not None else await filesystem() return filesys.isdir(path) except Exception: return False -def isfile(path: str) -> bool: +async def isfile(path: str) -> bool: try: - return filesystem().isfile(path) + fs = await filesystem() + return fs.isfile(path) except Exception: return False -def makedirs(path: str, exist_ok: bool = True) -> None: +async def makedirs(path: str, exist_ok: bool = True) -> None: + fs = await filesystem() try: - filesystem().makedirs(path, exist_ok=exist_ok) + fs.makedirs(path, exist_ok=exist_ok) except TypeError: # Some filesystems don't support exist_ok parameter - if not exist_ok or not exists(path): - filesystem().makedirs(path) + if not exist_ok or not await exists(path): + fs.makedirs(path) -def ls(path: str, detail: bool = False, fs=None): +async def ls(path: str, detail: bool = False, fs=None): # Use provided filesystem or get default - filesys = fs if fs is not None else filesystem() + filesys = fs if fs is not None else await filesystem() # Let fsspec parse the URI paths = filesys.ls(path, detail=detail) # Dont include the current path in the list @@ -139,11 +247,12 @@ def ls(path: str, detail: bool = False, fs=None): return paths -def find(path: str) -> list[str]: - return filesystem().find(path) +async def find(path: str) -> list[str]: + fs = await filesystem() + return fs.find(path) -def walk(path: str, maxdepth=None, topdown=True, on_error="omit"): +async def walk(path: str, maxdepth=None, topdown=True, on_error="omit"): """ Walk directory tree, yielding (root, dirs, files) tuples. @@ -156,30 +265,36 @@ def walk(path: str, maxdepth=None, topdown=True, on_error="omit"): Yields: (root, dirs, files) tuples similar to os.walk() """ - return filesystem().walk(path, maxdepth=maxdepth, topdown=topdown, on_error=on_error) + fs = await filesystem() + return fs.walk(path, maxdepth=maxdepth, topdown=topdown, on_error=on_error) -def rm(path: str) -> None: - if exists(path): - filesystem().rm(path) +async def rm(path: str) -> None: + if await exists(path): + fs = await filesystem() + fs.rm(path) -def rm_tree(path: str) -> None: - if exists(path): +async def rm_tree(path: str) -> None: + if await exists(path): + fs = await filesystem() try: - filesystem().rm(path, recursive=True) + fs.rm(path, recursive=True) except TypeError: # Some filesystems don't support recursive parameter # Use find() to get all files and remove them individually - files = find(path) + files = await find(path) for file_path in reversed(files): # Remove files before directories - filesystem().rm(file_path) + fs.rm(file_path) -def open(path: str, mode: str = "r", fs=None, uncached: bool = False, **kwargs): +async def open(path: str, mode: str = "r", fs=None, uncached: bool = False, **kwargs): """ Open a file for reading or writing. + For local files, uses aiofiles for truly async file I/O. + For remote files (S3, GCS, etc.), uses fsspec sync file objects. + Args: path: Path to the file mode: File mode ('r', 'w', etc.) @@ -188,18 +303,29 @@ def open(path: str, mode: str = "r", fs=None, uncached: bool = False, **kwargs): **kwargs: Additional arguments passed to filesystem.open() Returns: - File-like object + File-like object (context manager for remote, async context manager for local) """ if uncached: # Create an uncached filesystem instance # If fs is provided, use it to infer protocol/storage options, otherwise infer from path - filesys = _get_uncached_filesystem(path, fs=fs) + filesys = await _get_uncached_filesystem(path, fs=fs) else: - filesys = fs if fs is not None else filesystem() - return filesys.open(path, mode=mode, **kwargs) + filesys = fs if fs is not None else await filesystem() + # Check if this is a local filesystem + is_local = isinstance(filesys, fsspec.implementations.local.LocalFileSystem) -def _get_uncached_filesystem(path: str, fs=None): + if is_local: + # Use aiofiles for local files to get truly async file I/O + return aiofiles.open(path, mode=mode, **kwargs) + else: + # Use sync filesystem open method, but wrap it in async context manager + # so it can be used with 'async with' + sync_file = filesys.open(path, mode=mode, **kwargs) + return AsyncFileWrapper(sync_file) + + +async def _get_uncached_filesystem(path: str, fs=None): """ Get a filesystem instance without caching for reading files. This prevents Etag caching issues when files are being modified concurrently. @@ -245,13 +371,13 @@ def _get_uncached_filesystem(path: str, fs=None): if protocol: # Create a new uncached filesystem with the same protocol and options - fs_uncached = fsspec.filesystem( - protocol, - skip_instance_cache=True, - default_fill_cache=False, - use_listings_cache=False, - **storage_options, - ) + fs_kwargs = { + "skip_instance_cache": True, + "default_fill_cache": False, + "use_listings_cache": False, + } + fs_kwargs.update(storage_options) + fs_uncached = fsspec.filesystem(protocol, **fs_kwargs) return fs_uncached except Exception: # If extraction fails, fall through to path-based inference @@ -291,19 +417,36 @@ def _get_uncached_filesystem(path: str, fs=None): return fs_uncached else: # For local filesystems, just use the default - return filesystem() + return await filesystem() + +def _get_fs_for_path(path: str): + """ + Get filesystem for a given path, handling S3 storage_options correctly. + Returns (filesystem, parsed_path) tuple. + """ + storage_options = {} + if path.startswith("s3://") and _AWS_PROFILE: + storage_options["profile"] = _AWS_PROFILE + return fsspec.core.url_to_fs(path, storage_options=storage_options if storage_options else None) -def copy_file(src: str, dest: str) -> None: + +async def copy_file(src: str, dest: str) -> None: """Copy a single file from src to dest across arbitrary filesystems.""" # Use streaming copy to be robust across different filesystems - with fsspec.open(src, "rb") as r, fsspec.open(dest, "wb") as w: - for chunk in iter_chunks(r): - w.write(chunk) + # Get sync filesystems with proper storage_options handling + src_fs, _ = _get_fs_for_path(src) + dest_fs, _ = _get_fs_for_path(dest) + + # Use sync filesystem methods (wrapped in async function for API compatibility) + with src_fs.open(src, "rb") as r: + with dest_fs.open(dest, "wb") as w: + for chunk in iter_chunks(r): + w.write(chunk) def iter_chunks(file_obj, chunk_size: int = 8 * 1024 * 1024): - """Helper to read file in chunks.""" + """Helper to read file in chunks (synchronous).""" while True: data = file_obj.read(chunk_size) if not data: @@ -311,11 +454,11 @@ def iter_chunks(file_obj, chunk_size: int = 8 * 1024 * 1024): yield data -def copy_dir(src_dir: str, dest_dir: str) -> None: +async def copy_dir(src_dir: str, dest_dir: str) -> None: """Recursively copy a directory tree across arbitrary filesystems.""" - makedirs(dest_dir, exist_ok=True) + await makedirs(dest_dir, exist_ok=True) # Determine the source filesystem independently of destination - src_fs, _ = fsspec.core.url_to_fs(src_dir) + src_fs, _ = _get_fs_for_path(src_dir) try: src_files = src_fs.find(src_dir) except Exception: @@ -332,6 +475,6 @@ def copy_dir(src_dir: str, dest_dir: str) -> None: # Ensure destination directory exists dest_parent = posixpath.dirname(dest_file) if dest_parent: - makedirs(dest_parent, exist_ok=True) + await makedirs(dest_parent, exist_ok=True) # Copy the file using streaming (robust across FSes) - copy_file(src_file, dest_file) + await copy_file(src_file, dest_file) diff --git a/lab-sdk/src/lab/task.py b/lab-sdk/src/lab/task.py index 889781c4e..928e15610 100644 --- a/lab-sdk/src/lab/task.py +++ b/lab-sdk/src/lab/task.py @@ -7,10 +7,11 @@ class Task(BaseLabResource): - def get_dir(self): + async def get_dir(self): """Abstract method on BaseLabResource""" task_id_safe = secure_filename(str(self.id)) - return storage.join(get_tasks_dir(), task_id_safe) + tasks_dir = await get_tasks_dir() + return storage.join(tasks_dir, task_id_safe) def _default_json(self): # Default metadata modeled after API tasks table fields @@ -28,7 +29,7 @@ def _default_json(self): "updated_at": datetime.utcnow().isoformat(), } - def set_metadata( + async def set_metadata( self, *, name: str | None = None, @@ -41,7 +42,7 @@ def set_metadata( remote_task: bool | None = None, ): """Set task metadata""" - data = self.get_json_data() + data = await self.get_json_data() if name is not None: data["name"] = name if type is not None: @@ -61,29 +62,29 @@ def set_metadata( # Always update the updated_at timestamp data["updated_at"] = datetime.utcnow().isoformat() - self._set_json_data(data) + await self._set_json_data(data) - def get_metadata(self): + async def get_metadata(self): """Get task metadata""" - data = self.get_json_data() + data = await self.get_json_data() # Fix experiment_id if it's a digit - convert to experiment name if data.get("experiment_id") and str(data["experiment_id"]).isdigit(): - experiment_name = self._get_experiment_name_by_id(data["experiment_id"]) + experiment_name = await self._get_experiment_name_by_id(data["experiment_id"]) if experiment_name: data["experiment_id"] = experiment_name # Save the corrected data back to the file - self._set_json_data(data) + await self._set_json_data(data) return data - def _get_experiment_name_by_id(self, experiment_id): + async def _get_experiment_name_by_id(self, experiment_id): """Get experiment name by ID, return None if not found""" try: from .experiment import Experiment # Get all experiments and search for one with matching db_experiment_id - all_experiments = Experiment.get_all() + all_experiments = await Experiment.get_all() for exp_data in all_experiments: if exp_data.get("db_experiment_id") == int(experiment_id): return exp_data.get("name", experiment_id) @@ -94,73 +95,88 @@ def _get_experiment_name_by_id(self, experiment_id): return experiment_id @staticmethod - def list_all(): + async def list_all(): """List all tasks in the filesystem""" results = [] - tasks_dir = get_tasks_dir() - if not storage.isdir(tasks_dir): + tasks_dir = await get_tasks_dir() + if not await storage.isdir(tasks_dir): print(f"Tasks directory does not exist: {tasks_dir}") return results try: - entries = storage.ls(tasks_dir, detail=False) + entries = await storage.ls(tasks_dir, detail=False) except Exception as e: print(f"Exception listing tasks directory: {e}") entries = [] for full in entries: - if not storage.isdir(full): + if not await storage.isdir(full): continue # Attempt to read index.json (or latest snapshot) try: entry = full.rstrip("/").split("/")[-1] task = Task(entry) - results.append(task.get_metadata()) + results.append(await task.get_metadata()) except Exception: print(f"Exception getting metadata for task: {entry}") continue + # Sort by created_at descending to match database behavior - results.sort(key=lambda x: x.get("created_at", ""), reverse=True) + def sort_key(x): + created_at = x.get("created_at") + if created_at is None: + # Put items without created_at at the end (will sort last when reverse=True) + return "" + # Handle datetime objects + if isinstance(created_at, datetime): + return created_at.timestamp() + # Handle numeric timestamps + if isinstance(created_at, (int, float)): + return created_at + # Handle string dates (ISO format strings sort correctly) + return str(created_at) + + results.sort(key=sort_key, reverse=True) return results @staticmethod - def list_by_type(task_type: str): + async def list_by_type(task_type: str): """List all tasks of a specific type""" - all_tasks = Task.list_all() + all_tasks = await Task.list_all() return [task for task in all_tasks if task.get("type") == task_type] @staticmethod - def list_by_experiment(experiment_id: int): + async def list_by_experiment(experiment_id: int): """List all tasks for a specific experiment""" - all_tasks = Task.list_all() + all_tasks = await Task.list_all() return [task for task in all_tasks if task.get("experiment_id") == experiment_id] @staticmethod - def list_by_type_in_experiment(task_type: str, experiment_id: int): + async def list_by_type_in_experiment(task_type: str, experiment_id: int): """List all tasks of a specific type in a specific experiment""" - all_tasks = Task.list_all() + all_tasks = await Task.list_all() return [ task for task in all_tasks if task.get("type") == task_type and task.get("experiment_id") == experiment_id ] @staticmethod - def get_by_id(task_id: str): + async def get_by_id(task_id: str): """Get a specific task by ID""" try: - task = Task.get(task_id) - return task.get_metadata() + task = await Task.get(task_id) + return await task.get_metadata() except FileNotFoundError: return None @staticmethod - def delete_all(): + async def delete_all(): """Delete all tasks""" - tasks_dir = get_tasks_dir() - if not storage.isdir(tasks_dir): + tasks_dir = await get_tasks_dir() + if not await storage.isdir(tasks_dir): return try: - entries = storage.ls(tasks_dir, detail=False) + entries = await storage.ls(tasks_dir, detail=False) except Exception: entries = [] for full in entries: - if storage.isdir(full): - storage.rm_tree(full) + if await storage.isdir(full): + await storage.rm_tree(full) diff --git a/lab-sdk/src/lab/task_template.py b/lab-sdk/src/lab/task_template.py index 604e62d00..9965ba600 100644 --- a/lab-sdk/src/lab/task_template.py +++ b/lab-sdk/src/lab/task_template.py @@ -7,10 +7,11 @@ class TaskTemplate(BaseLabResource): - def get_dir(self): + async def get_dir(self): """Abstract method on BaseLabResource""" task_id_safe = secure_filename(str(self.id)) - return storage.join(get_task_dir(), task_id_safe) + task_dir = await get_task_dir() + return storage.join(task_dir, task_id_safe) def _default_json(self): # Default metadata - all fields stored directly (not nested in inputs/outputs/config) @@ -24,9 +25,9 @@ def _default_json(self): "updated_at": datetime.utcnow().isoformat(), } - def set_metadata(self, **kwargs): + async def set_metadata(self, **kwargs): """Set task metadata - all fields stored directly in JSON""" - data = self.get_json_data() + data = await self.get_json_data() # Update any provided fields for key, value in kwargs.items(): @@ -35,29 +36,29 @@ def set_metadata(self, **kwargs): # Always update the updated_at timestamp data["updated_at"] = datetime.utcnow().isoformat() - self._set_json_data(data) + await self._set_json_data(data) - def get_metadata(self): + async def get_metadata(self): """Get task metadata""" - data = self.get_json_data() + data = await self.get_json_data() # Fix experiment_id if it's a digit - convert to experiment name if data.get("experiment_id") and str(data["experiment_id"]).isdigit(): - experiment_name = self._get_experiment_name_by_id(data["experiment_id"]) + experiment_name = await self._get_experiment_name_by_id(data["experiment_id"]) if experiment_name: data["experiment_id"] = experiment_name # Save the corrected data back to the file - self._set_json_data(data) + await self._set_json_data(data) return data - def _get_experiment_name_by_id(self, experiment_id): + async def _get_experiment_name_by_id(self, experiment_id): """Get experiment name by ID, return None if not found""" try: from .experiment import Experiment # Get all experiments and search for one with matching db_experiment_id - all_experiments = Experiment.get_all() + all_experiments = await Experiment.get_all() for exp_data in all_experiments: if exp_data.get("db_experiment_id") == int(experiment_id): return exp_data.get("name", experiment_id) @@ -68,58 +69,73 @@ def _get_experiment_name_by_id(self, experiment_id): return experiment_id @staticmethod - def list_all(): + async def list_all(): """List all tasks in the filesystem""" results = [] - task_dir = get_task_dir() - if not storage.isdir(task_dir): + task_dir = await get_task_dir() + if not await storage.isdir(task_dir): print(f"Task directory does not exist: {task_dir}") return results try: - entries = storage.ls(task_dir, detail=False) + entries = await storage.ls(task_dir, detail=False) except Exception as e: print(f"Exception listing task directory: {e}") entries = [] for full in entries: - if not storage.isdir(full): + if not await storage.isdir(full): continue # Attempt to read index.json (or latest snapshot) try: entry = full.rstrip("/").split("/")[-1] task = TaskTemplate(entry) - results.append(task.get_metadata()) + results.append(await task.get_metadata()) except Exception: print(f"Exception getting metadata for task: {entry}") continue + # Sort by created_at descending to match database behavior - results.sort(key=lambda x: x.get("created_at", ""), reverse=True) + def sort_key(x): + created_at = x.get("created_at") + if created_at is None: + # Put items without created_at at the end (will sort last when reverse=True) + return "" + # Handle datetime objects + if isinstance(created_at, datetime): + return created_at.timestamp() + # Handle numeric timestamps + if isinstance(created_at, (int, float)): + return created_at + # Handle string dates (ISO format strings sort correctly) + return str(created_at) + + results.sort(key=sort_key, reverse=True) return results @staticmethod - def list_by_type(task_type: str): + async def list_by_type(task_type: str): """List all tasks of a specific type""" - all_tasks = TaskTemplate.list_all() + all_tasks = await TaskTemplate.list_all() return [task for task in all_tasks if task.get("type") == task_type] @staticmethod - def list_by_experiment(experiment_id: int): + async def list_by_experiment(experiment_id: int): """List all tasks for a specific experiment""" - all_tasks = TaskTemplate.list_all() + all_tasks = await TaskTemplate.list_all() return [task for task in all_tasks if task.get("experiment_id") == experiment_id] @staticmethod - def list_by_type_in_experiment(task_type: str, experiment_id: int): + async def list_by_type_in_experiment(task_type: str, experiment_id: int): """List all tasks of a specific type in a specific experiment""" - all_tasks = TaskTemplate.list_all() + all_tasks = await TaskTemplate.list_all() return [ task for task in all_tasks if task.get("type") == task_type and task.get("experiment_id") == experiment_id ] @staticmethod - def list_by_subtype_in_experiment(experiment_id: int, subtype: str, task_type: str = None): + async def list_by_subtype_in_experiment(experiment_id: int, subtype: str, task_type: str = None): """List all tasks for a specific experiment filtered by subtype and optionally by type""" - all_tasks = TaskTemplate.list_all() + all_tasks = await TaskTemplate.list_all() return [ task for task in all_tasks @@ -129,24 +145,24 @@ def list_by_subtype_in_experiment(experiment_id: int, subtype: str, task_type: s ] @staticmethod - def get_by_id(task_id: str): + async def get_by_id(task_id: str): """Get a specific task by ID""" try: - task = TaskTemplate.get(task_id) - return task.get_metadata() + task = await TaskTemplate.get(task_id) + return await task.get_metadata() except FileNotFoundError: return None @staticmethod - def delete_all(): + async def delete_all(): """Delete all tasks""" - task_dir = get_task_dir() - if not storage.isdir(task_dir): + task_dir = await get_task_dir() + if not await storage.isdir(task_dir): return try: - entries = storage.ls(task_dir, detail=False) + entries = await storage.ls(task_dir, detail=False) except Exception: entries = [] for full in entries: - if storage.isdir(full): - storage.rm_tree(full) + if await storage.isdir(full): + await storage.rm_tree(full) diff --git a/lab-sdk/tests/conftest.py b/lab-sdk/tests/conftest.py index 332a14300..590eb16a5 100644 --- a/lab-sdk/tests/conftest.py +++ b/lab-sdk/tests/conftest.py @@ -2,6 +2,9 @@ import sys import pytest +# Configure pytest-asyncio +pytest_plugins = ("pytest_asyncio",) + @pytest.fixture(autouse=True) def _isolate_imports_and_home(monkeypatch, tmp_path): diff --git a/lab-sdk/tests/test_dataset.py b/lab-sdk/tests/test_dataset.py index 3c986c2eb..43d811190 100644 --- a/lab-sdk/tests/test_dataset.py +++ b/lab-sdk/tests/test_dataset.py @@ -1,5 +1,6 @@ import os import importlib +import pytest def _fresh(monkeypatch): @@ -8,7 +9,8 @@ def _fresh(monkeypatch): importlib.sys.modules.pop(mod) -def test_dataset_get_dir(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_dataset_get_dir(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -20,11 +22,12 @@ def test_dataset_get_dir(tmp_path, monkeypatch): from lab.dataset import Dataset ds = Dataset("test-dataset") - d = ds.get_dir() + d = await ds.get_dir() assert d.endswith(os.path.join("datasets", "test-dataset")) -def test_dataset_create_and_get(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_dataset_create_and_get(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -36,21 +39,23 @@ def test_dataset_create_and_get(tmp_path, monkeypatch): from lab.dataset import Dataset # Create dataset and verify it exists - ds = Dataset.create("test_dataset") + ds = await Dataset.create("test_dataset") assert ds is not None - assert os.path.isdir(ds.get_dir()) - index_file = os.path.join(ds.get_dir(), "index.json") + ds_dir = await ds.get_dir() + assert os.path.isdir(ds_dir) + index_file = os.path.join(ds_dir, "index.json") assert os.path.isfile(index_file) # Get the dataset and verify its properties - ds2 = Dataset.get("test_dataset") + ds2 = await Dataset.get("test_dataset") assert isinstance(ds2, Dataset) - data = ds2.get_json_data() + data = await ds2.get_json_data() assert data["dataset_id"] == "test_dataset" assert data["location"] == "local" -def test_dataset_default_json(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_dataset_default_json(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -61,8 +66,8 @@ def test_dataset_default_json(tmp_path, monkeypatch): from lab.dataset import Dataset - ds = Dataset.create("test_dataset_default") - data = ds.get_json_data() + ds = await Dataset.create("test_dataset_default") + data = await ds.get_json_data() assert data["dataset_id"] == "test_dataset_default" assert data["location"] == "local" assert data["description"] == "" @@ -70,7 +75,8 @@ def test_dataset_default_json(tmp_path, monkeypatch): assert data["json_data"] == {} -def test_dataset_set_metadata(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_dataset_set_metadata(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -81,30 +87,31 @@ def test_dataset_set_metadata(tmp_path, monkeypatch): from lab.dataset import Dataset - ds = Dataset.create("test_dataset_metadata") + ds = await Dataset.create("test_dataset_metadata") # Test setting individual metadata fields - ds.set_metadata(location="remote", description="Test dataset", size=1000) - data = ds.get_json_data() + await ds.set_metadata(location="remote", description="Test dataset", size=1000) + data = await ds.get_json_data() assert data["location"] == "remote" assert data["description"] == "Test dataset" assert data["size"] == 1000 # Test setting json_data - ds.set_metadata(json_data={"key1": "value1", "key2": "value2"}) - data = ds.get_json_data() + await ds.set_metadata(json_data={"key1": "value1", "key2": "value2"}) + data = await ds.get_json_data() assert data["json_data"]["key1"] == "value1" assert data["json_data"]["key2"] == "value2" # Test merging json_data (shallow merge) - ds.set_metadata(json_data={"key2": "updated", "key3": "value3"}) - data = ds.get_json_data() + await ds.set_metadata(json_data={"key2": "updated", "key3": "value3"}) + data = await ds.get_json_data() assert data["json_data"]["key1"] == "value1" # Preserved assert data["json_data"]["key2"] == "updated" # Updated assert data["json_data"]["key3"] == "value3" # New key -def test_dataset_get_metadata(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_dataset_get_metadata(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -115,15 +122,16 @@ def test_dataset_get_metadata(tmp_path, monkeypatch): from lab.dataset import Dataset - ds = Dataset.create("test_dataset_get") - ds.set_metadata(description="My dataset", size=500) - metadata = ds.get_metadata() + ds = await Dataset.create("test_dataset_get") + await ds.set_metadata(description="My dataset", size=500) + metadata = await ds.get_metadata() assert metadata["dataset_id"] == "test_dataset_get" assert metadata["description"] == "My dataset" assert metadata["size"] == 500 -def test_dataset_list_all(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_dataset_list_all(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -135,13 +143,13 @@ def test_dataset_list_all(tmp_path, monkeypatch): from lab.dataset import Dataset # Create multiple datasets - ds1 = Dataset.create("dataset1") - ds1.set_metadata(description="First dataset") - ds2 = Dataset.create("dataset2") - ds2.set_metadata(description="Second dataset") + ds1 = await Dataset.create("dataset1") + await ds1.set_metadata(description="First dataset") + ds2 = await Dataset.create("dataset2") + await ds2.set_metadata(description="Second dataset") # List all datasets - all_datasets = Dataset.list_all() + all_datasets = await Dataset.list_all() assert isinstance(all_datasets, list) assert len(all_datasets) >= 2 @@ -151,7 +159,8 @@ def test_dataset_list_all(tmp_path, monkeypatch): assert "dataset2" in dataset_ids -def test_dataset_list_all_empty_dir(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_dataset_list_all_empty_dir(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -163,12 +172,13 @@ def test_dataset_list_all_empty_dir(tmp_path, monkeypatch): from lab.dataset import Dataset # List all datasets when none exist - all_datasets = Dataset.list_all() + all_datasets = await Dataset.list_all() assert isinstance(all_datasets, list) assert len(all_datasets) == 0 -def test_dataset_secure_filename(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_dataset_secure_filename(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -181,9 +191,9 @@ def test_dataset_secure_filename(tmp_path, monkeypatch): # Test that secure_filename sanitizes the dataset ID # secure_filename converts "/" to "_" and ".." to "__" - ds = Dataset.create("test/../dataset") + ds = await Dataset.create("test/../dataset") # The directory should be sanitized - dir_path = ds.get_dir() + dir_path = await ds.get_dir() # Should not contain actual path traversal (../ as a path component) # secure_filename converts "test/../dataset" to "test_.._dataset" # which is safe because ".." is part of the filename, not a path separator diff --git a/lab-sdk/tests/test_dirs.py b/lab-sdk/tests/test_dirs.py index 2e5807717..d3996071a 100644 --- a/lab-sdk/tests/test_dirs.py +++ b/lab-sdk/tests/test_dirs.py @@ -1,5 +1,6 @@ import os import importlib +import pytest def _fresh_import_dirs(monkeypatch): @@ -11,7 +12,8 @@ def _fresh_import_dirs(monkeypatch): return importlib.import_module("lab.dirs") -def test_dirs_structure_created(monkeypatch, tmp_path): +@pytest.mark.asyncio +async def test_dirs_structure_created(monkeypatch, tmp_path): home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" home.mkdir() @@ -22,12 +24,12 @@ def test_dirs_structure_created(monkeypatch, tmp_path): dirs = _fresh_import_dirs(monkeypatch) # Key directories exist - assert os.path.isdir(dirs.get_experiments_dir()) - assert os.path.isdir(dirs.get_jobs_dir()) - assert os.path.isdir(dirs.get_models_dir()) - assert os.path.isdir(dirs.get_datasets_dir()) - assert os.path.isdir(dirs.get_temp_dir()) - assert os.path.isdir(dirs.get_prompt_templates_dir()) - assert os.path.isdir(dirs.get_tools_dir()) - assert os.path.isdir(dirs.get_batched_prompts_dir()) + assert os.path.isdir(await dirs.get_experiments_dir()) + assert os.path.isdir(await dirs.get_jobs_dir()) + assert os.path.isdir(await dirs.get_models_dir()) + assert os.path.isdir(await dirs.get_datasets_dir()) + assert os.path.isdir(await dirs.get_temp_dir()) + assert os.path.isdir(await dirs.get_prompt_templates_dir()) + assert os.path.isdir(await dirs.get_tools_dir()) + assert os.path.isdir(await dirs.get_batched_prompts_dir()) assert os.path.isdir(dirs.get_galleries_cache_dir()) diff --git a/lab-sdk/tests/test_dirs_workspace.py b/lab-sdk/tests/test_dirs_workspace.py index b58a001eb..0cc8efb98 100644 --- a/lab-sdk/tests/test_dirs_workspace.py +++ b/lab-sdk/tests/test_dirs_workspace.py @@ -1,5 +1,6 @@ import os import importlib +import pytest def test_default_dirs_created(monkeypatch, tmp_path): @@ -21,7 +22,8 @@ def test_default_dirs_created(monkeypatch, tmp_path): assert dirs_workspace.HOME_DIR.startswith(str(tmp_path)) -def test_env_override_existing_paths(monkeypatch, tmp_path): +@pytest.mark.asyncio +async def test_env_override_existing_paths(monkeypatch, tmp_path): # Create explicit dirs and set env home = tmp_path / "custom_home" ws = tmp_path / "custom_ws" @@ -42,10 +44,13 @@ def test_env_override_existing_paths(monkeypatch, tmp_path): from lab import dirs as dirs_workspace assert dirs_workspace.HOME_DIR == str(home) - assert dirs_workspace.WORKSPACE_DIR == str(ws) + # WORKSPACE_DIR is a legacy placeholder, use get_workspace_dir() instead + workspace_dir = await dirs_workspace.get_workspace_dir() + assert workspace_dir == str(ws) -def test_org_scoped_workspace_dir(monkeypatch, tmp_path): +@pytest.mark.asyncio +async def test_org_scoped_workspace_dir(monkeypatch, tmp_path): # Ensure no explicit WS override and set a custom home monkeypatch.delenv("TFL_WORKSPACE_DIR", raising=False) home = tmp_path / "tfl_home" @@ -60,14 +65,14 @@ def test_org_scoped_workspace_dir(monkeypatch, tmp_path): # Set organization id → should route to org-scoped workspace dirs_workspace.set_organization_id("acme") - ws = dirs_workspace.get_workspace_dir() + ws = await dirs_workspace.get_workspace_dir() expected = os.path.join(dirs_workspace.HOME_DIR, "orgs", "acme", "workspace") assert ws == expected assert os.path.isdir(ws) # Reset organization_id → should fall back to default single-tenant path dirs_workspace.set_organization_id(None) - ws_default = dirs_workspace.get_workspace_dir() + ws_default = await dirs_workspace.get_workspace_dir() expected_default = os.path.join(dirs_workspace.HOME_DIR, "workspace") assert ws_default == expected_default assert os.path.isdir(ws_default) diff --git a/lab-sdk/tests/test_experiment.py b/lab-sdk/tests/test_experiment.py index b63ce4cba..d868e1d7c 100644 --- a/lab-sdk/tests/test_experiment.py +++ b/lab-sdk/tests/test_experiment.py @@ -1,6 +1,7 @@ import os import json import importlib +import pytest def _fresh(monkeypatch): @@ -9,7 +10,8 @@ def _fresh(monkeypatch): importlib.sys.modules.pop(mod) -def test_experiment_dir_and_jobs_index(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_experiment_dir_and_jobs_index(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -21,8 +23,8 @@ def test_experiment_dir_and_jobs_index(tmp_path, monkeypatch): from lab.experiment import Experiment from lab.job import Job - exp = Experiment.create("exp1") - exp_dir = exp.get_dir() + exp = await Experiment.create("exp1") + exp_dir = await exp.get_dir() assert exp_dir.endswith(os.path.join("experiments", "exp1")) assert os.path.isdir(exp_dir) @@ -35,16 +37,26 @@ def test_experiment_dir_and_jobs_index(tmp_path, monkeypatch): assert "TRAIN" in data["index"] # Create two jobs and assign to experiment - j1 = Job.create("10") - j1.set_experiment("exp1", sync_rebuild=True) - j2 = Job.create("11") - j2.set_experiment("exp1", sync_rebuild=True) + j1 = await Job.create("10") + await j1.set_experiment("exp1", sync_rebuild=True) + j2 = await Job.create("11") + await j2.set_experiment("exp1", sync_rebuild=True) - all_jobs = exp._get_all_jobs() - assert set(all_jobs) >= {"10", "11"} + # Manually trigger rebuild to ensure jobs are in the index + from lab.dirs import get_workspace_dir + workspace = await get_workspace_dir() + await exp.rebuild_jobs_index(workspace_dir=workspace) -def test_get_jobs_filters(tmp_path, monkeypatch): + all_jobs = await exp._get_all_jobs() + # Jobs should now be visible after rebuild + job_ids = set(all_jobs) + assert "10" in job_ids + assert "11" in job_ids + + +@pytest.mark.asyncio +async def test_get_jobs_filters(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -56,25 +68,26 @@ def test_get_jobs_filters(tmp_path, monkeypatch): from lab.experiment import Experiment from lab.job import Job - exp = Experiment.create("exp2") + exp = await Experiment.create("exp2") - j1 = Job.create("21") - j1.set_experiment("exp2", sync_rebuild=True) - j1.update_status("RUNNING") + j1 = await Job.create("21") + await j1.set_experiment("exp2", sync_rebuild=True) + await j1.update_status("RUNNING") - j2 = Job.create("22") - j2.set_experiment("exp2", sync_rebuild=True) - j2.update_status("NOT_STARTED") + j2 = await Job.create("22") + await j2.set_experiment("exp2", sync_rebuild=True) + await j2.update_status("NOT_STARTED") # get all - jobs = exp.get_jobs() + jobs = await exp.get_jobs() assert isinstance(jobs, list) # filter by status - running = exp.get_jobs(status="RUNNING") + running = await exp.get_jobs(status="RUNNING") assert all(j.get("status") == "RUNNING" for j in running) -def test_experiment_create_and_get(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_experiment_create_and_get(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -86,16 +99,16 @@ def test_experiment_create_and_get(tmp_path, monkeypatch): from lab.experiment import Experiment # Create experiment and verify it exists - exp = Experiment.create("test_experiment") + exp = await Experiment.create("test_experiment") assert exp is not None # Get the experiment and verify its properties - exp_data = exp.get_json_data() + exp_data = await exp.get_json_data() assert exp_data["name"] == "test_experiment" # Try to get an experiment that doesn't exist try: - nonexistent = Experiment.get("999999") + nonexistent = await Experiment.get("999999") # If we get here, the experiment should be None or indicate it doesn't exist assert nonexistent is None except Exception: @@ -103,7 +116,8 @@ def test_experiment_create_and_get(tmp_path, monkeypatch): pass -def test_experiment_config_validation(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_experiment_config_validation(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -115,12 +129,12 @@ def test_experiment_config_validation(tmp_path, monkeypatch): from lab.experiment import Experiment # Test creating experiment with valid config - exp = Experiment.create_with_config("test_experiment_config", {"key": "value"}) + exp = await Experiment.create_with_config("test_experiment_config", {"key": "value"}) assert exp is not None # Test creating experiment with invalid config (string instead of dict) try: - Experiment.create_with_config("test_experiment_invalid", "not_a_dict") + await Experiment.create_with_config("test_experiment_invalid", "not_a_dict") assert False, "Should have raised an exception for invalid config" except TypeError: # Expected behavior - should raise TypeError for non-dict config diff --git a/lab-sdk/tests/test_job_and_labresource.py b/lab-sdk/tests/test_job_and_labresource.py index cf69ce892..7eed0bb67 100644 --- a/lab-sdk/tests/test_job_and_labresource.py +++ b/lab-sdk/tests/test_job_and_labresource.py @@ -1,9 +1,11 @@ import os import json import importlib +import pytest -def test_baselabresource_create_get(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_baselabresource_create_get(tmp_path, monkeypatch): # Create a simple subclass inline by importing Job which uses BaseLabResource for mod in ["lab.job", "lab.dirs"]: if mod in importlib.sys.modules: @@ -18,16 +20,18 @@ def test_baselabresource_create_get(tmp_path, monkeypatch): from lab.job import Job - job = Job.create("123") - assert os.path.isdir(job.get_dir()) - index_file = os.path.join(job.get_dir(), "index.json") + job = await Job.create("123") + job_dir = await job.get_dir() + assert os.path.isdir(job_dir) + index_file = os.path.join(job_dir, "index.json") assert os.path.isfile(index_file) - job2 = Job.get("123") + job2 = await Job.get("123") assert isinstance(job2, Job) -def test_job_default_json_and_updates(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_job_default_json_and_updates(tmp_path, monkeypatch): for mod in ["lab.job", "lab.dirs"]: if mod in importlib.sys.modules: importlib.sys.modules.pop(mod) @@ -41,26 +45,28 @@ def test_job_default_json_and_updates(tmp_path, monkeypatch): from lab.job import Job - job = Job.create("1") + job = await Job.create("1") # On create, defaults are written to index.json - data_path = os.path.join(job.get_dir(), "index.json") + job_dir = await job.get_dir() + data_path = os.path.join(job_dir, "index.json") with open(data_path) as f: data = json.load(f) assert data["status"] == "NOT_STARTED" assert data["progress"] == 0 - job.update_status("RUNNING") - job.update_progress(50) - job.update_job_data_field("k", "v") + await job.update_status("RUNNING") + await job.update_progress(50) + await job.update_job_data_field("k", "v") # After updates, read using BaseLabResource helper (prefers latest snapshot) - data = job.get_json_data() + data = await job.get_json_data() assert data["status"] == "RUNNING" assert data["progress"] == 50 assert data["job_data"]["k"] == "v" -def test_job_data_field_updates(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_job_data_field_updates(tmp_path, monkeypatch): for mod in ["lab.job", "lab.dirs"]: if mod in importlib.sys.modules: importlib.sys.modules.pop(mod) @@ -74,14 +80,14 @@ def test_job_data_field_updates(tmp_path, monkeypatch): from lab.job import Job - job = Job.create("2") + job = await Job.create("2") # Test updating job data fields directly - job.update_job_data_field("completion_status", "success") - job.update_job_data_field("completion_details", "ok") - job.update_job_data_field("score", {"acc": 1}) + await job.update_job_data_field("completion_status", "success") + await job.update_job_data_field("completion_details", "ok") + await job.update_job_data_field("score", {"acc": 1}) - data = job.get_json_data() + data = await job.get_json_data() assert data["job_data"]["completion_status"] == "success" assert data["job_data"]["completion_details"] == "ok" assert data["job_data"]["score"] == {"acc": 1} diff --git a/lab-sdk/tests/test_lab_facade.py b/lab-sdk/tests/test_lab_facade.py index 6e93a8ad6..8370f3548 100644 --- a/lab-sdk/tests/test_lab_facade.py +++ b/lab-sdk/tests/test_lab_facade.py @@ -1,4 +1,5 @@ import os +import asyncio import json import importlib @@ -33,10 +34,10 @@ def test_lab_init(tmp_path, monkeypatch): # Verify experiment and job are initialized assert lab._experiment is not None assert lab._job is not None - assert lab._job.get_status() == "RUNNING" + assert asyncio.run(lab._job.get_status()) == "RUNNING" # Verify job data has start_time - job_data = lab._job.get_job_data() + job_data = lab.get_job_data() assert "start_time" in job_data @@ -53,8 +54,8 @@ def test_lab_init_with_existing_job(tmp_path, monkeypatch): from lab.experiment import Experiment # Create an experiment and job first - exp = Experiment.create("test_exp") - job = exp.create_job() + exp = asyncio.run(Experiment.create("test_exp")) + job = asyncio.run(exp.create_job()) job_id = str(job.id) # Set environment variable to use existing job @@ -65,7 +66,7 @@ def test_lab_init_with_existing_job(tmp_path, monkeypatch): # Verify it's using the existing job assert lab._job.id == job_id - assert lab._job.get_status() == "RUNNING" + assert asyncio.run(lab._job.get_status()) == "RUNNING" def test_lab_init_with_nonexistent_job(tmp_path, monkeypatch): @@ -105,7 +106,7 @@ def test_lab_set_config(tmp_path, monkeypatch): config = {"epochs": 10, "learning_rate": 0.001} lab.set_config(config) - job_data = lab._job.get_job_data() + job_data = lab.get_job_data() assert job_data["epochs"] == 10 assert job_data["learning_rate"] == 0.001 assert job_data["experiment_name"] == "test_exp" @@ -131,7 +132,7 @@ def test_lab_set_config_merges_existing(tmp_path, monkeypatch): # Update with new config lab.set_config({"epochs": 20}) - job_data = lab._job.get_job_data() + job_data = lab.get_job_data() assert job_data["epochs"] == 20 # Updated assert job_data["batch_size"] == 32 # Preserved @@ -153,7 +154,7 @@ def test_lab_log(tmp_path, monkeypatch): lab.log("Test message") # Verify log was written to file - log_path = lab._job.get_log_path() + log_path = asyncio.run(lab._job.get_log_path()) assert os.path.exists(log_path) with open(log_path, "r") as f: content = f.read() @@ -175,10 +176,10 @@ def test_lab_update_progress(tmp_path, monkeypatch): lab.init(experiment_id="test_exp") lab.update_progress(50) - assert lab._job.get_progress() == 50 + assert asyncio.run(lab._job.get_progress()) == 50 lab.update_progress(100) - assert lab._job.get_progress() == 100 + assert asyncio.run(lab._job.get_progress()) == 100 def test_lab_finish(tmp_path, monkeypatch): @@ -197,9 +198,9 @@ def test_lab_finish(tmp_path, monkeypatch): lab.finish(message="Job completed", score={"accuracy": 0.95}) - assert lab._job.get_status() == "COMPLETE" - assert lab._job.get_progress() == 100 - job_data = lab._job.get_job_data() + assert asyncio.run(lab._job.get_status()) == "COMPLETE" + assert asyncio.run(lab._job.get_progress()) == 100 + job_data = lab.get_job_data() assert job_data["completion_status"] == "success" assert job_data["completion_details"] == "Job completed" assert job_data["score"] == {"accuracy": 0.95} @@ -225,7 +226,7 @@ def test_lab_finish_with_paths(tmp_path, monkeypatch): plot_data_path="/path/to/plot", ) - job_data = lab._job.get_job_data() + job_data = lab.get_job_data() assert job_data["additional_output_path"] == "/path/to/output" assert job_data["plot_data_path"] == "/path/to/plot" @@ -246,8 +247,8 @@ def test_lab_error(tmp_path, monkeypatch): lab.error(message="Job failed") - assert lab._job.get_status() == "COMPLETE" - job_data = lab._job.get_job_data() + assert asyncio.run(lab._job.get_status()) == "COMPLETE" + job_data = lab.get_job_data() assert job_data["completion_status"] == "failed" assert job_data["completion_details"] == "Job failed" assert job_data["status"] == "FAILED" @@ -279,7 +280,7 @@ def test_lab_save_artifact_file(tmp_path, monkeypatch): assert f.read() == "test content" # Verify artifact is tracked in job_data - job_data = lab._job.get_job_data() + job_data = lab.get_job_data() assert "artifacts" in job_data assert dest_path in job_data["artifacts"] @@ -386,7 +387,7 @@ def test_lab_save_checkpoint(tmp_path, monkeypatch): assert os.path.isfile(dest_path) # Verify checkpoint is tracked in job_data - job_data = lab._job.get_job_data() + job_data = lab.get_job_data() assert "checkpoints" in job_data assert dest_path in job_data["checkpoints"] assert job_data["latest_checkpoint"] == dest_path @@ -469,15 +470,15 @@ def to_json(self, path_or_buf, orient, lines): # Verify dataset metadata was created from lab.dataset import Dataset - ds = Dataset.get("test_dataset") - metadata = ds.get_metadata() + ds = asyncio.run(Dataset.get("test_dataset")) + metadata = asyncio.run(ds.get_metadata()) assert metadata["dataset_id"] == "test_dataset" assert metadata["location"] == "local" assert metadata["json_data"]["generated"] is True assert metadata["json_data"]["sample_count"] == 2 # Verify dataset is tracked in job_data - job_data = lab._job.get_job_data() + job_data = lab.get_job_data() assert job_data["dataset_id"] == "test_dataset" @@ -519,8 +520,8 @@ def to_json(self, path_or_buf, orient, lines): from lab.dataset import Dataset - ds = Dataset.get("test_dataset_meta") - metadata = ds.get_metadata() + ds = asyncio.run(Dataset.get("test_dataset_meta")) + metadata = asyncio.run(ds.get_metadata()) assert metadata["json_data"]["description"] == "Test dataset" assert metadata["json_data"]["source"] == "synthetic" @@ -582,7 +583,7 @@ def test_lab_save_dataset_duplicate_error(tmp_path, monkeypatch): lab.init(experiment_id="test_exp") # Create dataset first - Dataset.create("existing_dataset") + asyncio.run(Dataset.create("existing_dataset")) class MockDataFrame: def __init__(self, data): @@ -591,9 +592,15 @@ def __init__(self, data): def __len__(self): return len(self.data) - def to_json(self, path, orient, lines): - with open(path, "w") as f: - json.dump(self.data, f) + def to_json(self, path_or_buf, orient, lines): + # Handle both file-like objects and path strings (like real pandas) + if hasattr(path_or_buf, "write"): + # It's a file-like object + json.dump(self.data, path_or_buf) + else: + # It's a path string + with open(path_or_buf, "w") as f: + json.dump(self.data, f) df = MockDataFrame([{"a": 1}]) @@ -756,7 +763,7 @@ def test_lab_capture_wandb_url(tmp_path, monkeypatch): wandb_url = "https://wandb.ai/test/run-123" lab.capture_wandb_url(wandb_url) - job_data = lab._job.get_job_data() + job_data = lab.get_job_data() assert job_data["wandb_run_url"] == wandb_url @@ -805,12 +812,12 @@ def test_lab_list_models(tmp_path, monkeypatch): from lab.model import Model # Create a test model - model1 = Model.create("test_model_1") - model1.set_metadata(name="Test Model 1") + model1 = asyncio.run(Model.create("test_model_1")) + asyncio.run(model1.set_metadata(name="Test Model 1")) # Create another test model - model2 = Model.create("test_model_2") - model2.set_metadata(name="Test Model 2") + model2 = asyncio.run(Model.create("test_model_2")) + asyncio.run(model2.set_metadata(name="Test Model 2")) lab = Lab() # list_models doesn't require initialization @@ -835,15 +842,15 @@ def test_lab_get_model(tmp_path, monkeypatch): from lab.model import Model # Create a test model - model = Model.create("test_model_get") - model.set_metadata(name="Test Model") + model = asyncio.run(Model.create("test_model_get")) + asyncio.run(model.set_metadata(name="Test Model")) lab = Lab() # get_model doesn't require initialization retrieved_model = lab.get_model("test_model_get") assert retrieved_model.id == "test_model_get" - metadata = retrieved_model.get_metadata() + metadata = asyncio.run(retrieved_model.get_metadata()) assert metadata["name"] == "Test Model" @@ -860,8 +867,8 @@ def test_lab_get_model_path(tmp_path, monkeypatch): from lab.model import Model # Create a test model - model = Model.create("test_model_path") - expected_path = model.get_dir() + model = asyncio.run(Model.create("test_model_path")) + expected_path = asyncio.run(model.get_dir()) lab = Lab() # get_model_path doesn't require initialization diff --git a/lab-sdk/tests/test_model.py b/lab-sdk/tests/test_model.py index 41c03ede7..97de4f624 100644 --- a/lab-sdk/tests/test_model.py +++ b/lab-sdk/tests/test_model.py @@ -1,8 +1,10 @@ import os import importlib +import pytest -def test_model_get_dir(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_model_get_dir(tmp_path, monkeypatch): for mod in ["lab.model", "lab.dirs"]: if mod in importlib.sys.modules: importlib.sys.modules.pop(mod) @@ -17,5 +19,5 @@ def test_model_get_dir(tmp_path, monkeypatch): from lab.model import Model m = Model("mixtral-8x7b") - d = m.get_dir() + d = await m.get_dir() assert d.endswith(os.path.join("models", "mixtral-8x7b")) diff --git a/lab-sdk/tests/test_task.py b/lab-sdk/tests/test_task.py index ae556e6ae..02281a1d1 100644 --- a/lab-sdk/tests/test_task.py +++ b/lab-sdk/tests/test_task.py @@ -1,5 +1,6 @@ import os import importlib +import pytest def _fresh(monkeypatch): @@ -8,7 +9,8 @@ def _fresh(monkeypatch): importlib.sys.modules.pop(mod) -def test_task_get_dir(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_task_get_dir(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -20,11 +22,12 @@ def test_task_get_dir(tmp_path, monkeypatch): from lab.task import Task task = Task("test-task") - d = task.get_dir() + d = await task.get_dir() assert d.endswith(os.path.join("tasks", "test-task")) -def test_task_create_and_get(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_task_create_and_get(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -36,20 +39,22 @@ def test_task_create_and_get(tmp_path, monkeypatch): from lab.task import Task # Create task and verify it exists - task = Task.create("test_task") + task = await Task.create("test_task") assert task is not None - assert os.path.isdir(task.get_dir()) - index_file = os.path.join(task.get_dir(), "index.json") + task_dir = await task.get_dir() + assert os.path.isdir(task_dir) + index_file = os.path.join(task_dir, "index.json") assert os.path.isfile(index_file) # Get the task and verify its properties - task2 = Task.get("test_task") + task2 = await Task.get("test_task") assert isinstance(task2, Task) - data = task2.get_json_data() + data = await task2.get_json_data() assert data["id"] == "test_task" -def test_task_default_json(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_task_default_json(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -60,8 +65,8 @@ def test_task_default_json(tmp_path, monkeypatch): from lab.task import Task - task = Task.create("test_task_default") - data = task.get_json_data() + task = await Task.create("test_task_default") + data = await task.get_json_data() assert data["id"] == "test_task_default" assert data["name"] == "" assert data["type"] == "" @@ -75,7 +80,8 @@ def test_task_default_json(tmp_path, monkeypatch): assert "updated_at" in data -def test_task_set_metadata(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_task_set_metadata(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -86,10 +92,10 @@ def test_task_set_metadata(tmp_path, monkeypatch): from lab.task import Task - task = Task.create("test_task_metadata") + task = await Task.create("test_task_metadata") # Test setting all metadata fields - task.set_metadata( + await task.set_metadata( name="Test Task", type="training", inputs={"input1": "value1"}, @@ -99,7 +105,7 @@ def test_task_set_metadata(tmp_path, monkeypatch): experiment_id="exp1", remote_task=True, ) - data = task.get_json_data() + data = await task.get_json_data() assert data["name"] == "Test Task" assert data["type"] == "training" assert data["inputs"] == {"input1": "value1"} @@ -111,15 +117,16 @@ def test_task_set_metadata(tmp_path, monkeypatch): assert "updated_at" in data # Test partial updates - task.set_metadata(name="Updated Task", type="evaluation") - data = task.get_json_data() + await task.set_metadata(name="Updated Task", type="evaluation") + data = await task.get_json_data() assert data["name"] == "Updated Task" assert data["type"] == "evaluation" # Other fields should remain unchanged assert data["plugin"] == "test_plugin" -def test_task_get_metadata(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_task_get_metadata(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -130,15 +137,16 @@ def test_task_get_metadata(tmp_path, monkeypatch): from lab.task import Task - task = Task.create("test_task_get") - task.set_metadata(name="My Task", type="training") - metadata = task.get_metadata() + task = await Task.create("test_task_get") + await task.set_metadata(name="My Task", type="training") + metadata = await task.get_metadata() assert metadata["id"] == "test_task_get" assert metadata["name"] == "My Task" assert metadata["type"] == "training" -def test_task_list_all(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_task_list_all(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -150,13 +158,13 @@ def test_task_list_all(tmp_path, monkeypatch): from lab.task import Task # Create multiple tasks - task1 = Task.create("task1") - task1.set_metadata(name="Task 1", type="training") - task2 = Task.create("task2") - task2.set_metadata(name="Task 2", type="evaluation") + task1 = await Task.create("task1") + await task1.set_metadata(name="Task 1", type="training") + task2 = await Task.create("task2") + await task2.set_metadata(name="Task 2", type="evaluation") # List all tasks - all_tasks = Task.list_all() + all_tasks = await Task.list_all() assert isinstance(all_tasks, list) assert len(all_tasks) >= 2 @@ -171,7 +179,8 @@ def test_task_list_all(tmp_path, monkeypatch): assert created_dates == sorted(created_dates, reverse=True) -def test_task_list_by_type(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_task_list_by_type(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -183,24 +192,25 @@ def test_task_list_by_type(tmp_path, monkeypatch): from lab.task import Task # Create tasks with different types - task1 = Task.create("task_training") - task1.set_metadata(type="training") - task2 = Task.create("task_eval") - task2.set_metadata(type="evaluation") - task3 = Task.create("task_training2") - task3.set_metadata(type="training") + task1 = await Task.create("task_training") + await task1.set_metadata(type="training") + task2 = await Task.create("task_eval") + await task2.set_metadata(type="evaluation") + task3 = await Task.create("task_training2") + await task3.set_metadata(type="training") # List tasks by type - training_tasks = Task.list_by_type("training") + training_tasks = await Task.list_by_type("training") assert len(training_tasks) >= 2 assert all(t["type"] == "training" for t in training_tasks) - eval_tasks = Task.list_by_type("evaluation") + eval_tasks = await Task.list_by_type("evaluation") assert len(eval_tasks) >= 1 assert all(t["type"] == "evaluation" for t in eval_tasks) -def test_task_list_by_experiment(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_task_list_by_experiment(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -212,24 +222,25 @@ def test_task_list_by_experiment(tmp_path, monkeypatch): from lab.task import Task # Create tasks with different experiment IDs - task1 = Task.create("task_exp1_1") - task1.set_metadata(experiment_id=1) - task2 = Task.create("task_exp1_2") - task2.set_metadata(experiment_id=1) - task3 = Task.create("task_exp2_1") - task3.set_metadata(experiment_id=2) + task1 = await Task.create("task_exp1_1") + await task1.set_metadata(experiment_id=1) + task2 = await Task.create("task_exp1_2") + await task2.set_metadata(experiment_id=1) + task3 = await Task.create("task_exp2_1") + await task3.set_metadata(experiment_id=2) # List tasks by experiment - exp1_tasks = Task.list_by_experiment(1) + exp1_tasks = await Task.list_by_experiment(1) assert len(exp1_tasks) >= 2 assert all(t["experiment_id"] == 1 for t in exp1_tasks) - exp2_tasks = Task.list_by_experiment(2) + exp2_tasks = await Task.list_by_experiment(2) assert len(exp2_tasks) >= 1 assert all(t["experiment_id"] == 2 for t in exp2_tasks) -def test_task_list_by_type_in_experiment(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_task_list_by_type_in_experiment(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -241,24 +252,25 @@ def test_task_list_by_type_in_experiment(tmp_path, monkeypatch): from lab.task import Task # Create tasks with different combinations - task1 = Task.create("task_exp1_training") - task1.set_metadata(type="training", experiment_id=1) - task2 = Task.create("task_exp1_eval") - task2.set_metadata(type="evaluation", experiment_id=1) - task3 = Task.create("task_exp2_training") - task3.set_metadata(type="training", experiment_id=2) + task1 = await Task.create("task_exp1_training") + await task1.set_metadata(type="training", experiment_id=1) + task2 = await Task.create("task_exp1_eval") + await task2.set_metadata(type="evaluation", experiment_id=1) + task3 = await Task.create("task_exp2_training") + await task3.set_metadata(type="training", experiment_id=2) # List tasks by type and experiment - exp1_training = Task.list_by_type_in_experiment("training", 1) + exp1_training = await Task.list_by_type_in_experiment("training", 1) assert len(exp1_training) >= 1 assert all(t["type"] == "training" and t["experiment_id"] == 1 for t in exp1_training) - exp2_training = Task.list_by_type_in_experiment("training", 2) + exp2_training = await Task.list_by_type_in_experiment("training", 2) assert len(exp2_training) >= 1 assert all(t["type"] == "training" and t["experiment_id"] == 2 for t in exp2_training) -def test_task_get_by_id(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_task_get_by_id(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -270,21 +282,22 @@ def test_task_get_by_id(tmp_path, monkeypatch): from lab.task import Task # Create a task - task = Task.create("test_task_get_by_id") - task.set_metadata(name="Test Task") + task = await Task.create("test_task_get_by_id") + await task.set_metadata(name="Test Task") # Get task by ID - task_data = Task.get_by_id("test_task_get_by_id") + task_data = await Task.get_by_id("test_task_get_by_id") assert task_data is not None assert task_data["id"] == "test_task_get_by_id" assert task_data["name"] == "Test Task" # Get non-existent task - task_data_none = Task.get_by_id("non_existent_task") + task_data_none = await Task.get_by_id("non_existent_task") assert task_data_none is None -def test_task_delete_all(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_task_delete_all(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -296,24 +309,26 @@ def test_task_delete_all(tmp_path, monkeypatch): from lab.task import Task # Create some tasks - Task.create("task_to_delete_1") - Task.create("task_to_delete_2") + await Task.create("task_to_delete_1") + await Task.create("task_to_delete_2") # Verify they exist - assert len(Task.list_all()) >= 2 + all_tasks = await Task.list_all() + assert len(all_tasks) >= 2 # Delete all tasks - Task.delete_all() + await Task.delete_all() # Verify tasks directory is empty or doesn't exist from lab.dirs import get_tasks_dir - tasks_dir = get_tasks_dir() + tasks_dir = await get_tasks_dir() if os.path.exists(tasks_dir): assert len(os.listdir(tasks_dir)) == 0 -def test_task_list_all_empty_dir(tmp_path, monkeypatch): +@pytest.mark.asyncio +async def test_task_list_all_empty_dir(tmp_path, monkeypatch): _fresh(monkeypatch) home = tmp_path / ".tfl_home" ws = tmp_path / ".tfl_ws" @@ -325,6 +340,6 @@ def test_task_list_all_empty_dir(tmp_path, monkeypatch): from lab.task import Task # List all tasks when none exist - all_tasks = Task.list_all() + all_tasks = await Task.list_all() assert isinstance(all_tasks, list) # Should return empty list, not raise error diff --git a/package.json b/package.json index 52f027e5e..e67e9cd3c 100644 --- a/package.json +++ b/package.json @@ -35,6 +35,7 @@ "test": "jest", "format:check": "prettier --check \"src/**/*.{js,jsx,ts,tsx,json,css,scss,md}\"", "format": "prettier --write \"src/**/*.{js,jsx,ts,tsx,json,css,scss,md}\"", + "format:api": "prettier --write \"api/transformerlab/plugins/**/*.json\"", "api:install": "cd api && ./install.sh", "api:start": "cd api && ./run.sh" }, @@ -204,4 +205,4 @@ "node": ">=14.x", "npm": ">=7.x" } -} \ No newline at end of file +}