Skip to content

Commit a6b324b

Browse files
masenfadhami3310
andauthored
[ENG-3953] Support pydantic BaseModel (v1 and v2) as state var (#4338)
* [ENG-3953] Support pydantic BaseModel (v1 and v2) as state var Provide serializers and mutable proxy tracking for pydantic models directly. * conditionally define v2 serializer Co-authored-by: Khaleel Al-Adhami <[email protected]> * Add `MutableProxy._is_mutable_value` to avoid duplicate logic * Conditionally import BaseModel to handle older pydantic v1 versions * pre-commit fu --------- Co-authored-by: Khaleel Al-Adhami <[email protected]>
1 parent 5702a18 commit a6b324b

File tree

3 files changed

+128
-6
lines changed

3 files changed

+128
-6
lines changed

reflex/state.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@
6262
except ModuleNotFoundError:
6363
import pydantic
6464

65+
from pydantic import BaseModel as BaseModelV2
66+
67+
try:
68+
from pydantic.v1 import BaseModel as BaseModelV1
69+
except ModuleNotFoundError:
70+
BaseModelV1 = BaseModelV2
71+
6572
import wrapt
6673
from redis.asyncio import Redis
6774
from redis.exceptions import ResponseError
@@ -1250,7 +1257,7 @@ def __getattribute__(self, name: str) -> Any:
12501257
if parent_state is not None:
12511258
return getattr(parent_state, name)
12521259

1253-
if isinstance(value, MutableProxy.__mutable_types__) and (
1260+
if MutableProxy._is_mutable_type(value) and (
12541261
name in super().__getattribute__("base_vars") or name in backend_vars
12551262
):
12561263
# track changes in mutable containers (list, dict, set, etc)
@@ -3558,7 +3565,16 @@ class MutableProxy(wrapt.ObjectProxy):
35583565
pydantic.BaseModel.__dict__
35593566
)
35603567

3561-
__mutable_types__ = (list, dict, set, Base, DeclarativeBase)
3568+
# These types will be wrapped in MutableProxy
3569+
__mutable_types__ = (
3570+
list,
3571+
dict,
3572+
set,
3573+
Base,
3574+
DeclarativeBase,
3575+
BaseModelV2,
3576+
BaseModelV1,
3577+
)
35623578

35633579
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
35643580
"""Create a proxy for a mutable object that tracks changes.
@@ -3598,6 +3614,18 @@ def _mark_dirty(
35983614
if wrapped is not None:
35993615
return wrapped(*args, **(kwargs or {}))
36003616

3617+
@classmethod
3618+
def _is_mutable_type(cls, value: Any) -> bool:
3619+
"""Check if a value is of a mutable type and should be wrapped.
3620+
3621+
Args:
3622+
value: The value to check.
3623+
3624+
Returns:
3625+
Whether the value is of a mutable type.
3626+
"""
3627+
return isinstance(value, cls.__mutable_types__)
3628+
36013629
def _wrap_recursive(self, value: Any) -> Any:
36023630
"""Wrap a value recursively if it is mutable.
36033631

@@ -3608,9 +3636,7 @@ def _wrap_recursive(self, value: Any) -> Any:
36083636
The wrapped value.
36093637
"""
36103638
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
3611-
if isinstance(value, self.__mutable_types__) and not isinstance(
3612-
value, MutableProxy
3613-
):
3639+
if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
36143640
return type(self)(
36153641
wrapped=value,
36163642
state=self._self_state,
@@ -3668,7 +3694,7 @@ def __getattr__(self, __name: str) -> Any:
36683694
self._wrap_recursive_decorator,
36693695
)
36703696

3671-
if isinstance(value, self.__mutable_types__) and __name not in (
3697+
if self._is_mutable_type(value) and __name not in (
36723698
"__wrapped__",
36733699
"_self_state",
36743700
):

reflex/utils/serializers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,53 @@ def serialize_base(value: Base) -> dict:
270270
}
271271

272272

273+
try:
274+
from pydantic.v1 import BaseModel as BaseModelV1
275+
276+
@serializer(to=dict)
277+
def serialize_base_model_v1(model: BaseModelV1) -> dict:
278+
"""Serialize a pydantic v1 BaseModel instance.
279+
280+
Args:
281+
model: The BaseModel to serialize.
282+
283+
Returns:
284+
The serialized BaseModel.
285+
"""
286+
return model.dict()
287+
288+
from pydantic import BaseModel as BaseModelV2
289+
290+
if BaseModelV1 is not BaseModelV2:
291+
292+
@serializer(to=dict)
293+
def serialize_base_model_v2(model: BaseModelV2) -> dict:
294+
"""Serialize a pydantic v2 BaseModel instance.
295+
296+
Args:
297+
model: The BaseModel to serialize.
298+
299+
Returns:
300+
The serialized BaseModel.
301+
"""
302+
return model.model_dump()
303+
except ImportError:
304+
# Older pydantic v1 import
305+
from pydantic import BaseModel as BaseModelV1
306+
307+
@serializer(to=dict)
308+
def serialize_base_model_v1(model: BaseModelV1) -> dict:
309+
"""Serialize a pydantic v1 BaseModel instance.
310+
311+
Args:
312+
model: The BaseModel to serialize.
313+
314+
Returns:
315+
The serialized BaseModel.
316+
"""
317+
return model.dict()
318+
319+
273320
@serializer
274321
def serialize_set(value: Set) -> list:
275322
"""Serialize a set to a JSON serializable list.

tests/units/test_state.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import pytest
1717
import pytest_asyncio
1818
from plotly.graph_objects import Figure
19+
from pydantic import BaseModel as BaseModelV2
20+
from pydantic.v1 import BaseModel as BaseModelV1
1921

2022
import reflex as rx
2123
import reflex.config
@@ -3413,6 +3415,53 @@ class TypedState(rx.State):
34133415
_ = TypedState(field="str")
34143416

34153417

3418+
class ModelV1(BaseModelV1):
3419+
"""A pydantic BaseModel v1."""
3420+
3421+
foo: str = "bar"
3422+
3423+
3424+
class ModelV2(BaseModelV2):
3425+
"""A pydantic BaseModel v2."""
3426+
3427+
foo: str = "bar"
3428+
3429+
3430+
@dataclasses.dataclass
3431+
class ModelDC:
3432+
"""A dataclass."""
3433+
3434+
foo: str = "bar"
3435+
3436+
3437+
class PydanticState(rx.State):
3438+
"""A state with pydantic BaseModel vars."""
3439+
3440+
v1: ModelV1 = ModelV1()
3441+
v2: ModelV2 = ModelV2()
3442+
dc: ModelDC = ModelDC()
3443+
3444+
3445+
def test_mutable_models():
3446+
"""Test that dataclass and pydantic BaseModel v1 and v2 use dep tracking."""
3447+
state = PydanticState()
3448+
assert isinstance(state.v1, MutableProxy)
3449+
state.v1.foo = "baz"
3450+
assert state.dirty_vars == {"v1"}
3451+
state.dirty_vars.clear()
3452+
3453+
assert isinstance(state.v2, MutableProxy)
3454+
state.v2.foo = "baz"
3455+
assert state.dirty_vars == {"v2"}
3456+
state.dirty_vars.clear()
3457+
3458+
# Not yet supported ENG-4083
3459+
# assert isinstance(state.dc, MutableProxy)
3460+
# state.dc.foo = "baz"
3461+
# assert state.dirty_vars == {"dc"}
3462+
# state.dirty_vars.clear()
3463+
3464+
34163465
def test_get_value():
34173466
class GetValueState(rx.State):
34183467
foo: str = "FOO"

0 commit comments

Comments
 (0)