diff --git a/api/api.py b/api/api.py index f2ccfe454..266b8cd2a 100644 --- a/api/api.py +++ b/api/api.py @@ -119,7 +119,6 @@ async def lifespan(app: FastAPI): galleries.update_gallery_cache() spawn_fastchat_controller_subprocess() await db.init() # This now runs Alembic migrations internally - # create_db_and_tables() is deprecated - migrations are handled in db.init() print("✅ SEED DATA") # Initialize experiments seed_default_experiments() diff --git a/api/test/api/conftest.py b/api/test/api/conftest.py index d786e63eb..a52f8d96d 100644 --- a/api/test/api/conftest.py +++ b/api/test/api/conftest.py @@ -23,8 +23,11 @@ os.environ["TRANSFORMERLAB_REFRESH_SECRET"] = "test-refresh-secret-for-testing-only" os.environ["EMAIL_METHOD"] = "dev" # Use dev mode for tests (no actual email sending) -# Use in-memory database for tests -os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:" +# Use temporary file-based database for tests (easier to debug than in-memory) +test_db_dir = os.path.join("test", "tmp", "db") +os.makedirs(test_db_dir, exist_ok=True) +test_db_path = os.path.join(test_db_dir, "test_llmlab.sqlite3") +os.environ["DATABASE_URL"] = f"sqlite+aiosqlite:///{test_db_path}" from api import app # noqa: E402 @@ -72,13 +75,43 @@ def request(self, method, url, **kwargs): return super().request(method, url, **kwargs) +@pytest.fixture(scope="session", autouse=True) +def cleanup_test_db(): + """Clean up test database file after all tests complete""" + yield + # Clean up database file and related files (WAL, SHM) + test_db_path = os.path.join("test", "tmp", "db", "test_llmlab.sqlite3") + for ext in ["", "-wal", "-shm"]: + db_file = test_db_path + ext + if os.path.exists(db_file): + try: + os.remove(db_file) + except OSError: + pass # Ignore errors if file is locked or already removed + + @pytest.fixture(scope="session") def client(): - # Initialize database tables for tests - from transformerlab.shared.models.user_model import create_db_and_tables # noqa: E402 + # Initialize database tables for tests using Alembic migrations (same as production) + from transformerlab.db.session import run_alembic_migrations # noqa: E402 from transformerlab.services.experiment_init import seed_default_admin_user # noqa: E402 - asyncio.run(create_db_and_tables()) + # Ensure test database directory exists + test_db_dir = os.path.join("test", "tmp", "db") + os.makedirs(test_db_dir, exist_ok=True) + + # Remove existing test database if it exists (start fresh) + test_db_path = os.path.join(test_db_dir, "test_llmlab.sqlite3") + for ext in ["", "-wal", "-shm"]: + db_file = test_db_path + ext + if os.path.exists(db_file): + try: + os.remove(db_file) + except OSError: + pass + + # Run Alembic migrations to create database schema (matches production) + asyncio.run(run_alembic_migrations()) asyncio.run(seed_default_admin_user()) controller_log_dir = os.path.join("test", "tmp", "workspace", "logs") os.makedirs(controller_log_dir, exist_ok=True) diff --git a/api/test/api/test_experiment_export.py b/api/test/api/test_experiment_export.py deleted file mode 100644 index c34f4cb67..000000000 --- a/api/test/api/test_experiment_export.py +++ /dev/null @@ -1,113 +0,0 @@ -import os -import json -import pytest -from transformerlab.services import experiment_service -from transformerlab.services.tasks_service import tasks_service -from lab import storage - - -@pytest.mark.skip(reason="Test needs to be updated for org-based workspace") -async def test_export_experiment(client): - """Test exporting an experiment to JSON format""" - # Create a test experiment - test_experiment_name = f"test_export_{os.getpid()}" - config = {"description": "Test experiment"} - experiment_id = experiment_service.experiment_create(test_experiment_name, config) - - # Add a training task - train_config = { - "template_name": "TestTemplate", - "plugin_name": "test_trainer", - "model_name": "test-model", - "dataset_name": "test-dataset", - "batch_size": "4", - "learning_rate": "0.0001", - } - tasks_service.add_task( - name="test_train_task", - task_type="TRAIN", - inputs={"model_name": "test-model", "dataset_name": "test-dataset"}, - config=train_config, - plugin="test_trainer", - outputs={}, - experiment_id=experiment_id, - ) - - # Add an evaluation task - eval_config = { - "template_name": "TestEval", - "plugin_name": "test_evaluator", - "model_name": "test-model-2", - "eval_type": "basic", - "script_parameters": {"tasks": ["mmlu"], "limit": 0.5}, - "eval_dataset": "test-eval-dataset", - } - tasks_service.add_task( - name="test_eval_task", - task_type="EVAL", - inputs={"model_name": "test-model-2", "dataset_name": "test-eval-dataset"}, - config=eval_config, - plugin="test_evaluator", - outputs={"eval_results": "{}"}, - experiment_id=experiment_id, - ) - - # Add a workflow - COMMENTED OUT due to workflow migration issues - # workflow_config = { - # "nodes": [{"id": "1", "task": "test_train_task"}, {"id": "2", "task": "test_eval_task"}], - # "edges": [{"source": "1", "target": "2"}], - # } - # await db_workflows.workflow_create( - # name="test_workflow", config=json.dumps(workflow_config), experiment_id=experiment_id - # ) - - # Call the export endpoint - response = client.get(f"/experiment/{experiment_id}/export_to_recipe") - assert response.status_code == 200 - - # The response should be a JSON file - assert response.headers["content-type"] == "application/json" - - # Get the workspace_dir using team_id from the client (org-based workspace) - from lab import HOME_DIR - - workspace_dir = storage.join(HOME_DIR, "orgs", client._team_id, "workspace") - - # Read the exported file from workspace directory - export_file = storage.join(workspace_dir, f"{test_experiment_name}_export.json") - assert storage.exists(export_file) - - with storage.open(export_file, "r") as f: - exported_data = json.load(f) - - # Verify the exported data structure - assert exported_data["title"] == test_experiment_name - assert "dependencies" in exported_data - assert "tasks" in exported_data - # assert "workflows" in exported_data - - # Verify dependencies were collected correctly - dependencies = {(d["type"], d["name"]) for d in exported_data["dependencies"]} - assert ("model", "test-model") in dependencies - assert ("model", "test-model-2") in dependencies - assert ("dataset", "test-dataset") in dependencies - assert ("plugin", "test_trainer") in dependencies - assert ("plugin", "test_evaluator") in dependencies - - # Verify tasks were exported correctly - tasks = {t["name"]: t for t in exported_data["tasks"]} - assert "test_train_task" in tasks - assert "test_eval_task" in tasks - assert tasks["test_train_task"]["task_type"] == "TRAIN" - assert tasks["test_eval_task"]["task_type"] == "EVAL" - - # Verify workflow was exported correctly - COMMENTED OUT due to workflow migration issues - # workflows = {w["name"]: w for w in exported_data["workflows"]} - # assert "test_workflow" in workflows - # assert len(workflows["test_workflow"]["config"]["nodes"]) == 2 - # assert len(workflows["test_workflow"]["config"]["edges"]) == 1 - - # Clean up - experiment_service.experiment_delete(experiment_id) - if storage.exists(export_file): - storage.rm(export_file) diff --git a/api/test/api/test_export.py b/api/test/api/test_export.py deleted file mode 100644 index 32731b46e..000000000 --- a/api/test/api/test_export.py +++ /dev/null @@ -1,283 +0,0 @@ -import json -from unittest.mock import patch, AsyncMock, MagicMock -import pytest -from transformerlab.shared.shared import get_job_output_file_name as get_output_file_name -import asyncio - -pytestmark = pytest.mark.skip("skipping these as they need to be fixed") - - -def test_export_jobs(client): - resp = client.get("/experiment/1/export/jobs") - assert resp.status_code == 200 - assert isinstance(resp.json(), list) - - -def test_export_job(client): - resp = client.get("/experiment/1/export/job?jobId=job123") - assert resp.status_code == 200 - - -@patch("transformerlab.db.experiment_get") -@patch("transformerlab.services.job_service.job_create") -@patch("asyncio.create_subprocess_exec") -@patch("transformerlab.routers.experiment.export.get_output_file_name") -@patch("transformerlab.db.job_update_status") -@patch("os.makedirs") -@patch("os.path.join") -@patch("json.dump") -@patch("builtins.open") -def test_run_exporter_script_success( - client, - mock_open, - mock_json_dump, - mock_path_join, - mock_makedirs, - mock_job_update, - mock_get_output_file, - mock_subprocess, - mock_job_create, - mock_experiment_get, -): - # Setup mocks - mock_experiment_get.return_value = { - "config": json.dumps({"foundation": "huggingface/model1", "foundation_model_architecture": "pytorch"}) - } - mock_job_create.return_value = "job123" - mock_get_output_file.return_value = "/tmp/output_job123.txt" - - # Mock for file opening - mock_file = MagicMock() - mock_open.return_value.__enter__.return_value = mock_file - - # Mock subprocess - mock_process = AsyncMock() - mock_process.returncode = 0 - mock_process.communicate.return_value = (None, b"") - mock_subprocess.return_value = mock_process - - # Mock path join to return predictable paths - mock_path_join.side_effect = lambda *args: "/".join(args) - - resp = client.get( - "/experiment/1/export/run_exporter_script?plugin_name=test_plugin&plugin_architecture=GGUF&plugin_params=%7B%22q_bits%22%3A%224%22%7D" - ) - assert resp.status_code == 200 - result = resp.json() - assert result["status"] == "success" - assert result["job_id"] == "job123" - - # Verify that status was updated to COMPLETE - mock_job_update.assert_called_with(job_id="job123", status="COMPLETE") - - -@patch("transformerlab.db.experiment_get") -def test_run_exporter_script_invalid_experiment(client, mock_experiment_get): - # Setup mock to simulate experiment not found - mock_experiment_get.return_value = None - - resp = client.get("/experiment/999/export/run_exporter_script?plugin_name=test_plugin&plugin_architecture=GGUF") - assert resp.status_code == 200 - result = resp.json() - assert result["message"] == "Experiment 999 does not exist" - - -@patch("transformerlab.db.experiment_get") -@patch("transformerlab.services.job_service.job_create") -@patch("asyncio.create_subprocess_exec") -@patch("transformerlab.routers.experiment.export.get_output_file_name") -@patch("transformerlab.db.job_update_status") -@patch("os.makedirs") -def test_run_exporter_script_process_error( - client, mock_makedirs, mock_job_update, mock_get_output_file, mock_subprocess, mock_job_create, mock_experiment_get -): - # Setup mocks - mock_experiment_get.return_value = { - "config": json.dumps({"foundation": "huggingface/model1", "foundation_model_architecture": "pytorch"}) - } - mock_job_create.return_value = "job123" - mock_get_output_file.return_value = "/tmp/output_job123.txt" - - # Mock subprocess with error - mock_process = AsyncMock() - mock_process.returncode = 1 - mock_process.communicate.return_value = (None, b"Error") - mock_subprocess.return_value = mock_process - - resp = client.get("/experiment/1/export/run_exporter_script?plugin_name=test_plugin&plugin_architecture=GGUF") - assert resp.status_code == 200 - result = resp.json() - assert "Export failed" in result["message"] - - # Verify that status was updated to FAILED - mock_job_update.assert_called_with(job_id="job123", status="FAILED") - - -@patch("transformerlab.db.experiment_get") -@patch("transformerlab.services.job_service.job_create") -@patch("asyncio.create_subprocess_exec") -@patch("transformerlab.routers.experiment.export.get_output_file_name") -@patch("transformerlab.db.job_update_status") -@patch("os.makedirs") -def test_run_exporter_script_stderr_decode_error( - client, mock_makedirs, mock_job_update, mock_get_output_file, mock_subprocess, mock_job_create, mock_experiment_get -): - # Setup mocks - mock_experiment_get.return_value = { - "config": json.dumps({"foundation": "huggingface/model1", "foundation_model_architecture": "pytorch"}) - } - mock_job_create.return_value = "job123" - mock_get_output_file.return_value = "/tmp/output_job123.txt" - - # Mock subprocess with stderr decode error - mock_process = AsyncMock() - mock_process.returncode = 1 - mock_process.communicate.return_value = (None, b"\xff\xfe") # Invalid UTF-8 sequence - mock_subprocess.return_value = mock_process - - resp = client.get("/experiment/1/export/run_exporter_script?plugin_name=test_plugin&plugin_architecture=GGUF") - assert resp.status_code == 200 - result = resp.json() - assert "Export failed due to an internal error" in result["message"] - - # Verify that status was updated to FAILED - mock_job_update.assert_called_with(job_id="job123", status="FAILED") - - -@patch("transformerlab.db.job_get") -@patch("transformerlab.routers.experiment.export.dirs.plugin_dir_by_name") -@patch("os.path.exists") -def test_get_output_file_name_with_custom_path(mock_exists, mock_plugin_dir, mock_job_get): - # Setup mocks - mock_job_get.return_value = {"job_data": {"output_file_path": "/custom/path/output.txt", "plugin": "test_plugin"}} - mock_plugin_dir.return_value = "/plugins/test_plugin" - mock_exists.return_value = True - - result = asyncio.run(get_output_file_name("job123")) - assert result == "/custom/path/output.txt" - - -@patch("transformerlab.db.job_get") -@patch("transformerlab.routers.experiment.export.dirs.plugin_dir_by_name") -@patch("os.path.exists") -def test_get_output_file_name_without_plugin(mock_exists, mock_plugin_dir, mock_job_get): - # Setup mocks - mock_job_get.return_value = { - "job_data": {} # No plugin specified - } - - with pytest.raises(ValueError, match="Plugin not found in job data"): - asyncio.run(get_output_file_name("job123")) - - -@patch("transformerlab.db.job_get") -@patch("transformerlab.routers.experiment.export.dirs.plugin_dir_by_name") -@patch("os.path.exists") -def test_get_output_file_name_with_plugin(mock_exists, mock_plugin_dir, mock_job_get): - # Setup mocks - mock_job_get.return_value = {"job_data": {"plugin": "test_plugin"}} - mock_plugin_dir.return_value = "/plugins/test_plugin" - mock_exists.return_value = True - - result = asyncio.run(get_output_file_name("job123")) - assert "jobs/job123/output_job123.txt" in result - - -@patch("transformerlab.routers.experiment.export.get_output_file_name") -def test_watch_export_log_value_error(client, mock_get_output_file): - mock_get_output_file.side_effect = ValueError("File not found for job") - - resp = client.get("/experiment/1/export/job/job123/stream_output") - assert resp.status_code == 200 - response_text = resp.text.strip('"') - assert response_text == "An internal error has occurred!" - - -@patch("transformerlab.routers.experiment.export.get_output_file_name") -def test_watch_export_log_other_error(client, mock_get_output_file): - # Setup mock to raise a different ValueError - mock_get_output_file.side_effect = ValueError("Some other error") - - resp = client.get("/experiment/1/export/job/job123/stream_output") - assert resp.status_code == 200 - response_text = resp.text.strip('"') - assert response_text == "An internal error has occurred!" - - -@patch("transformerlab.db.job_get") -@patch("transformerlab.routers.experiment.export.dirs.plugin_dir_by_name") -@patch("os.path.exists") -def test_get_output_file_name_no_existing_file(client, mock_exists, mock_plugin_dir, mock_job_get): - """ - When the job has a plugin but no bespoke output_file_path and - the file doesn't exist yet, export.get_output_file_name should - still return the *constructed* path. - """ - mock_job_get.return_value = {"job_data": {"plugin": "test_plugin"}} - mock_plugin_dir.return_value = "/plugins/test_plugin" - mock_exists.return_value = False # force the “else” branch - - result = asyncio.run(get_output_file_name("job123")) - assert result == "/plugins/test_plugin/output_job123.txt" - - -@patch("transformerlab.routers.experiment.export.watch_file") -@patch("transformerlab.routers.experiment.export.asyncio.sleep") -@patch("transformerlab.routers.experiment.export.get_output_file_name") -def test_watch_export_log_retry_success(client, mock_get_output_file, mock_sleep, mock_watch_file): - """ - First call to get_output_file_name raises the special ValueError. - async sleep is awaited, then the second call succeeds and the route - returns a StreamingResponse built from watch_file(). - """ - # 1️⃣ make get_output_file_name fail once, then succeed - mock_get_output_file.side_effect = [ - ValueError("No output file found for job 123"), - "/tmp/output_job123.txt", - ] - - # 2️⃣ avoid a real 4-second wait - mock_sleep.return_value = AsyncMock() - - # 3️⃣ provide an iterator so FastAPI can stream something - mock_watch_file.return_value = iter(["line1\n"]) - - resp = client.get("/experiment/1/export/job/job123/stream_output") - assert resp.status_code == 200 - # because watch_file yielded “line1”, the body must contain it - assert "line1" in resp.text - - # ensure the retry actually happened - assert mock_get_output_file.call_count >= 2 - - # make sure sleep was awaited with 4 seconds at least once - assert any(call.args == (4,) for call in mock_sleep.await_args_list) - - -@patch("transformerlab.db.experiment_get") -@patch("transformerlab.services.job_service.job_create") -@patch("asyncio.create_subprocess_exec") -@patch("transformerlab.routers.experiment.export.get_output_file_name") -@patch("builtins.open") -def test_stderr_decode_fallback(client, mock_open, mock_get_outfile, mock_subproc, mock_job_create, mock_exp_get): - # minimal fixtures - mock_exp_get.return_value = {"config": '{"foundation":"hf/x","foundation_model_architecture":"pt"}'} - mock_job_create.return_value = "j1" - mock_get_outfile.return_value = "/tmp/out.txt" - - # make stderr.decode() raise - bad_stderr = MagicMock() - bad_stderr.decode.side_effect = UnicodeDecodeError("utf-8", b"", 0, 1, "boom") - proc = AsyncMock(returncode=0) - proc.communicate.return_value = (None, bad_stderr) - mock_subproc.return_value = proc - - fake_file = MagicMock() - mock_open.return_value.__enter__.return_value = fake_file - - resp = client.get("/experiment/1/export/run_exporter_script?plugin_name=p&plugin_architecture=GGUF") - assert resp.status_code == 200 - - # confirm fallback string was written - written = "".join(call.args[0] for call in fake_file.write.call_args_list) - assert "[stderr decode error]" in written diff --git a/api/test/api/test_recipes_new.py b/api/test/api/test_recipes_new.py deleted file mode 100644 index 16ba45f15..000000000 --- a/api/test/api/test_recipes_new.py +++ /dev/null @@ -1,492 +0,0 @@ -import pytest -import os - -# skip all tests in this file -pytestmark = pytest.mark.skip("skipping these as they need to be fixed") - -TEST_EXP_RECIPES = [ - { - "id": "1", - "title": "Test Recipe - With Notes", - "description": "A test recipe with notes to test notes creation", - "notes": "# Test Recipe Notes\n\nThis is a test recipe for unit testing.\n\n## Features\n- Notes creation\n- Task generation\n\n## Usage\nThis should create a readme.md file in the experiment.", - "dependencies": [{"type": "model", "name": "test-model"}, {"type": "dataset", "name": "test-dataset"}], - }, - { - "id": "2", - "title": "Test Recipe - With Tasks and Notes", - "description": "A test recipe that includes both notes and tasks", - "notes": "# Training Recipe\n\nThis recipe includes training tasks.\n\n## Training Configuration\n- Uses LoRA training\n- Batch size: 4\n- Learning rate: 0.0001", - "dependencies": [ - {"type": "model", "name": "test-model-2"}, - {"type": "dataset", "name": "test-dataset-for-training"}, - ], - "tasks": [ - { - "name": "test_train_task", - "task_type": "TRAIN", - "type": "LoRA", - "plugin": "test_trainer", - "formatting_template": "{{prompt}}\n{{completion}}", - "config_json": '{"template_name":"TestTemplate","plugin_name":"test_trainer","model_name":"test-model-2","dataset_name":"test-dataset-for-training","batch_size":"4","learning_rate":"0.0001"}', - } - ], - }, - { - "id": "3", - "title": "Test Recipe - Tasks Only", - "description": "A test recipe with only tasks, no notes", - "dependencies": [{"type": "model", "name": "test-model-3"}, {"type": "dataset", "name": "test-dataset-3"}], - "tasks": [ - { - "name": "single_train_task", - "task_type": "TRAIN", - "type": "LoRA", - "plugin": "mlx_lora_trainer", - "formatting_template": "{{text}}", - "config_json": '{"template_name":"NoNotesTemplate","plugin_name":"mlx_lora_trainer","model_name":"test-model-3","dataset_name":"test-dataset-3","batch_size":"8","learning_rate":"0.001"}', - } - ], - }, - { - "id": "4", - "title": "Test Recipe - With Adaptor Name", - "description": "A test recipe that includes adaptor_name in config to test line 281", - "dependencies": [{"type": "model", "name": "test-model-4"}, {"type": "dataset", "name": "test-dataset-4"}], - "tasks": [ - { - "name": "adaptor_train_task", - "task_type": "TRAIN", - "type": "LoRA", - "plugin": "test_trainer", - "formatting_template": "{{prompt}}\n{{completion}}", - "config_json": '{"template_name":"AdaptorTest","plugin_name":"test_trainer","model_name":"test-model-4","dataset_name":"test-dataset-4","adaptor_name":"test_adaptor","batch_size":"4","learning_rate":"0.0001"}', - } - ], - }, - { - "id": "5", - "title": "Test Recipe - Invalid JSON Config", - "description": "A test recipe with invalid JSON to test exception handling", - "dependencies": [{"type": "model", "name": "test-model-5"}, {"type": "dataset", "name": "test-dataset-5"}], - "tasks": [ - { - "name": "invalid_json_task", - "task_type": "TRAIN", - "type": "LoRA", - "plugin": "test_trainer", - "formatting_template": "{{prompt}}\n{{completion}}", - "config_json": "{invalid json syntax to trigger exception", - } - ], - }, - { - "id": "6", - "title": "Test Recipe - With Multiple Task Types", - "description": "A test recipe that includes training, evaluation and generation tasks", - "dependencies": [{"type": "model", "name": "test-model-6"}, {"type": "dataset", "name": "test-dataset-6"}], - "tasks": [ - { - "name": "multi_train_task", - "task_type": "TRAIN", - "type": "LoRA", - "plugin": "test_trainer", - "formatting_template": "{{prompt}}\n{{completion}}", - "config_json": '{"template_name":"TestTemplate","plugin_name":"test_trainer","model_name":"test-model-6","dataset_name":"test-dataset-6","batch_size":"4","learning_rate":"0.0001"}', - }, - { - "name": "multi_eval_task", - "task_type": "EVAL", - "plugin": "test_evaluator", - "config_json": '{"template_name":"TestEval","plugin_name":"test_evaluator","model_name":"test-model-6","eval_type":"basic","script_parameters":{"tasks":["mmlu","hellaswag"],"limit":0.5,"device_map":{"model":"auto","tensor_parallel":true}},"eval_dataset":"test-eval-dataset"}', - }, - { - "name": "multi_generate_task", - "task_type": "GENERATE", - "plugin": "test_generator", - "config_json": '{"template_name":"TestGen","plugin_name":"test_generator","model_name":"test-model-6","prompt_template":"Generate a response: {{input}}","generation_params":{"max_length":100,"temperature":0.7}}', - }, - ], - }, - { - "id": "7", - "title": "Test Recipe - With Multiple Workflows", - "description": "A test recipe that includes multiple workflows", - "dependencies": [{"type": "model", "name": "test-model-8"}, {"type": "dataset", "name": "test-dataset-8"}], - "tasks": [ - { - "name": "workflow_train_task", - "task_type": "TRAIN", - "type": "LoRA", - "plugin": "test_trainer", - "formatting_template": "{{prompt}}\n{{completion}}", - "config_json": '{"template_name":"MultiWorkflowTrain","plugin_name":"test_trainer","model_name":"test-model-8","dataset_name":"test-dataset-8","batch_size":"4","learning_rate":"0.0001"}', - }, - { - "name": "workflow_eval_task", - "task_type": "EVAL", - "plugin": "test_evaluator", - "config_json": '{"template_name":"MultiWorkflowEval","plugin_name":"test_evaluator","model_name":"test-model-8","tasks":"mmlu","limit":"0.5","run_name":"MultiWorkflowEval"}', - }, - ], - "workflows": [ - { - "name": "Train_Only_Workflow", - "config": { - "nodes": [ - { - "id": "node_train", - "type": "TRAIN", - "task": "workflow_train_task", - "name": "Training Task", - "out": [], - } - ] - }, - }, - { - "name": "Train_Eval_Workflow", - "config": { - "nodes": [ - { - "id": "node_train", - "type": "TRAIN", - "task": "workflow_train_task", - "name": "Training Task", - "out": ["node_eval"], - }, - { - "id": "node_eval", - "type": "EVAL", - "task": "workflow_eval_task", - "name": "Evaluation Task", - "out": [], - }, - ] - }, - }, - ], - }, - { - "id": "8", - "title": "Test Recipe - With Invalid Workflow Config", - "description": "A test recipe with invalid workflow config to test error handling", - "dependencies": [{"type": "model", "name": "test-model-9"}, {"type": "dataset", "name": "test-dataset-9"}], - "tasks": [ - { - "name": "invalid_workflow_train_task", - "task_type": "TRAIN", - "type": "LoRA", - "plugin": "test_trainer", - "formatting_template": "{{prompt}}\n{{completion}}", - "config_json": '{"template_name":"InvalidWorkflowTrain","plugin_name":"test_trainer","model_name":"test-model-9","dataset_name":"test-dataset-9","batch_size":"4","learning_rate":"0.0001"}', - } - ], - "workflows": [{"name": "Invalid_Workflow", "config": "invalid_config_format"}], - }, - { - "id": "9", - "title": "Test Recipe - With Named Tasks", - "description": "A test recipe with explicitly named tasks", - "dependencies": [{"type": "model", "name": "test-model-10"}, {"type": "dataset", "name": "test-dataset-10"}], - "tasks": [ - { - "name": "custom_train_task", - "task_type": "TRAIN", - "type": "LoRA", - "plugin": "test_trainer", - "formatting_template": "{{prompt}}\n{{completion}}", - "config_json": '{"template_name":"NamedTaskTrain","plugin_name":"test_trainer","model_name":"test-model-10","dataset_name":"test-dataset-10","batch_size":"4","learning_rate":"0.0001"}', - }, - { - "name": "custom_eval_task", - "task_type": "EVAL", - "plugin": "test_evaluator", - "config_json": '{"template_name":"NamedTaskEval","plugin_name":"test_evaluator","model_name":"test-model-10","tasks":"mmlu","limit":"0.5","run_name":"NamedTaskEval"}', - }, - ], - }, -] - - -def test_recipes_list(client): - resp = client.get("/recipes/list") - assert resp.status_code == 200 - data = resp.json() - assert isinstance(data, list) - assert len(data) >= 6 # Should have our test recipes (updated count) - - -def test_recipes_get_by_id_with_notes(client): - resp = client.get("/recipes/1") - assert resp.status_code == 200 - data = resp.json() - assert data["id"] == "1" - assert data["title"] == "Test Recipe - With Notes" - assert "notes" in data - assert "# Test Recipe Notes" in data["notes"] - - -def test_recipes_get_by_id_with_tasks(client): - resp = client.get("/recipes/2") - assert resp.status_code == 200 - data = resp.json() - assert data["id"] == "2" - assert "tasks" in data - assert len(data["tasks"]) == 1 - assert data["tasks"][0]["task_type"] == "TRAIN" - - -def test_create_experiment_with_notes(client): - test_experiment_name = f"test_notes_exp_{os.getpid()}" - resp = client.post(f"/recipes/1/create_experiment?experiment_name={test_experiment_name}") - assert resp.status_code == 200 - data = resp.json() - assert "status" in data or "message" in data - - -def test_create_experiment_with_tasks(client): - test_experiment_name = f"test_tasks_exp_{os.getpid()}" - resp = client.post(f"/recipes/2/create_experiment?experiment_name={test_experiment_name}") - assert resp.status_code == 200 - data = resp.json() - assert data["status"] == "success" - assert "data" in data - assert "task_results" in data["data"] - task_results = data["data"]["task_results"] - assert len(task_results) == 1 - assert task_results[0]["task_name"] == "test_train_task" - assert task_results[0]["task_type"] == "TRAIN" - - -def test_create_experiment_tasks_only(client): - test_experiment_name = f"test_tasks_only_{os.getpid()}" - resp = client.post(f"/recipes/3/create_experiment?experiment_name={test_experiment_name}") - assert resp.status_code == 200 - data = resp.json() - assert "status" in data or "message" in data - - -def test_create_experiment_duplicate_name(client): - test_experiment_name = f"duplicate_test_{os.getpid()}" - - # First creation - resp1 = client.post(f"/recipes/1/create_experiment?experiment_name={test_experiment_name}") - assert resp1.status_code == 200 - - # Second creation with same name should fail - resp2 = client.post(f"/recipes/1/create_experiment?experiment_name={test_experiment_name}") - assert resp2.status_code == 200 - data = resp2.json() - assert data.get("status") == "error" - assert "already exists" in data.get("message", "") - - -def test_create_experiment_invalid_recipe_id(client): - test_experiment_name = f"invalid_recipe_test_{os.getpid()}" - resp = client.post(f"/recipes/999/create_experiment?experiment_name={test_experiment_name}") - assert resp.status_code == 200 - data = resp.json() - assert data.get("status") == "error" - assert "not found" in data.get("message", "") - - -def test_create_experiment_with_adaptor_name(client): - """Test creating experiment with recipe that has adaptor_name in config (covers line 281)""" - test_experiment_name = f"test_adaptor_{os.getpid()}" - resp = client.post(f"/recipes/4/create_experiment?experiment_name={test_experiment_name}") - assert resp.status_code == 200 - data = resp.json() - assert "status" in data or "message" in data - if data.get("status") == "success": - assert "data" in data - assert "task_results" in data["data"] - - -def test_create_experiment_with_invalid_json_config(client): - """Test creating experiment with invalid JSON config to trigger exception handling (covers lines 306-307)""" - test_experiment_name = f"test_invalid_json_{os.getpid()}" - resp = client.post(f"/recipes/5/create_experiment?experiment_name={test_experiment_name}") - assert resp.status_code == 200 - data = resp.json() - assert "status" in data or "message" in data - if data.get("status") == "success" and "data" in data and "task_results" in data["data"]: - task_results = data["data"]["task_results"] - assert len(task_results) > 0 - has_error = any("error" in result.get("status", "") for result in task_results) - assert has_error - - -def test_recipes_get_by_id_with_multiple_task_types(client): - """Test that a recipe with multiple task types (TRAIN, EVAL, GENERATE) is handled correctly""" - resp = client.get("/recipes/6") - assert resp.status_code == 200 - data = resp.json() - assert data["id"] == "6" - assert "tasks" in data - assert len(data["tasks"]) == 3 - task_types = [task["task_type"] for task in data["tasks"]] - assert "TRAIN" in task_types - assert "EVAL" in task_types - assert "GENERATE" in task_types - - -def test_create_experiment_with_multiple_task_types(client): - """Test creating an experiment with multiple task types""" - test_experiment_name = f"test_multi_tasks_{os.getpid()}" - resp = client.post(f"/recipes/6/create_experiment?experiment_name={test_experiment_name}") - assert resp.status_code == 200 - data = resp.json() - assert "status" in data - if data.get("status") == "success": - assert "data" in data - assert "task_results" in data["data"] - task_results = data["data"]["task_results"] - assert len(task_results) == 3 - - # Verify task names and types - task_names = [result["task_name"] for result in task_results] - assert "multi_train_task" in task_names - assert "multi_eval_task" in task_names - assert "multi_generate_task" in task_names - - # Verify task types - task_types = [result["task_type"] for result in task_results] - assert "TRAIN" in task_types - assert "EVAL" in task_types - assert "GENERATE" in task_types - - -def test_create_experiment_with_script_parameters_list_dict(client): - """Test creating experiment with recipe that has list and dict values in script_parameters (covers line 276)""" - test_experiment_name = f"test_script_params_{os.getpid()}" - resp = client.post(f"/recipes/10/create_experiment?experiment_name={test_experiment_name}") - assert resp.status_code == 200 - data = resp.json() - assert "status" in data or "message" in data - if data.get("status") == "success": - assert "data" in data - assert "task_results" in data["data"] - task_results = data["data"]["task_results"] - assert len(task_results) == 1 - task_result = task_results[0] - assert task_result.get("task_type") == "EVAL" - assert task_result.get("action") == "create_task" - - -def test_recipes_get_by_id_with_workflows(client): - """Test that a recipe with workflows is handled correctly""" - resp = client.get("/recipes/7") - assert resp.status_code == 200 - data = resp.json() - assert data["id"] == "7" - assert "workflows" in data - assert len(data["workflows"]) == 2 # Recipe 7 has 2 workflows - workflow_names = [wf["name"] for wf in data["workflows"]] - assert "Train_Only_Workflow" in workflow_names - assert "Train_Eval_Workflow" in workflow_names - for workflow in data["workflows"]: - assert "config" in workflow - assert "nodes" in workflow["config"] - assert len(workflow["config"]["nodes"]) > 0 - - -def test_create_experiment_with_workflows(client): - """Test creating an experiment with workflows""" - test_experiment_name = f"test_workflows_{os.getpid()}" - resp = client.post(f"/recipes/7/create_experiment?experiment_name={test_experiment_name}") - assert resp.status_code == 200 - data = resp.json() - assert "status" in data - if data.get("status") == "success": - assert "data" in data - assert "workflow_creation_results" in data["data"] - workflow_results = data["data"]["workflow_creation_results"] - assert len(workflow_results) == 2 # Recipe 7 has 2 workflows - workflow_names = [result.get("workflow_name") for result in workflow_results] - assert "Train_Only_Workflow" in workflow_names - assert "Train_Eval_Workflow" in workflow_names - for result in workflow_results: - assert result.get("action") == "create_workflow" - assert result.get("status") == "success" - assert "workflow_id" in result - - -def test_create_experiment_with_multiple_workflows(client): - """Test creating an experiment with multiple workflows""" - test_experiment_name = f"test_multi_workflows_{os.getpid()}" - resp = client.post(f"/recipes/7/create_experiment?experiment_name={test_experiment_name}") - assert resp.status_code == 200 - data = resp.json() - assert "status" in data - if data.get("status") == "success": - # Verify tasks were created with correct names - task_results = data["data"]["task_results"] - assert len(task_results) == 2 - task_names = [result["task_name"] for result in task_results] - assert "workflow_train_task" in task_names - assert "workflow_eval_task" in task_names - - # Verify workflows were created with correct task references - workflow_results = data["data"]["workflow_creation_results"] - assert len(workflow_results) == 2 - assert all(result["status"] == "success" for result in workflow_results) - - # Get the workflows to verify their task references - workflows_resp = client.get("/recipes/7") - workflows_data = workflows_resp.json() - for workflow in workflows_data["workflows"]: - if workflow["name"] == "Train_Only_Workflow": - assert workflow["config"]["nodes"][0]["task"] == "workflow_train_task" - elif workflow["name"] == "Train_Eval_Workflow": - assert workflow["config"]["nodes"][0]["task"] == "workflow_train_task" - assert workflow["config"]["nodes"][1]["task"] == "workflow_eval_task" - - -def test_create_experiment_with_invalid_workflow_config(client): - """Test creating experiment with invalid workflow config to test error handling""" - test_experiment_name = f"test_invalid_workflow_{os.getpid()}" - resp = client.post(f"/recipes/8/create_experiment?experiment_name={test_experiment_name}") - assert resp.status_code == 200 - data = resp.json() - assert "status" in data - if data.get("status") == "success": - assert "data" in data - assert "workflow_creation_results" in data["data"] - workflow_results = data["data"]["workflow_creation_results"] - assert len(workflow_results) == 1 # Recipe 8 has 1 invalid workflow - workflow_result = workflow_results[0] - assert workflow_result.get("workflow_name") == "Invalid_Workflow" - assert workflow_result.get("action") == "create_workflow" - assert "error" in workflow_result.get("status", "") - - -def test_create_experiment_without_workflows(client): - """Test creating an experiment from a recipe without workflows""" - test_experiment_name = f"test_no_workflows_{os.getpid()}" - resp = client.post(f"/recipes/6/create_experiment?experiment_name={test_experiment_name}") - assert resp.status_code == 200 - data = resp.json() - assert "status" in data - if data.get("status") == "success": - assert "data" in data - assert "workflow_creation_results" in data["data"] - workflow_results = data["data"]["workflow_creation_results"] - assert len(workflow_results) == 0 - - -def test_recipes_get_by_id_with_multiple_workflows(client): - """Test that a recipe with multiple workflows is handled correctly""" - resp = client.get("/recipes/7") # Changed to recipe 7 which has 2 workflows - assert resp.status_code == 200 - data = resp.json() - assert data["id"] == "7" - assert "workflows" in data - assert len(data["workflows"]) == 2 # Recipe 7 has 2 workflows - workflow_names = [wf["name"] for wf in data["workflows"]] - assert "Train_Only_Workflow" in workflow_names - assert "Train_Eval_Workflow" in workflow_names - for workflow in data["workflows"]: - assert "config" in workflow - assert "nodes" in workflow["config"] - assert len(workflow["config"]["nodes"]) > 0 diff --git a/api/test/api/test_tasks.py b/api/test/api/test_tasks.py deleted file mode 100644 index a96443071..000000000 --- a/api/test/api/test_tasks.py +++ /dev/null @@ -1,58 +0,0 @@ -import pytest - -pytestmark = pytest.mark.skip("skipping these as they need to be fixed") - - -def test_tasks_list(client): - resp = client.get("/tasks/list") - assert resp.status_code == 200 - assert isinstance(resp.json(), list) or isinstance(resp.json(), dict) - - -def test_tasks_get_by_id(client): - resp = client.get("/tasks/1/get") - assert resp.status_code in (200, 404) - - -def test_tasks_list_by_type(client): - resp = client.get("/tasks/list_by_type?type=TRAIN") - assert resp.status_code in (200, 404) - - -def test_add_task(client): - new_task = { - "name": "Test Task", - "type": "TRAIN", - "inputs": "{}", - "config": "{}", - "plugin": "test_plugin", - "outputs": "{}", - "experiment_id": 1, - } - resp = client.put("/tasks/new_task", json=new_task) - assert resp.status_code == 200 - assert "message" in resp.json() or "status" in resp.json() - - -def test_update_task(client): - update_data = {"name": "Updated Task", "inputs": "{}", "config": "{}", "outputs": "{}"} - resp = client.put("/tasks/1/update", json=update_data) - assert resp.status_code == 200 - assert resp.json()["message"] == "OK" - - -def test_list_by_type_in_experiment(client): - resp = client.get("/tasks/list_by_type_in_experiment?type=TRAIN&experiment_id=1") - assert resp.status_code in (200, 404) - - -def test_delete_task(client): - resp = client.get("/tasks/1/delete") - assert resp.status_code == 200 - assert resp.json()["message"] == "OK" - - -def test_delete_all_tasks(client): - resp = client.get("/tasks/delete_all") - assert resp.status_code == 200 - assert resp.json()["message"] == "OK" diff --git a/api/test/api/test_workflow_triggers.py b/api/test/api/test_workflow_triggers.py deleted file mode 100644 index 8ddc9eacf..000000000 --- a/api/test/api/test_workflow_triggers.py +++ /dev/null @@ -1,86 +0,0 @@ -import json -import pytest -from transformerlab.routers.experiment.workflows import workflows_get_by_trigger_type - -pytestmark = pytest.mark.skip("Skipping all workflow trigger tests due to database index conflicts") - - -def test_workflow_triggers_endpoint_basic_functionality(client, experiment_id): - """Test basic workflow triggering functionality""" - # Create a workflow with TRAIN trigger - config = {"nodes": [{"type": "START", "id": "start", "name": "START", "out": []}], "triggers": ["TRAIN"]} - resp = client.get( - f"/experiment/{experiment_id}/workflows/create", - params={"name": "test_trigger_workflow", "config": json.dumps(config)}, - ) - assert resp.status_code == 200 - - # Test the function directly - import asyncio - - workflows = asyncio.run(workflows_get_by_trigger_type(experiment_id, "TRAIN")) - assert isinstance(workflows, list) - assert len(workflows) == 1 - - -def test_workflow_triggers_endpoint_export_model_mapping(client, experiment_id): - """Test that EXPORT trigger works correctly""" - # Create a workflow with EXPORT trigger - config = {"nodes": [{"type": "START", "id": "start", "name": "START", "out": []}], "triggers": ["EXPORT"]} - resp = client.get( - f"/experiment/{experiment_id}/workflows/create", - params={"name": "test_export_trigger", "config": json.dumps(config)}, - ) - assert resp.status_code == 200 - - # Test the function directly - import asyncio - - workflows = asyncio.run(workflows_get_by_trigger_type(experiment_id, "EXPORT")) - assert isinstance(workflows, list) - assert len(workflows) == 1 - - -def test_workflow_triggers_endpoint_error_handling(client, experiment_id): - """Test that malformed configs are handled gracefully""" - # Create a workflow with malformed config directly in the database - import asyncio - from transformerlab.db.workflows import workflow_create - - async def create_malformed_workflow(): - return await workflow_create("test_malformed_trigger", "invalid json", experiment_id) - - workflow_id = asyncio.run(create_malformed_workflow()) - assert workflow_id is not None - - # Test that malformed config doesn't crash the function - workflows = asyncio.run(workflows_get_by_trigger_type(experiment_id, "TRAIN")) - assert isinstance(workflows, list) - # Should not contain the malformed workflow - - -def test_workflow_triggers_endpoint_no_matching_triggers(client, experiment_id): - """Test that no workflows are returned when no triggers match""" - # Create a workflow with EVAL trigger - config = {"nodes": [{"type": "START", "id": "start", "name": "START", "out": []}], "triggers": ["EVAL"]} - resp = client.get( - f"/experiment/{experiment_id}/workflows/create", - params={"name": "test_eval_workflow", "config": json.dumps(config)}, - ) - assert resp.status_code == 200 - - # Test with different trigger type - import asyncio - - workflows = asyncio.run(workflows_get_by_trigger_type(experiment_id, "GENERATE")) - assert isinstance(workflows, list) - assert len(workflows) == 0 - - -@pytest.fixture -def experiment_id(): - from transformerlab.services.experiment_service import experiment_create, experiment_delete - - exp_id = experiment_create("test_experiment", {}) - yield exp_id - experiment_delete(exp_id) diff --git a/api/test/api/test_workflows.py b/api/test/api/test_workflows.py deleted file mode 100644 index 5a7560299..000000000 --- a/api/test/api/test_workflows.py +++ /dev/null @@ -1,309 +0,0 @@ -import json -import pytest - -pytestmark = pytest.mark.skip("Skipping all workflow tests due to database index conflicts") - - -def test_create_empty_workflow(client): - resp = client.get("/experiment/1/workflows/create_empty", params={"name": "testwf"}) - assert resp.status_code == 200 - workflow_id = resp.json() - assert workflow_id - # Cleanup - del_resp = client.get(f"/experiment/1/workflows/delete/{workflow_id}") - assert del_resp.status_code == 200 - assert del_resp.json().get("message") == "OK" - - -def test_list_workflows(client): - # Create a workflow to ensure at least one exists - create_resp = client.get("/experiment/1/workflows/create_empty", params={"name": "listtest"}) - workflow_id = create_resp.json() - resp = client.get("/experiment/1/workflows/list") - assert resp.status_code == 200 - assert isinstance(resp.json(), list) - # Cleanup - del_resp = client.get(f"/experiment/1/workflows/delete/{workflow_id}") - assert del_resp.status_code == 200 - assert del_resp.json().get("message") == "OK" - - -def test_delete_workflow(client): - # Create a workflow to delete - create_resp = client.get("/experiment/1/workflows/create_empty", params={"name": "todelete"}) - workflow_id = create_resp.json() - del_resp = client.get(f"/experiment/1/workflows/delete/{workflow_id}") - assert del_resp.status_code == 200 - assert del_resp.json().get("message") == "OK" - - -def test_workflow_update_name(client): - create_resp = client.get("/experiment/1/workflows/create_empty", params={"name": "updatename"}) - workflow_id = create_resp.json() - resp = client.get(f"/experiment/1/workflows/{workflow_id}/update_name", params={"new_name": "updatedname"}) - assert resp.status_code == 200 - assert resp.json().get("message") == "OK" - # Cleanup - client.get(f"/experiment/1/workflows/delete/{workflow_id}") - - -def test_workflow_add_and_delete_node(client): - create_resp = client.get("/experiment/1/workflows/create_empty", params={"name": "addnode"}) - workflow_id = create_resp.json() - node = {"type": "TASK", "name": "Test Task", "task": "test_task", "out": []} - add_node_resp = client.get(f"/experiment/1/workflows/{workflow_id}/add_node", params={"node": json.dumps(node)}) - assert add_node_resp.status_code == 200 - # Get workflow config to find node id - wf_resp = client.get("/experiment/1/workflows/list") - workflow = next(w for w in wf_resp.json() if w["id"] == workflow_id) - config = workflow["config"] - if not isinstance(config, dict): - config = json.loads(config) - task_node = next(n for n in config["nodes"] if n["type"] == "TASK") - node_id = task_node["id"] - # Delete node - del_node_resp = client.get(f"/experiment/1/workflows/{workflow_id}/{node_id}/delete_node") - assert del_node_resp.status_code == 200 - # Cleanup - client.get(f"/experiment/1/workflows/delete/{workflow_id}") - - -def test_workflow_edit_node_metadata(client): - create_resp = client.get("/experiment/1/workflows/create_empty", params={"name": "editmeta"}) - workflow_id = create_resp.json() - node = {"type": "TASK", "name": "Meta Task", "task": "test_task", "out": []} - client.get(f"/experiment/1/workflows/{workflow_id}/add_node", params={"node": json.dumps(node)}) - wf_resp = client.get("/experiment/1/workflows/list") - workflow = next(w for w in wf_resp.json() if w["id"] == workflow_id) - config = workflow["config"] - if not isinstance(config, dict): - config = json.loads(config) - task_node = next(n for n in config["nodes"] if n["type"] == "TASK") - node_id = task_node["id"] - meta = {"desc": "testdesc"} - edit_resp = client.get( - f"/experiment/1/workflows/{workflow_id}/{node_id}/edit_node_metadata", params={"metadata": json.dumps(meta)} - ) - assert edit_resp.status_code == 200 - # Cleanup - client.get(f"/experiment/1/workflows/delete/{workflow_id}") - - -def test_workflow_add_and_remove_edge(client): - create_resp = client.get("/experiment/1/workflows/create_empty", params={"name": "edgecase"}) - workflow_id = create_resp.json() - node1 = {"type": "TASK", "name": "Task1", "task": "task1", "out": []} - node2 = {"type": "TASK", "name": "Task2", "task": "task2", "out": []} - client.get(f"/experiment/1/workflows/{workflow_id}/add_node", params={"node": json.dumps(node1)}) - client.get(f"/experiment/1/workflows/{workflow_id}/add_node", params={"node": json.dumps(node2)}) - wf_resp = client.get("/experiment/1/workflows/list") - workflow = next(w for w in wf_resp.json() if w["id"] == workflow_id) - config = workflow["config"] - if not isinstance(config, dict): - config = json.loads(config) - task_nodes = [n for n in config["nodes"] if n["type"] == "TASK"] - node1_id, node2_id = task_nodes[0]["id"], task_nodes[1]["id"] - add_edge_resp = client.post( - f"/experiment/1/workflows/{workflow_id}/{node1_id}/add_edge", params={"end_node_id": node2_id} - ) - assert add_edge_resp.status_code == 200 - remove_edge_resp = client.post( - f"/experiment/1/workflows/{workflow_id}/{node1_id}/remove_edge", params={"end_node_id": node2_id} - ) - assert remove_edge_resp.status_code == 200 - # Cleanup - client.get(f"/experiment/1/workflows/delete/{workflow_id}") - - -def test_workflow_export_to_yaml(client): - create_resp = client.get("/experiment/1/workflows/create_empty", params={"name": "yamltest"}) - workflow_id = create_resp.json() - export_resp = client.get(f"/experiment/1/workflows/{workflow_id}/export_to_yaml") - assert export_resp.status_code == 200 - # Cleanup - client.get(f"/experiment/1/workflows/delete/{workflow_id}") - - -def test_workflow_add_eval_node_and_metadata(client): - create_resp = client.get("/experiment/1/workflows/create_empty", params={"name": "evalnode"}) - workflow_id = create_resp.json() - # Add EVAL node with realistic structure - node = {"name": "hello", "task": "WarmPanda", "type": "EVAL", "metadata": {}, "out": []} - add_node_resp = client.get(f"/experiment/1/workflows/{workflow_id}/add_node", params={"node": json.dumps(node)}) - assert add_node_resp.status_code == 200 - # Get workflow config to find node id - wf_resp = client.get("/experiment/1/workflows/list") - workflow = next(w for w in wf_resp.json() if w["id"] == workflow_id) - config = workflow["config"] - if not isinstance(config, dict): - config = json.loads(config) - eval_node = next(n for n in config["nodes"] if n["type"] == "EVAL") - node_id = eval_node["id"] - # Edit metadata - meta = {"desc": "eval node test"} - edit_resp = client.get( - f"/experiment/1/workflows/{workflow_id}/{node_id}/edit_node_metadata", params={"metadata": json.dumps(meta)} - ) - assert edit_resp.status_code == 200 - # Cleanup - client.get(f"/experiment/1/workflows/delete/{workflow_id}") - - -def test_workflow_update_config(client): - # Create a workflow to update config - create_resp = client.get("/experiment/1/workflows/create_empty", params={"name": "updateconfig"}) - workflow_id = create_resp.json() - - # Test config with a custom structure - new_config = { - "nodes": [ - {"type": "START", "id": "start-123", "name": "START", "out": ["task-456"]}, - {"type": "TASK", "id": "task-456", "name": "Test Task", "task": "test_task", "out": []}, - ] - } - - # Update config using PUT endpoint - resp = client.put(f"/experiment/1/workflows/{workflow_id}/config", json=new_config) - assert resp.status_code == 200 - assert resp.json().get("message") == "OK" - - # Verify the config was updated by fetching the workflow - list_resp = client.get("/experiment/1/workflows/list") - workflow = next(w for w in list_resp.json() if w["id"] == workflow_id) - config = workflow["config"] - if not isinstance(config, dict): - config = json.loads(config) - - # Verify the config matches what we set - assert len(config["nodes"]) == 2 - assert config["nodes"][0]["type"] == "START" - assert config["nodes"][1]["type"] == "TASK" - assert config["nodes"][0]["out"] == ["task-456"] - - # Cleanup - client.get(f"/experiment/1/workflows/delete/{workflow_id}") - - -def test_workflow_task_isolation_success(client): - """Test that workflows can find tasks in their own experiment with correct type.""" - # Create a TRAIN task in experiment 1 - task_data = { - "name": "isolation_test_task", - "type": "TRAIN", - "inputs": '{"model_name": "test_model"}', - "config": '{"learning_rate": 0.001}', - "plugin": "test_plugin", - "outputs": '{"adaptor_name": "test_adaptor"}', - "experiment_id": 1, - } - task_resp = client.put("/tasks/new_task", json=task_data) - assert task_resp.status_code == 200 - - # Verify task exists in experiment 1 - tasks_resp = client.get("/tasks/list_by_type_in_experiment?type=TRAIN&experiment_id=1") - assert tasks_resp.status_code == 200 - tasks = tasks_resp.json() - test_task = next((t for t in tasks if t["name"] == "isolation_test_task"), None) - assert test_task is not None - task_id = test_task["id"] - - # Create a workflow that references this task - create_resp = client.get("/experiment/1/workflows/create_empty", params={"name": "isolation_test_workflow"}) - assert create_resp.status_code == 200 - workflow_id = create_resp.json() - - # Add TRAIN node - train_node = {"type": "TRAIN", "name": "Training Node", "task": "isolation_test_task", "out": []} - add_node_resp = client.get( - f"/experiment/1/workflows/{workflow_id}/add_node", params={"node": json.dumps(train_node)} - ) - assert add_node_resp.status_code == 200 - - # Start the workflow - start_resp = client.get(f"/experiment/1/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - assert start_resp.json().get("message") == "OK" - - # Test the task isolation by triggering workflow execution - from transformerlab.routers.experiment import workflows - import asyncio - - async def trigger_workflow_execution(): - return await workflows.start_next_step_in_workflow() - - # Run the workflow execution step - asyncio.run(trigger_workflow_execution()) - - # Verify workflow was processed successfully - runs_resp = client.get("/experiment/1/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - test_run = next((r for r in runs if r["workflow_id"] == workflow_id), None) - assert test_run is not None - assert test_run["status"] in ["QUEUED", "RUNNING", "COMPLETE", "FAILED"] - - # Cleanup - client.get(f"/tasks/{task_id}/delete") - client.get(f"/experiment/1/workflows/delete/{workflow_id}") - - -def test_workflow_task_isolation_cross_experiment_failure(client): - """Test that workflows cannot access tasks from other experiments.""" - # Create a task in experiment 2 - task_data = { - "name": "cross_experiment_task", - "type": "TRAIN", - "inputs": '{"model_name": "test_model"}', - "config": '{"learning_rate": 0.001}', - "plugin": "test_plugin", - "outputs": '{"adaptor_name": "test_adaptor"}', - "experiment_id": 2, - } - task_resp = client.put("/tasks/new_task", json=task_data) - assert task_resp.status_code == 200 - - # Verify task exists in experiment 2 - tasks_resp = client.get("/tasks/list_by_type_in_experiment?type=TRAIN&experiment_id=2") - assert tasks_resp.status_code == 200 - tasks = tasks_resp.json() - test_task = next((t for t in tasks if t["name"] == "cross_experiment_task"), None) - assert test_task is not None - task_id = test_task["id"] - - # Create a workflow in experiment 1 that tries to reference the task from experiment 2 - create_resp = client.get("/experiment/1/workflows/create_empty", params={"name": "cross_exp_workflow"}) - assert create_resp.status_code == 200 - workflow_id = create_resp.json() - - # Add a node that references the task from experiment 2 - train_node = {"type": "TRAIN", "name": "Cross Exp Node", "task": "cross_experiment_task", "out": []} - add_node_resp = client.get( - f"/experiment/1/workflows/{workflow_id}/add_node", params={"node": json.dumps(train_node)} - ) - assert add_node_resp.status_code == 200 - - # Start the workflow - start_resp = client.get(f"/experiment/1/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Trigger workflow execution - from transformerlab.routers.experiment import workflows - import asyncio - - async def trigger_workflow_execution(): - return await workflows.start_next_step_in_workflow() - - asyncio.run(trigger_workflow_execution()) - - # Verify workflow failed because it couldn't find the cross-experiment task - runs_resp = client.get("/experiment/1/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - test_run = next((r for r in runs if r["workflow_id"] == workflow_id), None) - assert test_run is not None - # The workflow should fail because the task is not found in experiment 1 - assert test_run["status"] == "FAILED" - - # Cleanup - client.get(f"/tasks/{task_id}/delete") - client.get(f"/experiment/1/workflows/delete/{workflow_id}") diff --git a/api/test/api/test_workflows_old.py b/api/test/api/test_workflows_old.py deleted file mode 100644 index afee8fc3c..000000000 --- a/api/test/api/test_workflows_old.py +++ /dev/null @@ -1,1462 +0,0 @@ -import json -import pytest -from transformerlab.routers.experiment import workflows as wf - -pytestmark = pytest.mark.skip("This entire test file is currently under development.") - - -@pytest.fixture(scope="module") -def experiment_id(client): - """Create a single experiment for all workflow tests and clean up afterward""" - exp_resp = client.get("/experiment/create?name=test_workflows_experiment") - assert exp_resp.status_code == 200 - exp_id = exp_resp.json() - - yield exp_id - - # Cleanup: delete the experiment after all tests are done - client.get(f"/experiment/delete/{exp_id}") - # Don't assert on the delete response as it might fail if experiment is already gone - - -def test_workflows_list(client, experiment_id): - resp = client.get(f"/experiment/{experiment_id}/workflows/list") - assert resp.status_code == 200 - assert isinstance(resp.json(), list) or isinstance(resp.json(), dict) - - -def test_workflows_delete(client, experiment_id): - # Create a workflow to delete - create_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=workflow_to_delete") - assert create_resp.status_code == 200 - workflow_id = create_resp.json() - - # Try to delete the workflow - resp = client.get(f"/experiment/{experiment_id}/workflows/delete/{workflow_id}") - assert resp.status_code == 200 - assert resp.json() == {"message": "OK"} - - # Try to delete a non-existent workflow - resp = client.get(f"/experiment/{experiment_id}/workflows/delete/non_existent_workflow") - assert resp.status_code == 200 - assert resp.json() == {"error": "Workflow not found or does not belong to this experiment"} - - -def test_workflows_create(client, experiment_id): - import json - - # Create workflow with required fields - config = {"nodes": [{"type": "START", "id": "start", "name": "START", "out": []}], "status": "CREATED"} - resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow&config={json.dumps(config)}") - assert resp.status_code == 200 - assert resp.json() is not None # Just check that we get a valid response - - -def test_experiment_workflows_list(client, experiment_id): - """Test the new experiment workflows list endpoint""" - # Create a workflow in the experiment - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - - # Test the new experiment workflows list endpoint - resp = client.get(f"/experiment/{experiment_id}/workflows/list") - assert resp.status_code == 200 - workflows = resp.json() - assert isinstance(workflows, list) - assert len(workflows) > 0 - assert workflows[0]["experiment_id"] == experiment_id - - -def test_experiment_workflow_runs(client, experiment_id): - """Test the new experiment workflow runs endpoint""" - # Create a workflow in the experiment - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Queue the workflow to create a run - queue_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/start") - assert queue_resp.status_code == 200 - - # Test the new experiment workflow runs endpoint - resp = client.get(f"/experiment/{experiment_id}/workflows/runs") - assert resp.status_code == 200 - runs = resp.json() - assert isinstance(runs, list) - assert len(runs) > 0 - assert runs[0]["experiment_id"] == experiment_id - assert runs[0]["workflow_id"] == workflow_id - - -def test_workflow_node_operations(client, experiment_id): - """Test node-related operations in a workflow""" - # Create workflow - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Add a node - node_data = { - "type": "TASK", - "name": "Test Task", - "task": "test_task", # Required field - "out": [], # Required field - } - add_node_resp = client.get( - f"/experiment/{experiment_id}/workflows/{workflow_id}/add_node?node={json.dumps(node_data)}" - ) - assert add_node_resp.status_code == 200 - - # Get the workflow to find the node ID - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/list") - assert workflow_resp.status_code == 200 - workflows = workflow_resp.json() - workflow = next(w for w in workflows if w["id"] == workflow_id) - workflow_config = workflow["config"] - if not isinstance(workflow_config, dict): - workflow_config = json.loads(workflow_config) - - nodes = workflow_config["nodes"] - node_id = next(n["id"] for n in nodes if n["type"] == "TASK") - - # Update node metadata - metadata = {"key": "value"} - metadata_resp = client.get( - f"/experiment/{experiment_id}/workflows/{workflow_id}/{node_id}/edit_node_metadata?metadata={json.dumps(metadata)}" - ) - assert metadata_resp.status_code == 200 - - # Update node - new_node = {"id": node_id, "type": "TASK", "name": "Updated Task", "task": "test_task", "out": []} - update_resp = client.post( - f"/experiment/{experiment_id}/workflows/{workflow_id}/{node_id}/update_node", json=new_node - ) - assert update_resp.status_code == 200 - - # Add edge - edge_resp = client.post(f"/experiment/{experiment_id}/workflows/{workflow_id}/START/add_edge?end_node_id={node_id}") - assert edge_resp.status_code == 200 - - # Remove edge - remove_edge_resp = client.post( - f"/experiment/{experiment_id}/workflows/{workflow_id}/START/remove_edge?end_node_id={node_id}" - ) - assert remove_edge_resp.status_code == 200 - - # Delete node - delete_node_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/{node_id}/delete_node") - assert delete_node_resp.status_code == 200 - - -def test_workflow_name_update(client, experiment_id): - """Test updating a workflow's name""" - # Create workflow - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=old_name") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Update name - update_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/update_name?new_name=new_name") - assert update_resp.status_code == 200 - assert update_resp.json() == {"message": "OK"} - - -def test_workflow_yaml_operations(client, experiment_id): - """Test YAML import/export operations""" - # Create workflow with required fields - config = {"nodes": [{"type": "START", "id": "start", "name": "START", "out": []}], "status": "CREATED"} - workflow_resp = client.get( - f"/experiment/{experiment_id}/workflows/create?name=test_workflow&config={json.dumps(config)}" - ) - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Queue the workflow to create a workflow run with the required fields - queue_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/start") - assert queue_resp.status_code == 200 - - # Export to YAML - using the experiment-scoped path - export_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/export_to_yaml") - assert export_resp.status_code == 200 - # Check that we get a file response with the correct filename - # assert export_resp.headers.get("content-type") == "text/plain; charset=utf-8" - assert export_resp.headers.get("content-disposition") == 'attachment; filename="test_workflow.yaml"' - - -def test_workflow_run_operations(client, experiment_id): - """Test workflow run operations""" - # Create workflow - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Start workflow - start_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Get workflow runs - runs_resp = client.get(f"/experiment/{experiment_id}/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - assert isinstance(runs, list) - assert len(runs) > 0 - run_id = runs[0]["id"] - - # Get specific run - run_resp = client.get(f"/experiment/{experiment_id}/workflows/runs/{run_id}") - assert run_resp.status_code == 200 - run_data = run_resp.json() - assert "run" in run_data - assert "workflow" in run_data - assert "jobs" in run_data - - -def test_workflow_next_step(client, experiment_id): - """Test workflow progression through complete workflow execution""" - # Create workflow with simple configuration that can complete quickly - config = { - "nodes": [ - {"type": "START", "id": "start", "name": "START", "out": []}, - ] - } - workflow_resp = client.get( - f"/experiment/{experiment_id}/workflows/create?name=test_workflow&config={json.dumps(config)}" - ) - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Start workflow to create a queued run - start_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Verify workflow run was created - runs_resp = client.get(f"/experiment/{experiment_id}/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - assert len(runs) > 0 - assert runs[0]["status"] in ["RUNNING", "QUEUED"] - - # Test workflow run retrieval - run_id = runs[0]["id"] - run_resp = client.get(f"/experiment/{experiment_id}/workflows/runs/{run_id}") - assert run_resp.status_code == 200 - run_data = run_resp.json() - assert "run" in run_data - assert "workflow" in run_data - assert "jobs" in run_data - - -def test_workflow_create_invalid(client, experiment_id): - """Test workflow creation with invalid config""" - # Test workflow creation without config (should still work) - resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow_no_config") - assert resp.status_code == 200 - # Just verify we get some response - assert resp.json() is not None - - -def test_workflow_run_cancel(client, experiment_id): - """Test workflow run cancellation""" - # Create workflow - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Start workflow to create a run - start_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Get the workflow run - runs_resp = client.get(f"/experiment/{experiment_id}/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - assert len(runs) > 0 - run_id = runs[0]["id"] - - # Test successful cancellation - cancel_resp = client.get(f"/experiment/{experiment_id}/workflows/{run_id}/cancel") - assert cancel_resp.status_code == 200 - cancel_data = cancel_resp.json() - assert "message" in cancel_data - assert f"Workflow run {run_id} cancellation initiated" in cancel_data["message"] - assert "cancelled_jobs" in cancel_data - assert "note" in cancel_data - - -def test_workflow_run_cancel_with_active_jobs(client, experiment_id): - """Test workflow run cancellation with actual running jobs""" - - # Create workflow - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow_with_jobs") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Create a job via API (more realistic) - job_resp = client.get(f"/jobs/create?type=TRAIN&status=RUNNING&experiment_id={experiment_id}") - assert job_resp.status_code == 200 - job_id = job_resp.json() - - # Start workflow to create a run - start_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Get the workflow run - runs_resp = client.get(f"/experiment/{experiment_id}/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - assert len(runs) > 0 - run_id = runs[0]["id"] - - # Manually add job to workflow run to simulate active job - # This simulates what happens when a workflow step is running - import asyncio - from transformerlab.db import db - - async def add_job_to_run(): - await db.workflow_run_update_with_new_job(run_id, f'["{job_id}"]', f"[{job_id}]") - - asyncio.run(add_job_to_run()) - - # Test the cancellation via API - cancel_resp = client.get(f"/experiment/{experiment_id}/workflows/{run_id}/cancel") - assert cancel_resp.status_code == 200 - - response_data = cancel_resp.json() - assert "cancelled_jobs" in response_data - assert job_id in response_data["cancelled_jobs"] - assert len(response_data["cancelled_jobs"]) == 1 - - # Verify job was actually stopped by checking via API - job_resp = client.get(f"/jobs/{job_id}") - assert job_resp.status_code == 200 - job_data = job_resp.json() - assert job_data["job_data"]["stop"] - - -def test_workflow_run_cancel_invalid_cases(client, experiment_id): - """Test workflow run cancellation with invalid cases""" - # Test cancelling non-existent workflow run - cancel_resp = client.get(f"/experiment/{experiment_id}/workflows/non_existent_run/cancel") - assert cancel_resp.status_code == 200 - assert cancel_resp.json() == {"error": "Workflow run not found"} - - # Create workflow and run for testing status checks - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Start workflow to create a run - start_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Get the workflow run - runs_resp = client.get(f"/experiment/{experiment_id}/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - assert len(runs) > 0 - run_id = runs[0]["id"] - - # First cancellation should succeed - cancel_resp = client.get(f"/experiment/{experiment_id}/workflows/{run_id}/cancel") - assert cancel_resp.status_code == 200 - - -def test_workflow_run_cancel_security(client, experiment_id): - """Test workflow run cancellation security checks across experiments""" - # Create a second experiment for security testing - exp2_resp = client.get("/experiment/create?name=test_workflow_cancel_security_exp2") - assert exp2_resp.status_code == 200 - exp2_id = exp2_resp.json() - - try: - # Create workflow in original experiment - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Start workflow to create a run - start_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Get the workflow run from original experiment - runs_resp = client.get(f"/experiment/{experiment_id}/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - assert len(runs) > 0 - run_id = runs[0]["id"] - - # Try to cancel the workflow run from experiment 2 (should fail) - cancel_resp = client.get(f"/experiment/{exp2_id}/workflows/{run_id}/cancel") - assert cancel_resp.status_code == 200 - assert cancel_resp.json() == {"error": "Associated workflow not found or does not belong to this experiment"} - - # Verify cancellation works from the correct experiment - cancel_resp = client.get(f"/experiment/{experiment_id}/workflows/{run_id}/cancel") - assert cancel_resp.status_code == 200 - cancel_data = cancel_resp.json() - assert "message" in cancel_data - assert f"Workflow run {run_id} cancellation initiated" in cancel_data["message"] - finally: - # Cleanup the second experiment - client.get(f"/experiment/delete/{exp2_id}") - - -def test_workflow_run_cancel_edge_cases(client, experiment_id): - """Test workflow run cancellation edge cases""" - # Create workflow with complex configuration - config = { - "nodes": [ - {"type": "START", "id": "start", "name": "START", "out": ["task1"]}, - {"type": "TASK", "id": "task1", "name": "Task 1", "task": "test_task", "out": ["task2"]}, - {"type": "TASK", "id": "task2", "name": "Task 2", "task": "test_task", "out": []}, - ] - } - workflow_resp = client.get( - f"/experiment/{experiment_id}/workflows/create?name=test_complex_workflow&config={json.dumps(config)}" - ) - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Start workflow - start_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Get workflow run - runs_resp = client.get(f"/experiment/{experiment_id}/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - assert len(runs) > 0 - run_id = runs[0]["id"] - - # Cancel the workflow run - cancel_resp = client.get(f"/experiment/{experiment_id}/workflows/{run_id}/cancel") - assert cancel_resp.status_code == 200 - cancel_data = cancel_resp.json() - - # Verify response structure - assert isinstance(cancel_data.get("cancelled_jobs"), list) - assert cancel_data.get("note") == "Workflow status will be updated to CANCELLED automatically" - - -def test_workflow_node_operations_invalid(client, experiment_id): - """Test node operations with invalid node IDs""" - # Use shared experiment instead of creating new one - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - - # Just test that the endpoint exists and doesn't crash - no complex operations - resp = client.get(f"/experiment/{experiment_id}/workflows/list") - assert resp.status_code == 200 - - -def test_workflow_edge_operations_invalid(client, experiment_id): - """Test edge operations with invalid node IDs""" - # Use shared experiment instead of creating new one - config = {"nodes": [{"type": "START", "id": "start", "name": "START", "out": []}], "status": "CREATED"} - workflow_resp = client.get( - f"/experiment/{experiment_id}/workflows/create?name=test_workflow&config={json.dumps(config)}" - ) - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Only test operations that are guaranteed to work - # Just verify the endpoints exist and don't crash - resp = client.post(f"/experiment/{experiment_id}/workflows/{workflow_id}/start/add_edge?end_node_id=non_existent") - assert resp.status_code == 200 - - -def test_workflow_run_operations_invalid(client, experiment_id): - """Test workflow run operations with invalid run IDs""" - # Try to get non-existent run using shared experiment - resp = client.get(f"/experiment/{experiment_id}/workflows/runs/non_existent_run") - assert resp.status_code == 200 - assert resp.json() == {"error": "Workflow run not found"} - - -def test_workflow_name_update_invalid(client, experiment_id): - """Test invalid workflow name updates""" - # Use shared experiment instead of creating new one - config = {"nodes": [{"type": "START", "id": "start", "name": "START", "out": []}], "status": "CREATED"} - workflow_resp = client.get( - f"/experiment/{experiment_id}/workflows/create?name=test_workflow&config={json.dumps(config)}" - ) - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - assert workflow_id is not None - - # Just test that the endpoint exists and doesn't crash - resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/update_name?new_name=new_name") - assert resp.status_code == 200 - - -def test_find_nodes_by_ids_helper(client): - # Use shared experiment instead of creating new one - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - - # Just test that the endpoint exists and doesn't crash - resp = client.get(f"/experiment/{experiment_id}/workflows/list") - assert resp.status_code == 200 - - -@pytest.mark.asyncio -async def test_determine_next_and_start_skip_helpers(client): - """Test determine_next_tasks and handle_start_node_skip helpers with various scenarios""" - - # Test with empty workflow config - empty_config = {"nodes": []} - result = await wf.determine_next_tasks([], empty_config, 0) - assert result == [] - - # Test with current tasks that have multiple outputs - workflow_config = { - "nodes": [ - {"id": "task1", "type": "TASK", "out": ["task2", "task3"]}, - {"id": "task2", "type": "TASK", "out": []}, - {"id": "task3", "type": "TASK", "out": []}, - ] - } - - result = await wf.determine_next_tasks(["task1"], workflow_config, 0) - assert set(result) == {"task2", "task3"} # Should get both outputs - - # Test handle_start_node_skip with multiple START nodes - workflow_config = { - "nodes": [ - {"id": "start1", "type": "START", "out": ["task1"]}, - {"id": "start2", "type": "START", "out": ["task2"]}, - {"id": "task1", "type": "TASK", "out": []}, - {"id": "task2", "type": "TASK", "out": []}, - ] - } - - actual_ids, next_nodes = await wf.handle_start_node_skip(["start1", "start2"], workflow_config, 0) - assert set(actual_ids) == {"task1", "task2"} - assert len(next_nodes) == 2 - - -def test_extract_previous_job_outputs_and_prepare_io(): - """Test extract_previous_job_outputs and prepare_next_task_io with comprehensive scenarios""" - - # Test GENERATE job with dataset_id at top level - generate_job_top_level = { - "type": "GENERATE", - "job_data": {"dataset_id": "Top Level Dataset", "config": {"dataset_id": "Config Level Dataset"}}, - } - outputs = wf.extract_previous_job_outputs(generate_job_top_level) - # Should prefer top-level dataset_id - assert outputs["dataset_name"] == "top-level-dataset" - - # Test TRAIN job with only model_name (no adaptor_name) - train_job_model_only = {"type": "TRAIN", "job_data": {"config": {"model_name": "test-model"}}} - outputs = wf.extract_previous_job_outputs(train_job_model_only) - assert outputs["model_name"] == "test-model" - assert "adaptor_name" not in outputs - - # Test TRAIN job with adaptor but no fuse_model - train_job_adaptor_no_fuse = { - "type": "TRAIN", - "job_data": {"config": {"model_name": "test-model", "adaptor_name": "test-adaptor"}}, - } - outputs = wf.extract_previous_job_outputs(train_job_adaptor_no_fuse) - assert outputs["adaptor_name"] == "test-adaptor" - - # Test TRAIN task with existing inputs and outputs - task_def_train = { - "type": "TRAIN", - "inputs": '{"existing_input": "value", "model_name": "old_model"}', - "outputs": '{"existing_output": "result"}', - } - previous_outputs = {"model_name": "new_model", "dataset_name": "test_dataset"} - - inputs_json, outputs_json = wf.prepare_next_task_io(task_def_train, previous_outputs) - inputs = json.loads(inputs_json) - outputs = json.loads(outputs_json) - - # Should override model_name but keep existing fields - assert inputs["model_name"] == "new_model" - assert inputs["dataset_name"] == "test_dataset" - assert inputs["existing_input"] == "value" - - # Should add adaptor_name and keep existing outputs - assert "adaptor_name" in outputs - assert outputs["existing_output"] == "result" - - # Test EVAL task with partial previous outputs - task_def_eval = {"type": "EVAL", "inputs": "{}", "outputs": "{}"} - partial_outputs = { - "model_name": "test_model" - # Missing other fields - } - - inputs_json, outputs_json = wf.prepare_next_task_io(task_def_eval, partial_outputs) - inputs = json.loads(inputs_json) - - # Should only include the fields that exist in previous_outputs - assert inputs["model_name"] == "test_model" - assert "model_architecture" not in inputs - assert "adaptor_name" not in inputs - assert "dataset_name" not in inputs - - -def test_workflow_security_checks(client): - """Test security checks for workflow operations across different experiments""" - # Create two separate experiments for security testing - exp1_resp = client.get("/experiment/create?name=test_workflow_security_exp1") - assert exp1_resp.status_code == 200 - exp1_id = exp1_resp.json() - - exp2_resp = client.get("/experiment/create?name=test_workflow_security_exp2") - assert exp2_resp.status_code == 200 - exp2_id = exp2_resp.json() - - try: - # Create a workflow in experiment 1 - workflow_resp = client.get(f"/experiment/{exp1_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Try to delete workflow from experiment 1 using experiment 2's context - delete_resp = client.get(f"/experiment/{exp2_id}/workflows/delete/{workflow_id}") - assert delete_resp.status_code == 200 - assert delete_resp.json() == {"error": "Workflow not found or does not belong to this experiment"} - - # Try to edit node metadata from wrong experiment - metadata_resp = client.get( - f"/experiment/{exp2_id}/workflows/{workflow_id}/node_id/edit_node_metadata?metadata={{}}" - ) - assert metadata_resp.status_code == 200 - assert metadata_resp.json() == {"error": "Workflow not found or does not belong to this experiment"} - - # Try to update name from wrong experiment - name_resp = client.get(f"/experiment/{exp2_id}/workflows/{workflow_id}/update_name?new_name=new_name") - assert name_resp.status_code == 200 - assert name_resp.json() == {"error": "Workflow not found or does not belong to this experiment"} - - # Try to add node from wrong experiment - node_data = {"type": "TASK", "name": "Test Task", "task": "test_task", "out": []} - add_node_resp = client.get( - f"/experiment/{exp2_id}/workflows/{workflow_id}/add_node?node={json.dumps(node_data)}" - ) - assert add_node_resp.status_code == 200 - assert add_node_resp.json() == {"error": "Workflow not found or does not belong to this experiment"} - - # Try to update node from wrong experiment - new_node = {"id": "test", "type": "TASK", "name": "Updated Task", "task": "test_task", "out": []} - update_resp = client.post(f"/experiment/{exp2_id}/workflows/{workflow_id}/test/update_node", json=new_node) - assert update_resp.status_code == 200 - assert update_resp.json() == {"error": "Workflow not found or does not belong to this experiment"} - - # Try to remove edge from wrong experiment - remove_edge_resp = client.post( - f"/experiment/{exp2_id}/workflows/{workflow_id}/start/remove_edge?end_node_id=test" - ) - assert remove_edge_resp.status_code == 200 - assert remove_edge_resp.json() == {"error": "Workflow not found or does not belong to this experiment"} - - # Try to add edge from wrong experiment - add_edge_resp = client.post(f"/experiment/{exp2_id}/workflows/{workflow_id}/start/add_edge?end_node_id=test") - assert add_edge_resp.status_code == 200 - assert add_edge_resp.json() == {"error": "Workflow not found or does not belong to this experiment"} - - # Try to delete node from wrong experiment - delete_node_resp = client.get(f"/experiment/{exp2_id}/workflows/{workflow_id}/test/delete_node") - assert delete_node_resp.status_code == 200 - assert delete_node_resp.json() == {"error": "Workflow not found or does not belong to this experiment"} - - # Try to export YAML from wrong experiment - export_resp = client.get(f"/experiment/{exp2_id}/workflows/{workflow_id}/export_to_yaml") - assert export_resp.status_code == 200 - assert export_resp.json() == {"error": "Workflow not found or does not belong to this experiment"} - - # Try to start workflow from wrong experiment - start_resp = client.get(f"/experiment/{exp2_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - assert start_resp.json() == {"error": "Workflow not found or does not belong to this experiment"} - finally: - # Cleanup both experiments - client.get(f"/experiment/delete/{exp1_id}") - client.get(f"/experiment/delete/{exp2_id}") - - -def test_workflow_start_node_deletion(client, experiment_id): - """Test that START nodes cannot be deleted""" - exp_id = experiment_id - - # Create workflow - workflow_resp = client.get(f"/experiment/{exp_id}/workflows/create?name=test_workflow123") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Get the workflow to find the START node ID - workflow_resp = client.get(f"/experiment/{exp_id}/workflows/list") - assert workflow_resp.status_code == 200 - workflows = workflow_resp.json() - workflow = next(w for w in workflows if w["id"] == workflow_id) - nodes = json.loads(workflow["config"])["nodes"] - start_node_id = next(n["id"] for n in nodes if n["type"] == "START") - - # Try to delete the START node - delete_node_resp = client.get(f"/experiment/{exp_id}/workflows/{workflow_id}/{start_node_id}/delete_node") - assert delete_node_resp.status_code == 200 - assert delete_node_resp.json() == {"message": "Cannot delete START node"} - - # Cleanup: delete the workflow - client.get(f"/experiment/{exp_id}/workflows/delete/{workflow_id}") - - -def test_workflow_no_active_workflow(client, experiment_id): - """Test workflow system when no workflow is active""" - exp_id = experiment_id - - # Create workflow but don't start it - workflow_resp = client.get(f"/experiment/{exp_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - - # Verify no workflow runs exist initially - runs_resp = client.get(f"/experiment/{exp_id}/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - assert len(runs) == 0 - - # Try to get a non-existent run - fake_run_resp = client.get(f"/experiment/{exp_id}/workflows/runs/fake_run_id") - assert fake_run_resp.status_code == 200 - assert fake_run_resp.json() == {"error": "Workflow run not found"} - - -def test_workflow_run_with_missing_associated_workflow(client, experiment_id): - """Test workflow run when associated workflow is missing (line 308)""" - exp_id = experiment_id - - # Create workflow and start it to create a run - workflow_resp = client.get(f"/experiment/{exp_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - start_resp = client.get(f"/experiment/{exp_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Get the run ID - runs_resp = client.get(f"/experiment/{exp_id}/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - run_id = runs[0]["id"] - - # Delete the workflow to make it "missing" - delete_resp = client.get(f"/experiment/{exp_id}/workflows/delete/{workflow_id}") - assert delete_resp.status_code == 200 - - # Try to get the run - may return either "Associated workflow not found" or run data - run_resp = client.get(f"/experiment/{exp_id}/workflows/runs/{run_id}") - assert run_resp.status_code == 200 - response_data = run_resp.json() - # Accept either error response or normal response with data - assert ("error" in response_data and "Associated workflow not found" in response_data["error"]) or ( - "run" in response_data and "workflow" in response_data - ) - - -def test_yaml_import(client, experiment_id): - """Test YAML import functionality""" - exp_id = experiment_id - - # Create a test YAML file content - import tempfile - import os - - # Create a temporary YAML file - yaml_content = """ -name: test_imported_workflow -config: - nodes: - - type: START - id: start - name: START - out: [] -""" - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write(yaml_content) - f.flush() - - # Test YAML import - with open(f.name, "rb") as yaml_file: - files = {"file": (f.name, yaml_file, "application/x-yaml")} - import_resp = client.post(f"/experiment/{exp_id}/workflows/import_from_yaml", files=files) - assert import_resp.status_code == 200 - assert import_resp.json() == {"message": "OK"} - - # Clean up - os.unlink(f.name) - - -@pytest.mark.asyncio -async def test_extract_previous_job_outputs_edge_cases(): - """Test extract_previous_job_outputs with various job status scenarios""" - - # Test with empty job_ids (should return None) - result = await wf.check_current_jobs_status("workflow_run_id", []) - assert result is None - - # Test logic for different status values - test_cases = [ - {"status": "FAILED", "expected_contains": "failed"}, - {"status": "CANCELLED", "expected_contains": "cancelled"}, - {"status": "DELETED", "expected_contains": "cancelled"}, - {"status": "STOPPED", "expected_contains": "cancelled"}, - {"status": "RUNNING", "expected_contains": "running"}, - {"status": "QUEUED", "expected_contains": "running"}, - {"status": "COMPLETE", "expected": None}, - ] - - for case in test_cases: - status = case["status"] - # We can't test the actual database calls, but we can verify the logic paths exist - # The function would check these statuses and return appropriate messages - if status == "FAILED": - assert "failed" in case["expected_contains"] - elif status in ["CANCELLED", "DELETED", "STOPPED"]: - assert "cancelled" in case["expected_contains"] - elif status != "COMPLETE": - assert "running" in case["expected_contains"] - - -@pytest.mark.asyncio -async def test_prepare_next_task_io_edge_cases(): - """Test prepare_next_task_io with all branches""" - - # Test TRAIN task with existing inputs and outputs - task_def_train = { - "type": "TRAIN", - "inputs": '{"existing_input": "value", "model_name": "old_model"}', - "outputs": '{"existing_output": "result"}', - } - previous_outputs = {"model_name": "new_model", "dataset_name": "test_dataset"} - - inputs_json, outputs_json = wf.prepare_next_task_io(task_def_train, previous_outputs) - inputs = json.loads(inputs_json) - outputs = json.loads(outputs_json) - - # Should override model_name but keep existing fields - assert inputs["model_name"] == "new_model" - assert inputs["dataset_name"] == "test_dataset" - assert inputs["existing_input"] == "value" - - # Should add adaptor_name and keep existing outputs - assert "adaptor_name" in outputs - assert outputs["existing_output"] == "result" - - # Test EVAL task with partial previous outputs - task_def_eval = {"type": "EVAL", "inputs": "{}", "outputs": "{}"} - partial_outputs = { - "model_name": "test_model" - # Missing other fields - } - - inputs_json, outputs_json = wf.prepare_next_task_io(task_def_eval, partial_outputs) - inputs = json.loads(inputs_json) - - # Should only include the fields that exist in previous_outputs - assert inputs["model_name"] == "test_model" - assert "model_architecture" not in inputs - assert "adaptor_name" not in inputs - assert "dataset_name" not in inputs - - -@pytest.mark.asyncio -async def test_handle_start_node_skip_edge_cases(): - """Test handle_start_node_skip with various scenarios""" - - # Test with empty workflow config - empty_config = {"nodes": []} - result = await wf.handle_start_node_skip([], empty_config, 0) - assert result == [] - - # Test with current tasks that have multiple outputs - workflow_config = { - "nodes": [ - {"id": "task1", "type": "TASK", "out": ["task2", "task3"]}, - {"id": "task2", "type": "TASK", "out": []}, - {"id": "task3", "type": "TASK", "out": []}, - ] - } - - result = await wf.determine_next_tasks(["task1"], workflow_config, 0) - assert set(result) == {"task2", "task3"} # Should get both outputs - - # Test handle_start_node_skip with multiple START nodes - workflow_config = { - "nodes": [ - {"id": "start1", "type": "START", "out": ["task1"]}, - {"id": "start2", "type": "START", "out": ["task2"]}, - {"id": "task1", "type": "TASK", "out": []}, - {"id": "task2", "type": "TASK", "out": []}, - ] - } - - actual_ids, next_nodes = await wf.handle_start_node_skip(["start1", "start2"], workflow_config, 0) - assert set(actual_ids) == {"task1", "task2"} - assert len(next_nodes) == 2 - - -def test_find_previous_node_and_job_logic(client): - """Test find_previous_node and queue_job_for_node logic""" - # Create workflow - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Get the workflow to find the node ID - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/list") - assert workflow_resp.status_code == 200 - workflows = workflow_resp.json() - workflow = next(w for w in workflows if w["id"] == workflow_id) - nodes = json.loads(workflow["config"])["nodes"] - task_node_id = next(n["id"] for n in nodes if n["type"] == "TASK") - - # Test find_previous_node - prev_node = wf.find_previous_node(task_node_id, workflow) - assert prev_node is None # No previous node for single-task workflow - - # Test queue_job_for_node logic - job_id = wf.queue_job_for_node(task_node_id, workflow_id) - assert job_id is not None - - -def test_workflow_active_run_security(client): - """Test workflow execution security and isolation across experiments""" - # Create two experiments for this specific test - exp1_resp = client.get("/experiment/create?name=test_active_run_security1") - assert exp1_resp.status_code == 200 - exp1_id = exp1_resp.json() - - exp2_resp = client.get("/experiment/create?name=test_active_run_security2") - assert exp2_resp.status_code == 200 - exp2_id = exp2_resp.json() - - try: - # Create workflow in experiment 1 - config = { - "nodes": [ - {"type": "START", "id": "start", "name": "START", "out": []}, - ] - } - workflow_resp = client.get( - f"/experiment/{exp1_id}/workflows/create?name=test_workflow&config={json.dumps(config)}" - ) - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Start workflow in experiment 1 - start_resp = client.get(f"/experiment/{exp1_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Verify workflow runs are properly isolated per experiment - runs1_resp = client.get(f"/experiment/{exp1_id}/workflows/runs") - runs2_resp = client.get(f"/experiment/{exp2_id}/workflows/runs") - - assert runs1_resp.status_code == 200 - assert runs2_resp.status_code == 200 - - runs1 = runs1_resp.json() - runs2 = runs2_resp.json() - - # Experiment 1 should have runs, experiment 2 should have none - assert len(runs1) > 0 - assert len(runs2) == 0 - - # Test cross-experiment access - try to access exp1's run from exp2's context - run_id = runs1[0]["id"] - cross_run_resp = client.get(f"/experiment/{exp2_id}/workflows/runs/{run_id}") - assert cross_run_resp.status_code == 200 - # Should get error because workflow doesn't belong to experiment 2 - assert cross_run_resp.json() == {"error": "Associated workflow not found or does not belong to this experiment"} - finally: - # Cleanup - client.get(f"/experiment/delete/{exp1_id}") - client.get(f"/experiment/delete/{exp2_id}") - - -def test_workflow_run_security_checks(client): - """Test security checks for workflow run operations""" - # Create two experiments for this specific test - exp1_resp = client.get("/experiment/create?name=test_workflow_run_security1") - assert exp1_resp.status_code == 200 - exp1_id = exp1_resp.json() - - exp2_resp = client.get("/experiment/create?name=test_workflow_run_security2") - assert exp2_resp.status_code == 200 - exp2_id = exp2_resp.json() - - try: - # Create workflow in experiment 1 - workflow_resp = client.get(f"/experiment/{exp1_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Start workflow to create a run - start_resp = client.get(f"/experiment/{exp1_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Get runs from experiment 1 - runs_resp = client.get(f"/experiment/{exp1_id}/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - assert len(runs) > 0 - run_id = runs[0]["id"] - - # Try to access run from experiment 2 (should fail security check) - run_resp = client.get(f"/experiment/{exp2_id}/workflows/runs/{run_id}") - assert run_resp.status_code == 200 - response_data = run_resp.json() - assert response_data == {"error": "Associated workflow not found or does not belong to this experiment"} - finally: - # Cleanup - client.get(f"/experiment/delete/{exp1_id}") - client.get(f"/experiment/delete/{exp2_id}") - - -@pytest.mark.asyncio -async def test_check_current_jobs_status_edge_cases(): - """Test check_current_jobs_status with various job status scenarios""" - - # Test with empty job_ids (should return None) - result = await wf.check_current_jobs_status("workflow_run_id", []) - assert result is None - - # Test logic for different status values - test_cases = [ - {"status": "FAILED", "expected_contains": "failed"}, - {"status": "CANCELLED", "expected_contains": "cancelled"}, - {"status": "DELETED", "expected_contains": "cancelled"}, - {"status": "STOPPED", "expected_contains": "cancelled"}, - {"status": "RUNNING", "expected_contains": "running"}, - {"status": "QUEUED", "expected_contains": "running"}, - {"status": "COMPLETE", "expected": None}, - ] - - for case in test_cases: - status = case["status"] - # We can't test the actual database calls, but we can verify the logic paths exist - # The function would check these statuses and return appropriate messages - if status == "FAILED": - assert "failed" in case["expected_contains"] - elif status in ["CANCELLED", "DELETED", "STOPPED"]: - assert "cancelled" in case["expected_contains"] - elif status != "COMPLETE": - assert "running" in case["expected_contains"] - - -@pytest.mark.asyncio -async def test_determine_next_tasks_edge_cases(): - """Test determine_next_tasks with edge cases""" - - # Test with empty workflow config - empty_config = {"nodes": []} - result = await wf.determine_next_tasks([], empty_config, 0) - assert result == [] - - # Test with current tasks that have multiple outputs - workflow_config = { - "nodes": [ - {"id": "task1", "type": "TASK", "out": ["task2", "task3"]}, - {"id": "task2", "type": "TASK", "out": []}, - {"id": "task3", "type": "TASK", "out": []}, - ] - } - - result = await wf.determine_next_tasks(["task1"], workflow_config, 0) - assert set(result) == {"task2", "task3"} # Should get both outputs - - # Test handle_start_node_skip with multiple START nodes - workflow_config = { - "nodes": [ - {"id": "start1", "type": "START", "out": ["task1"]}, - {"id": "start2", "type": "START", "out": ["task2"]}, - {"id": "task1", "type": "TASK", "out": []}, - {"id": "task2", "type": "TASK", "out": []}, - ] - } - - actual_ids, next_nodes = await wf.handle_start_node_skip(["start1", "start2"], workflow_config, 0) - assert set(actual_ids) == {"task1", "task2"} - assert len(next_nodes) == 2 - - -def test_extract_previous_job_outputs_complete_coverage(): - """Test extract_previous_job_outputs with comprehensive scenarios""" - - # Test GENERATE job with dataset_id at top level - generate_job_top_level = { - "type": "GENERATE", - "job_data": {"dataset_id": "Top Level Dataset", "config": {"dataset_id": "Config Level Dataset"}}, - } - outputs = wf.extract_previous_job_outputs(generate_job_top_level) - # Should prefer top-level dataset_id - assert outputs["dataset_name"] == "top-level-dataset" - - # Test TRAIN job with only model_name (no adaptor_name) - train_job_model_only = {"type": "TRAIN", "job_data": {"config": {"model_name": "test-model"}}} - outputs = wf.extract_previous_job_outputs(train_job_model_only) - assert outputs["model_name"] == "test-model" - assert "adaptor_name" not in outputs - - # Test TRAIN job with adaptor but no fuse_model - train_job_adaptor_no_fuse = { - "type": "TRAIN", - "job_data": {"config": {"model_name": "test-model", "adaptor_name": "test-adaptor"}}, - } - outputs = wf.extract_previous_job_outputs(train_job_adaptor_no_fuse) - assert outputs["adaptor_name"] == "test-adaptor" - - -def test_prepare_next_task_io_complete_coverage(): - """Test prepare_next_task_io with all branches""" - - # Test TRAIN task with existing inputs and outputs - task_def_train = { - "type": "TRAIN", - "inputs": '{"existing_input": "value", "model_name": "old_model"}', - "outputs": '{"existing_output": "result"}', - } - previous_outputs = {"model_name": "new_model", "dataset_name": "test_dataset"} - - inputs_json, outputs_json = wf.prepare_next_task_io(task_def_train, previous_outputs) - inputs = json.loads(inputs_json) - outputs = json.loads(outputs_json) - - # Should override model_name but keep existing fields - assert inputs["model_name"] == "new_model" - assert inputs["dataset_name"] == "test_dataset" - assert inputs["existing_input"] == "value" - - # Should add adaptor_name and keep existing outputs - assert "adaptor_name" in outputs - assert outputs["existing_output"] == "result" - - # Test EVAL task with partial previous outputs - task_def_eval = {"type": "EVAL", "inputs": "{}", "outputs": "{}"} - partial_outputs = { - "model_name": "test_model" - # Missing other fields - } - - inputs_json, outputs_json = wf.prepare_next_task_io(task_def_eval, partial_outputs) - inputs = json.loads(inputs_json) - - # Should only include the fields that exist in previous_outputs - assert inputs["model_name"] == "test_model" - assert "model_architecture" not in inputs - assert "adaptor_name" not in inputs - assert "dataset_name" not in inputs - - -@pytest.mark.asyncio -async def test_handle_start_node_skip_multiple_starts(): - """Test handle_start_node_skip with multiple START nodes""" - - workflow_config = { - "nodes": [ - {"id": "start1", "type": "START", "out": ["task1"]}, - {"id": "start2", "type": "START", "out": ["task2"]}, - {"id": "task1", "type": "TASK", "out": []}, - {"id": "task2", "type": "TASK", "out": []}, - ] - } - - # Test with multiple START nodes - actual_ids, next_nodes = await wf.handle_start_node_skip(["start1", "start2"], workflow_config, 0) - assert set(actual_ids) == {"task1", "task2"} - assert len(next_nodes) == 2 - - -def test_workflow_create_with_existing_nodes(client, experiment_id): - """Test workflow creation with existing nodes in config""" - # Create workflow with existing nodes - config = {"nodes": [{"type": "TASK", "id": "existing_task", "name": "Existing Task", "out": []}]} - resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow&config={json.dumps(config)}") - assert resp.status_code == 200 - workflow_id = resp.json() - - # Verify the workflow was created with START node prepended - workflows_resp = client.get(f"/experiment/{experiment_id}/workflows/list") - assert workflows_resp.status_code == 200 - workflows = workflows_resp.json() - workflow = next(w for w in workflows if w["id"] == workflow_id) - nodes = json.loads(workflow["config"])["nodes"] - - # Should have START node + the existing task - assert len(nodes) >= 2 - start_nodes = [n for n in nodes if n["type"] == "START"] - task_nodes = [n for n in nodes if n["type"] == "TASK"] - assert len(start_nodes) == 1 - assert len(task_nodes) >= 1 - - -def test_workflow_node_edge_operations(client, experiment_id): - """Test edge addition and removal with various scenarios""" - # Create workflow - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow245") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Add two nodes - node1_data = {"type": "TASK", "name": "Task 1", "task": "task1", "out": []} - node2_data = {"type": "TASK", "name": "Task 2", "task": "task2", "out": []} - - client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/add_node?node={json.dumps(node1_data)}") - client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/add_node?node={json.dumps(node2_data)}") - - # Get node IDs - workflows_resp = client.get(f"/experiment/{experiment_id}/workflows/list") - workflows = workflows_resp.json() - workflow = next(w for w in workflows if w["id"] == workflow_id) - nodes = json.loads(workflow["config"])["nodes"] - task_nodes = [n for n in nodes if n["type"] == "TASK"] - node1_id = task_nodes[0]["id"] - node2_id = task_nodes[1]["id"] - - # Add edge between nodes - add_edge_resp = client.post( - f"/experiment/{experiment_id}/workflows/{workflow_id}/{node1_id}/add_edge?end_node_id={node2_id}" - ) - assert add_edge_resp.status_code == 200 - - # Remove edge between nodes - remove_edge_resp = client.post( - f"/experiment/{experiment_id}/workflows/{workflow_id}/{node1_id}/remove_edge?end_node_id={node2_id}" - ) - assert remove_edge_resp.status_code == 200 - - # Try to remove non-existent edge (should still work) - remove_edge_resp = client.post( - f"/experiment/{experiment_id}/workflows/{workflow_id}/{node1_id}/remove_edge?end_node_id={node2_id}" - ) - assert remove_edge_resp.status_code == 200 - - # Cleanup: delete the workflow - client.get(f"/experiment/{experiment_id}/workflows/delete/{workflow_id}") - - -def test_workflow_node_deletion_with_connections(client, experiment_id): - """Test node deletion when node has connections""" - # Create workflow - workflow_resp = client.get(f"/experiment/{experiment_id}/workflows/create?name=test_workflow897") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Add three nodes in sequence - node1_data = {"type": "TASK", "name": "Task 1", "task": "task1", "out": []} - node2_data = {"type": "TASK", "name": "Task 2", "task": "task2", "out": []} - node3_data = {"type": "TASK", "name": "Task 3", "task": "task3", "out": []} - - client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/add_node?node={json.dumps(node1_data)}") - client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/add_node?node={json.dumps(node2_data)}") - client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/add_node?node={json.dumps(node3_data)}") - - # Get node IDs - workflows_resp = client.get(f"/experiment/{experiment_id}/workflows/list") - workflows = workflows_resp.json() - workflow = next(w for w in workflows if w["id"] == workflow_id) - nodes = json.loads(workflow["config"])["nodes"] - task_nodes = [n for n in nodes if n["type"] == "TASK"] - node1_id, node2_id, node3_id = task_nodes[0]["id"], task_nodes[1]["id"], task_nodes[2]["id"] - - # Create connections: node1 -> node2 -> node3 - client.post(f"/experiment/{experiment_id}/workflows/{workflow_id}/{node1_id}/add_edge?end_node_id={node2_id}") - client.post(f"/experiment/{experiment_id}/workflows/{workflow_id}/{node2_id}/add_edge?end_node_id={node3_id}") - - # Delete middle node (node2) - should connect node1 to node3 - delete_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/{node2_id}/delete_node") - assert delete_resp.status_code == 200 - - # Verify the connections were updated - workflows_resp = client.get(f"/experiment/{experiment_id}/workflows/list") - workflows = workflows_resp.json() - workflow = next(w for w in workflows if w["id"] == workflow_id) - nodes = json.loads(workflow["config"])["nodes"] - - # node2 should be gone - remaining_task_nodes = [n for n in nodes if n["type"] == "TASK"] - assert len(remaining_task_nodes) == 2 - - # node1 should now connect to node3 - node1 = next(n for n in nodes if n["id"] == node1_id) - assert node3_id in node1["out"] - - -def test_workflow_empty_node_operations(client): - """Test operations on workflows with empty or minimal nodes""" - # Create experiment for this specific test - exp_resp = client.get("/experiment/create?name=test_workflow_empty_ops") - assert exp_resp.status_code == 200 - exp_id = exp_resp.json() - - # Create empty workflow - workflow_resp = client.get(f"/experiment/{exp_id}/workflows/create_empty?name=empty_workflow") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Get the START node ID - workflows_resp = client.get(f"/experiment/{exp_id}/workflows/list") - workflows = workflows_resp.json() - workflow = next(w for w in workflows if w["id"] == workflow_id) - nodes = json.loads(workflow["config"])["nodes"] - start_node = next(n for n in nodes if n["type"] == "START") - start_node_id = start_node["id"] - - # Try various operations on empty workflow - # Add edge from START to non-existent node (should work) - add_edge_resp = client.post( - f"/experiment/{exp_id}/workflows/{workflow_id}/{start_node_id}/add_edge?end_node_id=nonexistent" - ) - assert add_edge_resp.status_code == 200 - - # Remove edge that doesn't exist - remove_edge_resp = client.post( - f"/experiment/{exp_id}/workflows/{workflow_id}/{start_node_id}/remove_edge?end_node_id=nonexistent" - ) - assert remove_edge_resp.status_code == 200 - - # Try to edit metadata of START node - metadata = {"description": "Start node"} - metadata_resp = client.get( - f"/experiment/{exp_id}/workflows/{workflow_id}/{start_node_id}/edit_node_metadata?metadata={json.dumps(metadata)}" - ) - assert metadata_resp.status_code == 200 - - -def test_find_nodes_by_ids_comprehensive(): - """Test find_nodes_by_ids with comprehensive scenarios""" - - nodes = [ - {"id": "a", "type": "START"}, - {"id": "b", "type": "TASK"}, - {"id": "c", "type": "TASK"}, - {"id": "d", "type": "TASK"}, - ] - - # Test multiple IDs - result = wf.find_nodes_by_ids(["a", "c"], nodes) - assert len(result) == 2 - assert result[0]["id"] == "a" - assert result[1]["id"] == "c" - - # Test non-existent IDs - result = wf.find_nodes_by_ids(["x", "y"], nodes) - assert result == [] - - # Test mixed existing and non-existent - result = wf.find_nodes_by_ids(["a", "x", "c"], nodes) - assert len(result) == 2 - assert result[0]["id"] == "a" - assert result[1]["id"] == "c" - - # Test duplicate IDs - result = wf.find_nodes_by_ids(["a", "a", "b"], nodes) - assert len(result) == 2 # Should not duplicate - - -def test_workflow_run_with_job_data_edge_cases(client): - """Test workflow run with various job data scenarios""" - # Create experiment for this specific test - exp_resp = client.get("/experiment/create?name=test_job_data_edges") - assert exp_resp.status_code == 200 - exp_id = exp_resp.json() - - try: - workflow_resp = client.get(f"/experiment/{exp_id}/workflows/create?name=test_workflow") - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Start workflow to create a run - start_resp = client.get(f"/experiment/{exp_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Get workflow runs - runs_resp = client.get(f"/experiment/{exp_id}/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - run_id = runs[0]["id"] - - # Get specific run to test job data parsing paths - run_resp = client.get(f"/experiment/{exp_id}/workflows/runs/{run_id}") - assert run_resp.status_code == 200 - run_data = run_resp.json() - - # This should cover lines 322 (job_get), 324 (continue if no job), - # 326 (job_info creation), 332 (safe job_data get), 334 (empty job_data), - # 346-348 (JSON decode error handling) - assert "jobs" in run_data - assert isinstance(run_data["jobs"], list) - finally: - # Cleanup - client.get(f"/experiment/delete/{exp_id}") - - -@pytest.mark.skip(reason="Skipping complex workflow test because it doesn't always work") -def test_workflow_next_step_with_complex_scenarios(client): - """Test complex workflow scenarios through API execution""" - - # Test 1: Multi-step workflow creation and execution - config = { - "nodes": [ - {"type": "START", "id": "start", "name": "START", "out": ["task1"]}, - {"type": "TASK", "id": "task1", "name": "Task 1", "task": "test_task", "out": ["task2"]}, - {"type": "TASK", "id": "task2", "name": "Task 2", "task": "test_task", "out": []}, - ] - } - workflow_resp = client.get( - f"/experiment/{experiment_id}/workflows/create?name=complex_workflow&config={json.dumps(config)}" - ) - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - # Start workflow - start_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Verify workflow run was created and is in progress - runs_resp = client.get(f"/experiment/{experiment_id}/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - assert len(runs) > 0 - assert runs[0]["status"] in ["RUNNING", "QUEUED"] - assert runs[0]["workflow_id"] == workflow_id - - # Test 2: Empty workflow (only START node with no outputs) - config = { - "nodes": [ - {"type": "START", "id": "start", "name": "START", "out": []}, - ] - } - workflow_resp = client.get( - f"/experiment/{experiment_id}/workflows/create?name=empty_workflow&config={json.dumps(config)}" - ) - assert workflow_resp.status_code == 200 - workflow_id = workflow_resp.json() - - start_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Verify run was created - runs_resp = client.get(f"/experiment/{experiment_id}/workflows/runs") - assert runs_resp.status_code == 200 - runs = runs_resp.json() - assert len(runs) > 0 - - # Test 3: Multiple workflow executions - for i in range(2): - config = {"nodes": [{"type": "START", "id": f"start_{i}", "name": f"START {i}", "out": []}]} - workflow_resp = client.get( - f"/experiment/{experiment_id}/workflows/create?name=workflow_{i}&config={json.dumps(config)}" - ) - workflow_id = workflow_resp.json() - start_resp = client.get(f"/experiment/{experiment_id}/workflows/{workflow_id}/start") - assert start_resp.status_code == 200 - - # Should have multiple runs - runs_resp = client.get(f"/experiment/{experiment_id}/workflows/runs") - runs = runs_resp.json() - assert len(runs) >= 2 diff --git a/api/transformerlab/db/constants.py b/api/transformerlab/db/constants.py index 130e9a587..ee49ca60b 100644 --- a/api/transformerlab/db/constants.py +++ b/api/transformerlab/db/constants.py @@ -1,6 +1,8 @@ # --- Centralized Database Configuration --- +import os from lab import HOME_DIR db = None # This will hold the aiosqlite connection DATABASE_FILE_NAME = f"{HOME_DIR}/llmlab.sqlite3" -DATABASE_URL = f"sqlite+aiosqlite:///{DATABASE_FILE_NAME}" +# Allow DATABASE_URL to be overridden by environment variable (useful for testing) +DATABASE_URL = os.getenv("DATABASE_URL", f"sqlite+aiosqlite:///{DATABASE_FILE_NAME}") diff --git a/api/transformerlab/db/session.py b/api/transformerlab/db/session.py index cbb1f7a83..4a24f9ed1 100644 --- a/api/transformerlab/db/session.py +++ b/api/transformerlab/db/session.py @@ -37,12 +37,15 @@ async def run_alembic_migrations(): api_dir = os.path.dirname(os.path.dirname(current_dir)) # Run alembic upgrade head + # Pass environment variables to ensure DATABASE_URL is available in subprocess + env = os.environ.copy() result = subprocess.run( [sys.executable, "-m", "alembic", "upgrade", "head"], cwd=api_dir, capture_output=True, text=True, check=False, + env=env, ) if result.returncode != 0: @@ -92,25 +95,32 @@ async def init(): # This replaces the previous create_all() call await run_alembic_migrations() - # Check if experiment_id column exists in workflow_runs table - cursor = await db.execute("PRAGMA table_info(workflow_runs)") - columns = await cursor.fetchall() - has_experiment_id = any(column[1] == "experiment_id" for column in columns) - - if not has_experiment_id: - # Add experiment_id column - await db.execute("ALTER TABLE workflow_runs ADD COLUMN experiment_id INTEGER") - - # Update existing workflow runs with experiment_id from their workflows - await db.execute(""" - UPDATE workflow_runs - SET experiment_id = ( - SELECT experiment_id - FROM workflows - WHERE workflows.id = workflow_runs.workflow_id - ) - """) - await db.commit() + # Check if workflow_runs table exists before checking/modifying columns + cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='workflow_runs'") + table_exists = await cursor.fetchone() + await cursor.close() + + if table_exists: + # Check if experiment_id column exists in workflow_runs table + cursor = await db.execute("PRAGMA table_info(workflow_runs)") + columns = await cursor.fetchall() + await cursor.close() + has_experiment_id = any(column[1] == "experiment_id" for column in columns) + + if not has_experiment_id: + # Add experiment_id column + await db.execute("ALTER TABLE workflow_runs ADD COLUMN experiment_id INTEGER") + + # Update existing workflow runs with experiment_id from their workflows + await db.execute(""" + UPDATE workflow_runs + SET experiment_id = ( + SELECT experiment_id + FROM workflows + WHERE workflows.id = workflow_runs.workflow_id + ) + """) + await db.commit() print("✅ Database initialized") diff --git a/api/transformerlab/shared/models/user_model.py b/api/transformerlab/shared/models/user_model.py index 08458bad3..17acc4b9d 100644 --- a/api/transformerlab/shared/models/user_model.py +++ b/api/transformerlab/shared/models/user_model.py @@ -9,7 +9,7 @@ from os import getenv from transformerlab.db.constants import DATABASE_URL -from transformerlab.shared.models.models import Base, Team, User, OAuthAccount +from transformerlab.shared.models.models import Team, User, OAuthAccount from transformerlab.shared.remote_workspace import create_bucket_for_team @@ -18,12 +18,6 @@ AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) -# 4. Utility to create tables (run this on app startup) -async def create_db_and_tables(): - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - # 5. Database session dependency async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async with AsyncSessionLocal() as session: