66from importlib .util import find_spec
77from typing import TYPE_CHECKING , Any , ClassVar
88
9- from reflex .base import Base
109from reflex .config import get_config
1110from reflex .environment import environment
1211from reflex .utils import console
12+ from reflex .utils .compat import sqlmodel_field_has_primary_key
1313
1414if TYPE_CHECKING :
1515 import sqlalchemy
2020 )
2121
2222
23+ def _safe_db_url_for_logging (url : str ) -> str :
24+ """Remove username and password from the database URL for logging.
25+
26+ Args:
27+ url: The database URL.
28+
29+ Returns:
30+ The database URL with the username and password removed.
31+ """
32+ return re .sub (r"://[^@]+@" , "://<username>:<password>@" , url )
33+
34+
2335def _print_db_not_available (* args , ** kwargs ):
2436 msg = (
2537 "Database is not available. Please install the required packages: "
@@ -35,6 +47,108 @@ def __init__(self, *args, **kwargs):
3547
3648if find_spec ("sqlalchemy" ):
3749 import sqlalchemy
50+ import sqlalchemy .exc
51+ import sqlalchemy .ext .asyncio
52+ import sqlalchemy .orm
53+
54+ _ENGINE : dict [str , sqlalchemy .engine .Engine ] = {}
55+ _ASYNC_ENGINE : dict [str , sqlalchemy .ext .asyncio .AsyncEngine ] = {}
56+
57+ def get_engine_args (url : str | None = None ) -> dict [str , Any ]:
58+ """Get the database engine arguments.
59+
60+ Args:
61+ url: The database url.
62+
63+ Returns:
64+ The database engine arguments as a dict.
65+ """
66+ kwargs : dict [str , Any ] = {
67+ # Print the SQL queries if the log level is INFO or lower.
68+ "echo" : environment .SQLALCHEMY_ECHO .get (),
69+ # Check connections before returning them.
70+ "pool_pre_ping" : environment .SQLALCHEMY_POOL_PRE_PING .get (),
71+ }
72+ conf = get_config ()
73+ url = url or conf .db_url
74+ if url is not None and url .startswith ("sqlite" ):
75+ # Needed for the admin dash on sqlite.
76+ kwargs ["connect_args" ] = {"check_same_thread" : False }
77+ return kwargs
78+
79+ def get_engine (url : str | None = None ) -> sqlalchemy .engine .Engine :
80+ """Get the database engine.
81+
82+ Args:
83+ url: the DB url to use.
84+
85+ Returns:
86+ The database engine.
87+
88+ Raises:
89+ ValueError: If the database url is None.
90+ """
91+ conf = get_config ()
92+ url = url or conf .db_url
93+ if url is None :
94+ msg = "No database url configured"
95+ raise ValueError (msg )
96+
97+ global _ENGINE
98+ if url in _ENGINE :
99+ return _ENGINE [url ]
100+
101+ if not environment .ALEMBIC_CONFIG .get ().exists ():
102+ console .warn (
103+ "Database is not initialized, run [bold]reflex db init[/bold] first."
104+ )
105+ _ENGINE [url ] = sqlalchemy .engine .create_engine (
106+ url ,
107+ ** get_engine_args (url ),
108+ )
109+ return _ENGINE [url ]
110+
111+ def get_async_engine (url : str | None ) -> sqlalchemy .ext .asyncio .AsyncEngine :
112+ """Get the async database engine.
113+
114+ Args:
115+ url: The database url.
116+
117+ Returns:
118+ The async database engine.
119+
120+ Raises:
121+ ValueError: If the async database url is None.
122+ """
123+ if url is None :
124+ conf = get_config ()
125+ url = conf .async_db_url
126+ if url is not None and conf .db_url is not None :
127+ async_db_url_tail = url .partition ("://" )[2 ]
128+ db_url_tail = conf .db_url .partition ("://" )[2 ]
129+ if async_db_url_tail != db_url_tail :
130+ console .warn (
131+ f"async_db_url `{ _safe_db_url_for_logging (url )} ` "
132+ "should reference the same database as "
133+ f"db_url `{ _safe_db_url_for_logging (conf .db_url )} `."
134+ )
135+ if url is None :
136+ msg = "No async database url configured"
137+ raise ValueError (msg )
138+
139+ global _ASYNC_ENGINE
140+ if url in _ASYNC_ENGINE :
141+ return _ASYNC_ENGINE [url ]
142+
143+ if not environment .ALEMBIC_CONFIG .get ().exists ():
144+ console .warn (
145+ "Database is not initialized, run [bold]reflex db init[/bold] first."
146+ )
147+ _ASYNC_ENGINE [url ] = sqlalchemy .ext .asyncio .create_async_engine (
148+ url ,
149+ ** get_engine_args (url ),
150+ )
151+ return _ASYNC_ENGINE [url ]
38152
39153 def sqla_session (url : str | None = None ) -> sqlalchemy .orm .Session :
40154 """Get a bare sqlalchemy session to interact with the database.
@@ -124,6 +238,9 @@ def get_metadata(cls) -> sqlalchemy.MetaData:
124238 return metadata
125239
126240else :
241+ get_engine_args = _print_db_not_available
242+ get_engine = _print_db_not_available
243+ get_async_engine = _print_db_not_available
127244 sqla_session = _print_db_not_available
128245 ModelRegistry = _ClassThatErrorsOnInit # pyright: ignore [reportAssignmentType]
129246
@@ -134,38 +251,15 @@ def get_metadata(cls) -> sqlalchemy.MetaData:
134251 import alembic .operations .ops
135252 import alembic .runtime .environment
136253 import alembic .script
137- import alembic .util
138- import sqlalchemy
139- import sqlalchemy .exc
140- import sqlalchemy .ext .asyncio
141- import sqlalchemy .orm
254+ import sqlmodel
142255 from alembic .runtime .migration import MigrationContext
143256 from alembic .script .base import Script
257+ from pydantic import ConfigDict
258+ from sqlmodel ._compat import IS_PYDANTIC_V2
259+ from sqlmodel .ext .asyncio .session import AsyncSession
144260
145- from reflex .utils .compat import sqlmodel
146-
147- def _sqlmodel_field_has_primary_key (field : Any ) -> bool :
148- """Determines if a field is a primary.
149-
150- Args:
151- field: a rx.model field
152-
153- Returns:
154- If field is a primary key (Bool)
155- """
156- if getattr (field .field_info , "primary_key" , None ) is True :
157- return True
158- if getattr (field .field_info , "sa_column" , None ) is None :
159- return False
160- return bool (getattr (field .field_info .sa_column , "primary_key" , None ))
161-
162- _ENGINE : dict [str , sqlalchemy .engine .Engine ] = {}
163- _ASYNC_ENGINE : dict [str , sqlalchemy .ext .asyncio .AsyncEngine ] = {}
164261 _AsyncSessionLocal : dict [str | None , sqlalchemy .ext .asyncio .async_sessionmaker ] = {}
165262
166- # Import AsyncSession _after_ reflex.utils.compat
167- from sqlmodel .ext .asyncio .session import AsyncSession
168-
169263 def format_revision (
170264 rev : Script ,
171265 current_rev : str | None ,
@@ -200,113 +294,6 @@ def format_revision(
200294 # Format output with message
201295 return f" [{ status_icon } ] { current } { head_marker } , { message } "
202296
203- def _safe_db_url_for_logging (url : str ) -> str :
204- """Remove username and password from the database URL for logging.
205-
206- Args:
207- url: The database URL.
208-
209- Returns:
210- The database URL with the username and password removed.
211- """
212- return re .sub (r"://[^@]+@" , "://<username>:<password>@" , url )
213-
214- def get_engine_args (url : str | None = None ) -> dict [str , Any ]:
215- """Get the database engine arguments.
216-
217- Args:
218- url: The database url.
219-
220- Returns:
221- The database engine arguments as a dict.
222- """
223- kwargs : dict [str , Any ] = {
224- # Print the SQL queries if the log level is INFO or lower.
225- "echo" : environment .SQLALCHEMY_ECHO .get (),
226- # Check connections before returning them.
227- "pool_pre_ping" : environment .SQLALCHEMY_POOL_PRE_PING .get (),
228- }
229- conf = get_config ()
230- url = url or conf .db_url
231- if url is not None and url .startswith ("sqlite" ):
232- # Needed for the admin dash on sqlite.
233- kwargs ["connect_args" ] = {"check_same_thread" : False }
234- return kwargs
235-
236- def get_engine (url : str | None = None ) -> sqlalchemy .engine .Engine :
237- """Get the database engine.
238-
239- Args:
240- url: the DB url to use.
241-
242- Returns:
243- The database engine.
244-
245- Raises:
246- ValueError: If the database url is None.
247- """
248- conf = get_config ()
249- url = url or conf .db_url
250- if url is None :
251- msg = "No database url configured"
252- raise ValueError (msg )
253-
254- global _ENGINE
255- if url in _ENGINE :
256- return _ENGINE [url ]
257-
258- if not environment .ALEMBIC_CONFIG .get ().exists ():
259- console .warn (
260- "Database is not initialized, run [bold]reflex db init[/bold] first."
261- )
262- _ENGINE [url ] = sqlmodel .create_engine (
263- url ,
264- ** get_engine_args (url ),
265- )
266- return _ENGINE [url ]
267-
268- def get_async_engine (url : str | None ) -> sqlalchemy .ext .asyncio .AsyncEngine :
269- """Get the async database engine.
270-
271- Args:
272- url: The database url.
273-
274- Returns:
275- The async database engine.
276-
277- Raises:
278- ValueError: If the async database url is None.
279- """
280- if url is None :
281- conf = get_config ()
282- url = conf .async_db_url
283- if url is not None and conf .db_url is not None :
284- async_db_url_tail = url .partition ("://" )[2 ]
285- db_url_tail = conf .db_url .partition ("://" )[2 ]
286- if async_db_url_tail != db_url_tail :
287- console .warn (
288- f"async_db_url `{ _safe_db_url_for_logging (url )} ` "
289- "should reference the same database as "
290- f"db_url `{ _safe_db_url_for_logging (conf .db_url )} `."
291- )
292- if url is None :
293- msg = "No async database url configured"
294- raise ValueError (msg )
295-
296- global _ASYNC_ENGINE
297- if url in _ASYNC_ENGINE :
298- return _ASYNC_ENGINE [url ]
299-
300- if not environment .ALEMBIC_CONFIG .get ().exists ():
301- console .warn (
302- "Database is not initialized, run [bold]reflex db init[/bold] first."
303- )
304- _ASYNC_ENGINE [url ] = sqlalchemy .ext .asyncio .create_async_engine (
305- url ,
306- ** get_engine_args (url ),
307- )
308- return _ASYNC_ENGINE [url ]
309-
310297 async def get_db_status () -> dict [str , bool ]:
311298 """Checks the status of the database connection.
312299
@@ -325,18 +312,35 @@ async def get_db_status() -> dict[str, bool]:
325312
326313 return {"db" : status }
327314
328- class Model (Base , sqlmodel .SQLModel ): # pyright: ignore [reportGeneralTypeIssues,reportIncompatibleVariableOverride]
315+ class Model (sqlmodel .SQLModel ):
329316 """Base class to define a table in the database."""
330317
331318 # The primary key for the table.
332319 id : int | None = sqlmodel .Field (default = None , primary_key = True )
333320
321+ if IS_PYDANTIC_V2 :
322+ model_config = ConfigDict ( # pyright: ignore [reportAssignmentType]
323+ arbitrary_types_allowed = True ,
324+ extra = "allow" ,
325+ use_enum_values = True ,
326+ from_attributes = True ,
327+ )
328+ else :
329+
330+ class Config : # pyright: ignore [reportIncompatibleVariableOverride]
331+ """Pydantic V1 config."""
332+
333+ arbitrary_types_allowed = True
334+ use_enum_values = True
335+ extra = "allow"
336+ orm_mode = True
337+
334338 def __init_subclass__ (cls ):
335339 """Drop the default primary key field if any primary key field is defined."""
336340 non_default_primary_key_fields = [
337341 field_name
338342 for field_name , field in cls .__fields__ .items ()
339- if field_name != "id" and _sqlmodel_field_has_primary_key (field )
343+ if field_name != "id" and sqlmodel_field_has_primary_key (field )
340344 ]
341345 if non_default_primary_key_fields :
342346 cls .__fields__ .pop ("id" , None )
@@ -381,6 +385,19 @@ def dict(self, **kwargs):
381385 ** relationships ,
382386 }
383387
388+ def json (self ) -> str :
389+ """Convert the object to a json string.
390+
391+ Returns:
392+ The object as a json string.
393+ """
394+ from reflex .utils .serializers import serialize
395+
396+ return self .__config__ .json_dumps (
397+ self .dict (),
398+ default = serialize ,
399+ )
400+
384401 @staticmethod
385402 def create_all ():
386403 """Create all the tables."""
@@ -643,9 +660,6 @@ def asession(url: str | None = None) -> AsyncSession:
643660 return _AsyncSessionLocal [url ]()
644661
645662else :
646- get_engine_args = _print_db_not_available
647- get_engine = _print_db_not_available
648- get_async_engine = _print_db_not_available
649663 get_db_status = _print_db_not_available
650664 session = _print_db_not_available
651665 asession = _print_db_not_available
0 commit comments