Skip to content

Commit 1d4bfe0

Browse files
authored
MutableProxy: wrap dataclass and BaseModel methods (#5979)
When calling a method on an arbitrary wrapped object, rebind it's `self` as the mutable proxy so changes made inside the method are also tracked. (Previously only wrapped `Base` instances had in-method tracking).
1 parent a679316 commit 1d4bfe0

File tree

2 files changed

+77
-10
lines changed

2 files changed

+77
-10
lines changed

reflex/istate/proxy.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -553,18 +553,17 @@ def __getattr__(self, __name: str) -> Any:
553553
value = wrapt.FunctionWrapper(value, self._mark_dirty)
554554

555555
if __name in self.__wrap_mutable_attrs__:
556-
# Wrap methods that may return mutable objects tied to the state.
556+
# Wrap special methods that may return mutable objects tied to the state.
557557
value = wrapt.FunctionWrapper(
558558
value,
559559
self._wrap_recursive_decorator, # pyright: ignore[reportArgumentType]
560560
)
561561

562562
if (
563-
isinstance(self.__wrapped__, Base)
564-
and __name not in NEVER_WRAP_BASE_ATTRS
565-
and hasattr(value, "__func__")
566-
):
567-
# Wrap methods called on Base subclasses, which might do _anything_
563+
not isinstance(self.__wrapped__, Base)
564+
or __name not in NEVER_WRAP_BASE_ATTRS
565+
) and hasattr(value, "__func__"):
566+
# Wrap methods which might do _anything_
568567
return wrapt.FunctionWrapper(
569568
functools.partial(value.__func__, self), # pyright: ignore [reportFunctionMemberAccess, reportAttributeAccessIssue]
570569
self._wrap_recursive_decorator, # pyright: ignore[reportArgumentType]

tests/units/test_state.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2055,6 +2055,22 @@ class ModelDC:
20552055
foo: str = "bar"
20562056
ls: list[dict] = dataclasses.field(default_factory=list)
20572057

2058+
def set_foo(self, val: str):
2059+
"""Set the attribute foo.
2060+
2061+
Args:
2062+
val: The value to set.
2063+
"""
2064+
self.foo = val
2065+
2066+
def double_foo(self) -> str:
2067+
"""Concatenate foo with foo.
2068+
2069+
Returns:
2070+
foo + foo
2071+
"""
2072+
return self.foo + self.foo
2073+
20582074

20592075
@pytest.mark.asyncio
20602076
async def test_state_proxy(
@@ -3806,12 +3822,44 @@ class ModelV1(BaseModelV1):
38063822

38073823
foo: str = "bar"
38083824

3825+
def set_foo(self, val: str):
3826+
"""Set the attribute foo.
3827+
3828+
Args:
3829+
val: The value to set.
3830+
"""
3831+
self.foo = val
3832+
3833+
def double_foo(self) -> str:
3834+
"""Concatenate foo with foo.
3835+
3836+
Returns:
3837+
foo + foo
3838+
"""
3839+
return self.foo + self.foo
3840+
38093841

38103842
class ModelV2(BaseModelV2):
38113843
"""A pydantic BaseModel v2."""
38123844

38133845
foo: str = "bar"
38143846

3847+
def set_foo(self, val: str):
3848+
"""Set the attribute foo.
3849+
3850+
Args:
3851+
val: The value to set.
3852+
"""
3853+
self.foo = val
3854+
3855+
def double_foo(self) -> str:
3856+
"""Concatenate foo with foo.
3857+
3858+
Returns:
3859+
foo + foo
3860+
"""
3861+
return self.foo + self.foo
3862+
38153863

38163864
class PydanticState(rx.State):
38173865
"""A state with pydantic BaseModel vars."""
@@ -3828,26 +3876,46 @@ def test_mutable_models():
38283876
state.v1.foo = "baz"
38293877
assert state.dirty_vars == {"v1"}
38303878
state.dirty_vars.clear()
3879+
state.v1.set_foo("quuc")
3880+
assert state.dirty_vars == {"v1"}
3881+
state.dirty_vars.clear()
3882+
assert state.v1.double_foo() == "quucquuc"
3883+
assert state.dirty_vars == set()
3884+
state.v1.copy(update={"foo": "larp"})
3885+
assert state.dirty_vars == set()
38313886

38323887
assert isinstance(state.v2, MutableProxy)
38333888
state.v2.foo = "baz"
38343889
assert state.dirty_vars == {"v2"}
38353890
state.dirty_vars.clear()
3891+
state.v2.set_foo("quuc")
3892+
assert state.dirty_vars == {"v2"}
3893+
state.dirty_vars.clear()
3894+
assert state.v2.double_foo() == "quucquuc"
3895+
assert state.dirty_vars == set()
3896+
state.v2.model_copy(update={"foo": "larp"})
3897+
assert state.dirty_vars == set()
38363898

38373899
assert isinstance(state.dc, MutableProxy)
38383900
state.dc.foo = "baz"
38393901
assert state.dirty_vars == {"dc"}
38403902
state.dirty_vars.clear()
38413903
assert state.dirty_vars == set()
3904+
state.dc.set_foo("quuc")
3905+
assert state.dirty_vars == {"dc"}
3906+
state.dirty_vars.clear()
3907+
assert state.dirty_vars == set()
3908+
assert state.dc.double_foo() == "quucquuc"
3909+
assert state.dirty_vars == set()
38423910
state.dc.ls.append({"hi": "reflex"})
38433911
assert state.dirty_vars == {"dc"}
38443912
state.dirty_vars.clear()
38453913
assert state.dirty_vars == set()
3846-
assert dataclasses.asdict(state.dc) == {"foo": "baz", "ls": [{"hi": "reflex"}]}
3847-
assert dataclasses.astuple(state.dc) == ("baz", [{"hi": "reflex"}])
3914+
assert dataclasses.asdict(state.dc) == {"foo": "quuc", "ls": [{"hi": "reflex"}]}
3915+
assert dataclasses.astuple(state.dc) == ("quuc", [{"hi": "reflex"}])
38483916
# creating a new instance shouldn't mark the state dirty
3849-
assert dataclasses.replace(state.dc, foo="quuc") == ModelDC(
3850-
foo="quuc", ls=[{"hi": "reflex"}]
3917+
assert dataclasses.replace(state.dc, foo="larp") == ModelDC(
3918+
foo="larp", ls=[{"hi": "reflex"}]
38513919
)
38523920
assert state.dirty_vars == set()
38533921

0 commit comments

Comments
 (0)