Skip to content

Commit ba6cdf7

Browse files
Merge pull request #4139 from zenml-io/bug/4120-fix-skip-migration-logic
bug:4120 Fix skip migration logic
2 parents c882555 + 3c14732 commit ba6cdf7

File tree

3 files changed

+69
-5
lines changed

3 files changed

+69
-5
lines changed

src/zenml/zen_stores/sql_zen_store.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,6 +1248,14 @@ def filter_and_paginate(
12481248
# Initialization and configuration
12491249
# --------------------------------
12501250

1251+
def _run_migrations(self) -> None:
1252+
if self.skip_migrations or handle_bool_env_var(
1253+
ENV_ZENML_DISABLE_DATABASE_MIGRATION
1254+
):
1255+
logger.debug("Skipping database migration.")
1256+
else:
1257+
self.migrate_database()
1258+
12511259
def _initialize(self) -> None:
12521260
"""Initialize the SQL store."""
12531261
logger.debug("Initializing SqlZenStore at %s", self.config.url)
@@ -1284,11 +1292,7 @@ def _initialize(self) -> None:
12841292

12851293
self._alembic = Alembic(self.engine)
12861294

1287-
if (
1288-
not self.skip_migrations
1289-
and ENV_ZENML_DISABLE_DATABASE_MIGRATION not in os.environ
1290-
):
1291-
self.migrate_database()
1295+
self._run_migrations()
12921296

12931297
if self.config.driver == SQLDatabaseDriver.SQLITE:
12941298
# Enable foreign key checks at the SQLite database level, but only

tests/unit/zen_stores/__init__.py

Whitespace-only changes.
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
5+
from zenml.client import Client
6+
from zenml.constants import ENV_ZENML_DISABLE_DATABASE_MIGRATION
7+
from zenml.zen_stores.sql_zen_store import SqlZenStore
8+
9+
10+
def test_run_migrations_helper_func(monkeypatch):
11+
store = Client().zen_store
12+
13+
if not isinstance(store, SqlZenStore):
14+
pytest.skip(
15+
"Run migration helper function is testable only for SQL ZenML store"
16+
)
17+
18+
fake_migrate = MagicMock(return_value=None)
19+
20+
with monkeypatch.context() as m:
21+
m.setattr(SqlZenStore, "migrate_database", fake_migrate)
22+
store.skip_migrations = False
23+
m.setenv(ENV_ZENML_DISABLE_DATABASE_MIGRATION, "false")
24+
25+
store._run_migrations()
26+
27+
fake_migrate.assert_called_once()
28+
29+
30+
def test_run_run_migrations_skipped(monkeypatch):
31+
store = Client().zen_store
32+
33+
if not isinstance(store, SqlZenStore):
34+
pytest.skip(
35+
"Run migration helper function is testable only for SQL ZenML store"
36+
)
37+
38+
fake_migrate = MagicMock(return_value=None)
39+
40+
# check skip migrations via store.skip_migrations works
41+
42+
with monkeypatch.context() as m:
43+
m.setattr(SqlZenStore, "migrate_database", fake_migrate)
44+
store.skip_migrations = True
45+
m.setenv(ENV_ZENML_DISABLE_DATABASE_MIGRATION, "false")
46+
47+
store._run_migrations()
48+
49+
fake_migrate.assert_not_called()
50+
51+
# check skip migrations via env var works
52+
53+
with monkeypatch.context() as m:
54+
m.setattr(SqlZenStore, "migrate_database", fake_migrate)
55+
store.skip_migrations = False
56+
m.setenv(ENV_ZENML_DISABLE_DATABASE_MIGRATION, "true")
57+
58+
store._run_migrations()
59+
60+
fake_migrate.assert_not_called()

0 commit comments

Comments
 (0)