Skip to content

Commit 9d8f406

Browse files
committed
move sqlmodel to pydantic v2 and remove compat code
1 parent b2ef5c9 commit 9d8f406

File tree

6 files changed

+62
-88
lines changed

6 files changed

+62
-88
lines changed

pyi_hashes.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"reflex/__init__.pyi": "cc4f461d8244f0f372b7607eb1edd146",
2+
"reflex/__init__.pyi": "2fa0051c43f2d3d10114283480c666fb",
33
"reflex/components/__init__.pyi": "ac05995852baa81062ba3d18fbc489fb",
44
"reflex/components/base/__init__.pyi": "16e47bf19e0d62835a605baa3d039c5a",
55
"reflex/components/base/app_wrap.pyi": "22e94feaa9fe675bcae51c412f5b67f1",

reflex/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@
323323
"SessionStorage",
324324
],
325325
"middleware": ["middleware", "Middleware"],
326-
"model": ["asession", "session", "Model"],
326+
"model": ["asession", "session", "Model", "ModelRegistry"],
327327
"page": ["page"],
328328
"state": [
329329
"var",

reflex/istate/proxy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from reflex.base import Base
1919
from reflex.utils import prerequisites
2020
from reflex.utils.exceptions import ImmutableStateError
21-
from reflex.utils.serializers import serializer
21+
from reflex.utils.serializers import can_serialize, serialize, serializer
2222
from reflex.vars.base import Var
2323

2424
if TYPE_CHECKING:
@@ -689,7 +689,10 @@ def serialize_mutable_proxy(mp: MutableProxy):
689689
Returns:
690690
The wrapped object.
691691
"""
692-
return mp.__wrapped__
692+
obj = mp.__wrapped__
693+
if can_serialize(type(obj)):
694+
return serialize(obj)
695+
return obj
693696

694697

695698
_orig_json_encoder_default = json.JSONEncoder.default

reflex/model.py

Lines changed: 43 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from reflex.environment import environment
1111
from reflex.utils import console
1212
from reflex.utils.compat import sqlmodel_field_has_primary_key
13+
from reflex.utils.serializers import serializer
1314

1415
if TYPE_CHECKING:
1516
import sqlalchemy
@@ -254,8 +255,6 @@ def get_metadata(cls) -> sqlalchemy.MetaData:
254255
import sqlmodel
255256
from alembic.runtime.migration import MigrationContext
256257
from alembic.script.base import Script
257-
from pydantic import ConfigDict
258-
from sqlmodel._compat import IS_PYDANTIC_V2
259258
from sqlmodel.ext.asyncio.session import AsyncSession
260259

261260
_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
@@ -312,91 +311,60 @@ async def get_db_status() -> dict[str, bool]:
312311

313312
return {"db": status}
314313

314+
@serializer
315+
def serialize_sqlmodel(m: sqlmodel.SQLModel) -> dict[str, Any]:
316+
"""Serialize a SQLModel object to a dictionary.
317+
318+
Args:
319+
m: The SQLModel object to serialize.
320+
321+
Returns:
322+
The serialized object as a dictionary.
323+
"""
324+
base_fields = m.model_dump()
325+
relationships = {}
326+
# SQLModel relationships do not appear in __fields__, but should be included if present.
327+
for name in m.__sqlmodel_relationships__:
328+
with suppress(
329+
sqlalchemy.orm.exc.DetachedInstanceError # This happens when the relationship was never loaded and the session is closed.
330+
):
331+
relationships[name] = getattr(m, name)
332+
return {
333+
**base_fields,
334+
**relationships,
335+
}
336+
315337
class Model(sqlmodel.SQLModel):
316338
"""Base class to define a table in the database."""
317339

318340
# The primary key for the table.
319341
id: int | None = sqlmodel.Field(default=None, primary_key=True)
320342

321-
if IS_PYDANTIC_V2:
322-
model_config = ConfigDict( # pyright: ignore [reportAssignmentType]
323-
arbitrary_types_allowed=True,
324-
extra="allow",
325-
use_enum_values=True,
326-
from_attributes=True,
327-
)
328-
else:
329-
330-
class Config: # pyright: ignore [reportIncompatibleVariableOverride]
331-
"""Pydantic V1 config."""
332-
333-
arbitrary_types_allowed = True
334-
use_enum_values = True
335-
extra = "allow"
336-
orm_mode = True
343+
model_config = { # pyright: ignore [reportAssignmentType]
344+
"arbitrary_types_allowed": True,
345+
"use_enum_values": True,
346+
"extra": "allow",
347+
}
337348

338-
def __init_subclass__(cls):
349+
@classmethod
350+
def __pydantic_init_subclass__(cls):
339351
"""Drop the default primary key field if any primary key field is defined."""
340352
non_default_primary_key_fields = [
341353
field_name
342-
for field_name, field in cls.__fields__.items()
343-
if field_name != "id" and sqlmodel_field_has_primary_key(field)
354+
for field_name, field_info in cls.model_fields.items()
355+
if field_name != "id" and sqlmodel_field_has_primary_key(field_info)
344356
]
345357
if non_default_primary_key_fields:
346-
cls.__fields__.pop("id", None)
347-
348-
super().__init_subclass__()
349-
350-
@classmethod
351-
def _dict_recursive(cls, value: Any):
352-
"""Recursively serialize the relationship object(s).
353-
354-
Args:
355-
value: The value to serialize.
356-
357-
Returns:
358-
The serialized value.
359-
"""
360-
if hasattr(value, "dict"):
361-
return value.dict()
362-
if isinstance(value, list):
363-
return [cls._dict_recursive(item) for item in value]
364-
return value
365-
366-
def dict(self, **kwargs):
367-
"""Convert the object to a dictionary.
368-
369-
Args:
370-
kwargs: Ignored but needed for compatibility.
371-
372-
Returns:
373-
The object as a dictionary.
374-
"""
375-
base_fields = {name: getattr(self, name) for name in self.__fields__}
376-
relationships = {}
377-
# SQLModel relationships do not appear in __fields__, but should be included if present.
378-
for name in self.__sqlmodel_relationships__:
379-
with suppress(
380-
sqlalchemy.orm.exc.DetachedInstanceError # This happens when the relationship was never loaded and the session is closed.
381-
):
382-
relationships[name] = self._dict_recursive(getattr(self, name))
383-
return {
384-
**base_fields,
385-
**relationships,
386-
}
387-
388-
def json(self) -> str:
389-
"""Convert the object to a json string.
390-
391-
Returns:
392-
The object as a json string.
393-
"""
394-
from reflex.utils.serializers import serialize
395-
396-
return self.__config__.json_dumps(
397-
self.dict(),
398-
default=serialize,
399-
)
358+
cls.model_fields.pop("id", None)
359+
console.deprecate(
360+
feature_name="Overriding default primary key",
361+
reason=(
362+
"Register sqlmodel.SQLModel classes with `@rx.ModelRegistry.register`"
363+
),
364+
deprecation_version="0.8.0",
365+
removal_version="0.9.0",
366+
)
367+
super().__pydantic_init_subclass__()
400368

401369
@staticmethod
402370
def create_all():

reflex/utils/compat.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Compatibility hacks and helpers."""
22

3-
from typing import Any
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from pydantic.fields import FieldInfo
47

58

69
async def windows_hot_reload_lifespan_hack():
@@ -27,17 +30,17 @@ async def windows_hot_reload_lifespan_hack():
2730
pass
2831

2932

30-
def sqlmodel_field_has_primary_key(field: Any) -> bool:
33+
def sqlmodel_field_has_primary_key(field_info: "FieldInfo") -> bool:
3134
"""Determines if a field is a primary.
3235
3336
Args:
34-
field: a rx.model field
37+
field_info: a rx.model field
3538
3639
Returns:
37-
If field is a primary key (Bool)
40+
If field_info is a primary key (Bool)
3841
"""
39-
if getattr(field.field_info, "primary_key", None) is True:
42+
if getattr(field_info, "primary_key", None) is True:
4043
return True
41-
if getattr(field.field_info, "sa_column", None) is None:
44+
if getattr(field_info, "sa_column", None) is None:
4245
return False
43-
return bool(getattr(field.field_info.sa_column, "primary_key", None))
46+
return bool(getattr(field_info.sa_column, "primary_key", None)) # pyright: ignore[reportAttributeAccessIssue]

tests/units/test_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_default_primary_key(model_default_primary: Model):
5353
Args:
5454
model_default_primary: Fixture.
5555
"""
56-
assert "id" in type(model_default_primary).__fields__
56+
assert "id" in type(model_default_primary).model_fields
5757

5858

5959
def test_custom_primary_key(model_custom_primary: Model):
@@ -62,7 +62,7 @@ def test_custom_primary_key(model_custom_primary: Model):
6262
Args:
6363
model_custom_primary: Fixture.
6464
"""
65-
assert "id" not in type(model_custom_primary).__fields__
65+
assert "id" not in type(model_custom_primary).model_fields
6666

6767

6868
@pytest.mark.filterwarnings(

0 commit comments

Comments
 (0)