|
10 | 10 | from reflex.environment import environment |
11 | 11 | from reflex.utils import console |
12 | 12 | from reflex.utils.compat import sqlmodel_field_has_primary_key |
| 13 | +from reflex.utils.serializers import serializer |
13 | 14 |
|
14 | 15 | if TYPE_CHECKING: |
15 | 16 | import sqlalchemy |
@@ -254,8 +255,6 @@ def get_metadata(cls) -> sqlalchemy.MetaData: |
254 | 255 | import sqlmodel |
255 | 256 | from alembic.runtime.migration import MigrationContext |
256 | 257 | from alembic.script.base import Script |
257 | | - from pydantic import ConfigDict |
258 | | - from sqlmodel._compat import IS_PYDANTIC_V2 |
259 | 258 | from sqlmodel.ext.asyncio.session import AsyncSession |
260 | 259 |
|
261 | 260 | _AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {} |
@@ -312,91 +311,60 @@ async def get_db_status() -> dict[str, bool]: |
312 | 311 |
|
313 | 312 | return {"db": status} |
314 | 313 |
|
| 314 | + @serializer |
| 315 | + def serialize_sqlmodel(m: sqlmodel.SQLModel) -> dict[str, Any]: |
| 316 | + """Serialize a SQLModel object to a dictionary. |
| 317 | +
|
| 318 | + Args: |
| 319 | + m: The SQLModel object to serialize. |
| 320 | +
|
| 321 | + Returns: |
| 322 | + The serialized object as a dictionary. |
| 323 | + """ |
| 324 | + base_fields = m.model_dump() |
| 325 | + relationships = {} |
| 326 | + # SQLModel relationships do not appear in __fields__, but should be included if present. |
| 327 | + for name in m.__sqlmodel_relationships__: |
| 328 | + with suppress( |
| 329 | + sqlalchemy.orm.exc.DetachedInstanceError # This happens when the relationship was never loaded and the session is closed. |
| 330 | + ): |
| 331 | + relationships[name] = getattr(m, name) |
| 332 | + return { |
| 333 | + **base_fields, |
| 334 | + **relationships, |
| 335 | + } |
| 336 | + |
315 | 337 | class Model(sqlmodel.SQLModel): |
316 | 338 | """Base class to define a table in the database.""" |
317 | 339 |
|
318 | 340 | # The primary key for the table. |
319 | 341 | id: int | None = sqlmodel.Field(default=None, primary_key=True) |
320 | 342 |
|
321 | | - if IS_PYDANTIC_V2: |
322 | | - model_config = ConfigDict( # pyright: ignore [reportAssignmentType] |
323 | | - arbitrary_types_allowed=True, |
324 | | - extra="allow", |
325 | | - use_enum_values=True, |
326 | | - from_attributes=True, |
327 | | - ) |
328 | | - else: |
329 | | - |
330 | | - class Config: # pyright: ignore [reportIncompatibleVariableOverride] |
331 | | - """Pydantic V1 config.""" |
332 | | - |
333 | | - arbitrary_types_allowed = True |
334 | | - use_enum_values = True |
335 | | - extra = "allow" |
336 | | - orm_mode = True |
| 343 | + model_config = { # pyright: ignore [reportAssignmentType] |
| 344 | + "arbitrary_types_allowed": True, |
| 345 | + "use_enum_values": True, |
| 346 | + "extra": "allow", |
| 347 | + } |
337 | 348 |
|
338 | | - def __init_subclass__(cls): |
| 349 | + @classmethod |
| 350 | + def __pydantic_init_subclass__(cls): |
339 | 351 | """Drop the default primary key field if any primary key field is defined.""" |
340 | 352 | non_default_primary_key_fields = [ |
341 | 353 | field_name |
342 | | - for field_name, field in cls.__fields__.items() |
343 | | - if field_name != "id" and sqlmodel_field_has_primary_key(field) |
| 354 | + for field_name, field_info in cls.model_fields.items() |
| 355 | + if field_name != "id" and sqlmodel_field_has_primary_key(field_info) |
344 | 356 | ] |
345 | 357 | if non_default_primary_key_fields: |
346 | | - cls.__fields__.pop("id", None) |
347 | | - |
348 | | - super().__init_subclass__() |
349 | | - |
350 | | - @classmethod |
351 | | - def _dict_recursive(cls, value: Any): |
352 | | - """Recursively serialize the relationship object(s). |
353 | | -
|
354 | | - Args: |
355 | | - value: The value to serialize. |
356 | | -
|
357 | | - Returns: |
358 | | - The serialized value. |
359 | | - """ |
360 | | - if hasattr(value, "dict"): |
361 | | - return value.dict() |
362 | | - if isinstance(value, list): |
363 | | - return [cls._dict_recursive(item) for item in value] |
364 | | - return value |
365 | | - |
366 | | - def dict(self, **kwargs): |
367 | | - """Convert the object to a dictionary. |
368 | | -
|
369 | | - Args: |
370 | | - kwargs: Ignored but needed for compatibility. |
371 | | -
|
372 | | - Returns: |
373 | | - The object as a dictionary. |
374 | | - """ |
375 | | - base_fields = {name: getattr(self, name) for name in self.__fields__} |
376 | | - relationships = {} |
377 | | - # SQLModel relationships do not appear in __fields__, but should be included if present. |
378 | | - for name in self.__sqlmodel_relationships__: |
379 | | - with suppress( |
380 | | - sqlalchemy.orm.exc.DetachedInstanceError # This happens when the relationship was never loaded and the session is closed. |
381 | | - ): |
382 | | - relationships[name] = self._dict_recursive(getattr(self, name)) |
383 | | - return { |
384 | | - **base_fields, |
385 | | - **relationships, |
386 | | - } |
387 | | - |
388 | | - def json(self) -> str: |
389 | | - """Convert the object to a json string. |
390 | | -
|
391 | | - Returns: |
392 | | - The object as a json string. |
393 | | - """ |
394 | | - from reflex.utils.serializers import serialize |
395 | | - |
396 | | - return self.__config__.json_dumps( |
397 | | - self.dict(), |
398 | | - default=serialize, |
399 | | - ) |
| 358 | + cls.model_fields.pop("id", None) |
| 359 | + console.deprecate( |
| 360 | + feature_name="Overriding default primary key", |
| 361 | + reason=( |
| 362 | + "Register sqlmodel.SQLModel classes with `@rx.ModelRegistry.register`" |
| 363 | + ), |
| 364 | + deprecation_version="0.8.0", |
| 365 | + removal_version="0.9.0", |
| 366 | + ) |
| 367 | + super().__pydantic_init_subclass__() |
400 | 368 |
|
401 | 369 | @staticmethod |
402 | 370 | def create_all(): |
|
0 commit comments