|
23 | 23 | os.environ["TRANSFORMERLAB_REFRESH_SECRET"] = "test-refresh-secret-for-testing-only" |
24 | 24 | os.environ["EMAIL_METHOD"] = "dev" # Use dev mode for tests (no actual email sending) |
25 | 25 |
|
26 | | -# Use in-memory database for tests |
27 | | -os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:" |
| 26 | +# Use temporary file-based database for tests (easier to debug than in-memory) |
| 27 | +test_db_dir = os.path.join("test", "tmp", "db") |
| 28 | +os.makedirs(test_db_dir, exist_ok=True) |
| 29 | +test_db_path = os.path.join(test_db_dir, "test_llmlab.sqlite3") |
| 30 | +os.environ["DATABASE_URL"] = f"sqlite+aiosqlite:///{test_db_path}" |
28 | 31 |
|
29 | 32 | from api import app # noqa: E402 |
30 | 33 |
|
@@ -72,13 +75,43 @@ def request(self, method, url, **kwargs): |
72 | 75 | return super().request(method, url, **kwargs) |
73 | 76 |
|
74 | 77 |
|
| 78 | +@pytest.fixture(scope="session", autouse=True) |
| 79 | +def cleanup_test_db(): |
| 80 | + """Clean up test database file after all tests complete""" |
| 81 | + yield |
| 82 | + # Clean up database file and related files (WAL, SHM) |
| 83 | + test_db_path = os.path.join("test", "tmp", "db", "test_llmlab.sqlite3") |
| 84 | + for ext in ["", "-wal", "-shm"]: |
| 85 | + db_file = test_db_path + ext |
| 86 | + if os.path.exists(db_file): |
| 87 | + try: |
| 88 | + os.remove(db_file) |
| 89 | + except OSError: |
| 90 | + pass # Ignore errors if file is locked or already removed |
| 91 | + |
| 92 | + |
75 | 93 | @pytest.fixture(scope="session") |
76 | 94 | def client(): |
77 | | - # Initialize database tables for tests |
78 | | - from transformerlab.shared.models.user_model import create_db_and_tables # noqa: E402 |
| 95 | + # Initialize database tables for tests using Alembic migrations (same as production) |
| 96 | + from transformerlab.db.session import run_alembic_migrations # noqa: E402 |
79 | 97 | from transformerlab.services.experiment_init import seed_default_admin_user # noqa: E402 |
80 | 98 |
|
81 | | - asyncio.run(create_db_and_tables()) |
| 99 | + # Ensure test database directory exists |
| 100 | + test_db_dir = os.path.join("test", "tmp", "db") |
| 101 | + os.makedirs(test_db_dir, exist_ok=True) |
| 102 | + |
| 103 | + # Remove existing test database if it exists (start fresh) |
| 104 | + test_db_path = os.path.join(test_db_dir, "test_llmlab.sqlite3") |
| 105 | + for ext in ["", "-wal", "-shm"]: |
| 106 | + db_file = test_db_path + ext |
| 107 | + if os.path.exists(db_file): |
| 108 | + try: |
| 109 | + os.remove(db_file) |
| 110 | + except OSError: |
| 111 | + pass |
| 112 | + |
| 113 | + # Run Alembic migrations to create database schema (matches production) |
| 114 | + asyncio.run(run_alembic_migrations()) |
82 | 115 | asyncio.run(seed_default_admin_user()) |
83 | 116 | controller_log_dir = os.path.join("test", "tmp", "workspace", "logs") |
84 | 117 | os.makedirs(controller_log_dir, exist_ok=True) |
|
0 commit comments