Skip to content

Commit b2ef5c9

Browse files
committed
rx.Model no longer inherits from rx.Base (and friends)
* Move sqlalchemy-only helpers out of the sqlmodel section * Remove sqlmodel pydantic v2 hacks -- just use pydantic v2 now * Import guard other parts of the code that were importing sqlalchemy
1 parent a33e8fb commit b2ef5c9

File tree

3 files changed

+171
-190
lines changed

3 files changed

+171
-190
lines changed

reflex/model.py

Lines changed: 154 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from importlib.util import find_spec
77
from typing import TYPE_CHECKING, Any, ClassVar
88

9-
from reflex.base import Base
109
from reflex.config import get_config
1110
from reflex.environment import environment
1211
from reflex.utils import console
12+
from reflex.utils.compat import sqlmodel_field_has_primary_key
1313

1414
if TYPE_CHECKING:
1515
import sqlalchemy
@@ -20,6 +20,18 @@
2020
)
2121

2222

23+
def _safe_db_url_for_logging(url: str) -> str:
24+
"""Remove username and password from the database URL for logging.
25+
26+
Args:
27+
url: The database URL.
28+
29+
Returns:
30+
The database URL with the username and password removed.
31+
"""
32+
return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
33+
34+
2335
def _print_db_not_available(*args, **kwargs):
2436
msg = (
2537
"Database is not available. Please install the required packages: "
@@ -35,6 +47,108 @@ def __init__(self, *args, **kwargs):
3547

3648
if find_spec("sqlalchemy"):
3749
import sqlalchemy
50+
import sqlalchemy.exc
51+
import sqlalchemy.ext.asyncio
52+
import sqlalchemy.orm
53+
54+
_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
55+
_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
56+
57+
def get_engine_args(url: str | None = None) -> dict[str, Any]:
58+
"""Get the database engine arguments.
59+
60+
Args:
61+
url: The database url.
62+
63+
Returns:
64+
The database engine arguments as a dict.
65+
"""
66+
kwargs: dict[str, Any] = {
67+
# Print the SQL queries if the log level is INFO or lower.
68+
"echo": environment.SQLALCHEMY_ECHO.get(),
69+
# Check connections before returning them.
70+
"pool_pre_ping": environment.SQLALCHEMY_POOL_PRE_PING.get(),
71+
}
72+
conf = get_config()
73+
url = url or conf.db_url
74+
if url is not None and url.startswith("sqlite"):
75+
# Needed for the admin dash on sqlite.
76+
kwargs["connect_args"] = {"check_same_thread": False}
77+
return kwargs
78+
79+
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
80+
"""Get the database engine.
81+
82+
Args:
83+
url: the DB url to use.
84+
85+
Returns:
86+
The database engine.
87+
88+
Raises:
89+
ValueError: If the database url is None.
90+
"""
91+
conf = get_config()
92+
url = url or conf.db_url
93+
if url is None:
94+
msg = "No database url configured"
95+
raise ValueError(msg)
96+
97+
global _ENGINE
98+
if url in _ENGINE:
99+
return _ENGINE[url]
100+
101+
if not environment.ALEMBIC_CONFIG.get().exists():
102+
console.warn(
103+
"Database is not initialized, run [bold]reflex db init[/bold] first."
104+
)
105+
_ENGINE[url] = sqlalchemy.engine.create_engine(
106+
url,
107+
**get_engine_args(url),
108+
)
109+
return _ENGINE[url]
110+
111+
def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
112+
"""Get the async database engine.
113+
114+
Args:
115+
url: The database url.
116+
117+
Returns:
118+
The async database engine.
119+
120+
Raises:
121+
ValueError: If the async database url is None.
122+
"""
123+
if url is None:
124+
conf = get_config()
125+
url = conf.async_db_url
126+
if url is not None and conf.db_url is not None:
127+
async_db_url_tail = url.partition("://")[2]
128+
db_url_tail = conf.db_url.partition("://")[2]
129+
if async_db_url_tail != db_url_tail:
130+
console.warn(
131+
f"async_db_url `{_safe_db_url_for_logging(url)}` "
132+
"should reference the same database as "
133+
f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
134+
)
135+
if url is None:
136+
msg = "No async database url configured"
137+
raise ValueError(msg)
138+
139+
global _ASYNC_ENGINE
140+
if url in _ASYNC_ENGINE:
141+
return _ASYNC_ENGINE[url]
142+
143+
if not environment.ALEMBIC_CONFIG.get().exists():
144+
console.warn(
145+
"Database is not initialized, run [bold]reflex db init[/bold] first."
146+
)
147+
_ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
148+
url,
149+
**get_engine_args(url),
150+
)
151+
return _ASYNC_ENGINE[url]
38152

39153
def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
40154
"""Get a bare sqlalchemy session to interact with the database.
@@ -124,6 +238,9 @@ def get_metadata(cls) -> sqlalchemy.MetaData:
124238
return metadata
125239

126240
else:
241+
get_engine_args = _print_db_not_available
242+
get_engine = _print_db_not_available
243+
get_async_engine = _print_db_not_available
127244
sqla_session = _print_db_not_available
128245
ModelRegistry = _ClassThatErrorsOnInit # pyright: ignore [reportAssignmentType]
129246

@@ -134,38 +251,15 @@ def get_metadata(cls) -> sqlalchemy.MetaData:
134251
import alembic.operations.ops
135252
import alembic.runtime.environment
136253
import alembic.script
137-
import alembic.util
138-
import sqlalchemy
139-
import sqlalchemy.exc
140-
import sqlalchemy.ext.asyncio
141-
import sqlalchemy.orm
254+
import sqlmodel
142255
from alembic.runtime.migration import MigrationContext
143256
from alembic.script.base import Script
257+
from pydantic import ConfigDict
258+
from sqlmodel._compat import IS_PYDANTIC_V2
259+
from sqlmodel.ext.asyncio.session import AsyncSession
144260

145-
from reflex.utils.compat import sqlmodel
146-
147-
def _sqlmodel_field_has_primary_key(field: Any) -> bool:
148-
"""Determines if a field is a primary.
149-
150-
Args:
151-
field: a rx.model field
152-
153-
Returns:
154-
If field is a primary key (Bool)
155-
"""
156-
if getattr(field.field_info, "primary_key", None) is True:
157-
return True
158-
if getattr(field.field_info, "sa_column", None) is None:
159-
return False
160-
return bool(getattr(field.field_info.sa_column, "primary_key", None))
161-
162-
_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
163-
_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
164261
_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
165262

166-
# Import AsyncSession _after_ reflex.utils.compat
167-
from sqlmodel.ext.asyncio.session import AsyncSession
168-
169263
def format_revision(
170264
rev: Script,
171265
current_rev: str | None,
@@ -200,113 +294,6 @@ def format_revision(
200294
# Format output with message
201295
return f" [{status_icon}] {current}{head_marker}, {message}"
202296

203-
def _safe_db_url_for_logging(url: str) -> str:
204-
"""Remove username and password from the database URL for logging.
205-
206-
Args:
207-
url: The database URL.
208-
209-
Returns:
210-
The database URL with the username and password removed.
211-
"""
212-
return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
213-
214-
def get_engine_args(url: str | None = None) -> dict[str, Any]:
215-
"""Get the database engine arguments.
216-
217-
Args:
218-
url: The database url.
219-
220-
Returns:
221-
The database engine arguments as a dict.
222-
"""
223-
kwargs: dict[str, Any] = {
224-
# Print the SQL queries if the log level is INFO or lower.
225-
"echo": environment.SQLALCHEMY_ECHO.get(),
226-
# Check connections before returning them.
227-
"pool_pre_ping": environment.SQLALCHEMY_POOL_PRE_PING.get(),
228-
}
229-
conf = get_config()
230-
url = url or conf.db_url
231-
if url is not None and url.startswith("sqlite"):
232-
# Needed for the admin dash on sqlite.
233-
kwargs["connect_args"] = {"check_same_thread": False}
234-
return kwargs
235-
236-
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
237-
"""Get the database engine.
238-
239-
Args:
240-
url: the DB url to use.
241-
242-
Returns:
243-
The database engine.
244-
245-
Raises:
246-
ValueError: If the database url is None.
247-
"""
248-
conf = get_config()
249-
url = url or conf.db_url
250-
if url is None:
251-
msg = "No database url configured"
252-
raise ValueError(msg)
253-
254-
global _ENGINE
255-
if url in _ENGINE:
256-
return _ENGINE[url]
257-
258-
if not environment.ALEMBIC_CONFIG.get().exists():
259-
console.warn(
260-
"Database is not initialized, run [bold]reflex db init[/bold] first."
261-
)
262-
_ENGINE[url] = sqlmodel.create_engine(
263-
url,
264-
**get_engine_args(url),
265-
)
266-
return _ENGINE[url]
267-
268-
def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
269-
"""Get the async database engine.
270-
271-
Args:
272-
url: The database url.
273-
274-
Returns:
275-
The async database engine.
276-
277-
Raises:
278-
ValueError: If the async database url is None.
279-
"""
280-
if url is None:
281-
conf = get_config()
282-
url = conf.async_db_url
283-
if url is not None and conf.db_url is not None:
284-
async_db_url_tail = url.partition("://")[2]
285-
db_url_tail = conf.db_url.partition("://")[2]
286-
if async_db_url_tail != db_url_tail:
287-
console.warn(
288-
f"async_db_url `{_safe_db_url_for_logging(url)}` "
289-
"should reference the same database as "
290-
f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
291-
)
292-
if url is None:
293-
msg = "No async database url configured"
294-
raise ValueError(msg)
295-
296-
global _ASYNC_ENGINE
297-
if url in _ASYNC_ENGINE:
298-
return _ASYNC_ENGINE[url]
299-
300-
if not environment.ALEMBIC_CONFIG.get().exists():
301-
console.warn(
302-
"Database is not initialized, run [bold]reflex db init[/bold] first."
303-
)
304-
_ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
305-
url,
306-
**get_engine_args(url),
307-
)
308-
return _ASYNC_ENGINE[url]
309-
310297
async def get_db_status() -> dict[str, bool]:
311298
"""Checks the status of the database connection.
312299
@@ -325,18 +312,35 @@ async def get_db_status() -> dict[str, bool]:
325312

326313
return {"db": status}
327314

328-
class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssues,reportIncompatibleVariableOverride]
315+
class Model(sqlmodel.SQLModel):
329316
"""Base class to define a table in the database."""
330317

331318
# The primary key for the table.
332319
id: int | None = sqlmodel.Field(default=None, primary_key=True)
333320

321+
if IS_PYDANTIC_V2:
322+
model_config = ConfigDict( # pyright: ignore [reportAssignmentType]
323+
arbitrary_types_allowed=True,
324+
extra="allow",
325+
use_enum_values=True,
326+
from_attributes=True,
327+
)
328+
else:
329+
330+
class Config: # pyright: ignore [reportIncompatibleVariableOverride]
331+
"""Pydantic V1 config."""
332+
333+
arbitrary_types_allowed = True
334+
use_enum_values = True
335+
extra = "allow"
336+
orm_mode = True
337+
334338
def __init_subclass__(cls):
335339
"""Drop the default primary key field if any primary key field is defined."""
336340
non_default_primary_key_fields = [
337341
field_name
338342
for field_name, field in cls.__fields__.items()
339-
if field_name != "id" and _sqlmodel_field_has_primary_key(field)
343+
if field_name != "id" and sqlmodel_field_has_primary_key(field)
340344
]
341345
if non_default_primary_key_fields:
342346
cls.__fields__.pop("id", None)
@@ -381,6 +385,19 @@ def dict(self, **kwargs):
381385
**relationships,
382386
}
383387

388+
def json(self) -> str:
389+
"""Convert the object to a json string.
390+
391+
Returns:
392+
The object as a json string.
393+
"""
394+
from reflex.utils.serializers import serialize
395+
396+
return self.__config__.json_dumps(
397+
self.dict(),
398+
default=serialize,
399+
)
400+
384401
@staticmethod
385402
def create_all():
386403
"""Create all the tables."""
@@ -643,9 +660,6 @@ def asession(url: str | None = None) -> AsyncSession:
643660
return _AsyncSessionLocal[url]()
644661

645662
else:
646-
get_engine_args = _print_db_not_available
647-
get_engine = _print_db_not_available
648-
get_async_engine = _print_db_not_available
649663
get_db_status = _print_db_not_available
650664
session = _print_db_not_available
651665
asession = _print_db_not_available

0 commit comments

Comments
 (0)