Skip to content

Commit 6984750

Browse files
committed
abstract out _patch_state and use it when linking or unlinking
1 parent ec3e36c commit 6984750

File tree

1 file changed

+62
-50
lines changed

1 file changed

+62
-50
lines changed

reflex/istate/shared.py

Lines changed: 62 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,46 @@ async def _update_client(token: str):
5252
return tasks
5353

5454

55+
@contextlib.asynccontextmanager
56+
async def _patch_state(
57+
original_state: BaseState, linked_state: BaseState, full_delta: bool = False
58+
):
59+
"""Patch the linked state into the original state's tree, restoring it afterward.
60+
61+
Args:
62+
original_state: The original shared state.
63+
linked_state: The linked shared state.
64+
full_delta: If True, mark all Vars in linked_state dirty and resolve
65+
the delta from the root. This option is used when linking or unlinking
66+
to ensure that other computed vars in the tree pick up the newly
67+
linked/unlinked values.
68+
"""
69+
if (original_parent_state := original_state.parent_state) is None:
70+
msg = "Cannot patch root state as linked state."
71+
raise ReflexRuntimeError(msg)
72+
73+
state_name = original_state.get_name()
74+
original_parent_state.substates[state_name] = linked_state
75+
linked_parent_state = linked_state.parent_state
76+
linked_state.parent_state = original_parent_state
77+
try:
78+
if full_delta:
79+
linked_state.dirty_vars.update(linked_state.base_vars)
80+
linked_state.dirty_vars.update(linked_state.backend_vars)
81+
linked_state.dirty_vars.update(linked_state.computed_vars)
82+
linked_state._mark_dirty()
83+
# Apply the updates into the existing state tree for rehydrate.
84+
root_state = original_state._get_root_state()
85+
await root_state._get_resolved_delta()
86+
yield
87+
finally:
88+
original_parent_state.substates[state_name] = original_state
89+
linked_state.parent_state = linked_parent_state
90+
91+
5592
class SharedStateBaseInternal(State):
5693
"""The private base state for all shared states."""
5794

58-
# While _modify_linked_states is active, this holds the original substates for the client's tree.
59-
_original_substates: dict[str, tuple[BaseState, BaseState | None]] = {}
60-
6195
def __getstate__(self):
6296
"""Override redis serialization to remove temporary fields.
6397
@@ -66,7 +100,6 @@ def __getstate__(self):
66100
"""
67101
s = super().__getstate__()
68102
# Don't want to persist the cached substates
69-
s.pop("_original_substates", None)
70103
s.pop("_previous_dirty_vars", None)
71104
return s
72105

@@ -90,7 +123,6 @@ def _mark_dirty(self):
90123
Since these internal fields are not persisted to redis, they shouldn't cause the
91124
state to be considered dirty either.
92125
"""
93-
self.dirty_vars.discard("_original_substates")
94126
self.dirty_vars.discard("_previous_dirty_vars")
95127
# Only mark dirty if there are still dirty vars, or any substate is dirty
96128
if self.dirty_vars or any(
@@ -145,60 +177,47 @@ async def _link_to(self, token: str):
145177
self._reflex_internal_links[state_name] = token
146178

147179
# Get the newly linked state and update pointers/delta for subsequent events.
148-
original_substate = self
149180
async with get_state_manager().modify_state(
150181
_substate_key(token, type(self))
151182
) as linked_root_state:
152183
linked_state = await linked_root_state.get_state(type(self))
153-
linked_parent_state = linked_state.parent_state
154184
linked_state._linked_from.add(self.router.session.client_token)
155185
linked_state._linked_to = token
156-
try:
157-
if (parent_state := self.parent_state) is not None:
158-
parent_state.substates[self.get_name()] = linked_state
159-
linked_state.parent_state = parent_state
160-
linked_state.dirty_vars.update(self.base_vars)
161-
linked_state.dirty_vars.update(self.backend_vars)
162-
linked_state.dirty_vars.update(self.computed_vars)
163-
linked_state._mark_dirty()
164-
# Apply the updates into the existing state tree, then rehydrate.
165-
root_state = self._get_root_state()
166-
await root_state._get_resolved_delta()
167-
finally:
168-
# Put the tree back together for now, since we're about to drop the lock.
169-
if self.parent_state is not None:
170-
self.parent_state.substates[self.get_name()] = original_substate
171-
linked_state.parent_state = linked_parent_state
172-
return self._rehydrate()
186+
async with _patch_state(
187+
original_state=self,
188+
linked_state=linked_state,
189+
full_delta=True,
190+
):
191+
return self._rehydrate()
173192

174193
async def _unlink(self):
175194
"""Unlink this shared state from its linked token.
176195
177196
Returns:
178197
The events to rehydrate the state after unlinking (these should be returned/yielded
179198
"""
199+
from reflex.istate.manager import get_state_manager
200+
180201
state_name = self.get_full_name()
181202
if state_name not in self._reflex_internal_links:
182203
msg = f"State {state_name} is not linked and cannot be unlinked."
183204
raise ReflexRuntimeError(msg)
205+
206+
# Break the linkage for future events.
184207
self._reflex_internal_links.pop(state_name)
185208
self._linked_from.discard(self.router.session.client_token)
186-
# Rehydrate after unlinking to restore original values.
187-
return self._rehydrate()
188-
189-
async def _restore_original_substates(self, *_exc_info) -> None:
190-
"""Restore the original substates that were linked."""
191-
root_state = self._get_root_state()
192-
for linked_state_name, (
193-
original_state,
194-
linked_parent_state,
195-
) in self._original_substates.items():
196-
linked_state_cls = root_state.get_class_substate(linked_state_name)
197-
linked_state = await root_state.get_state(linked_state_cls)
198-
if (parent_state := linked_state.parent_state) is not None:
199-
parent_state.substates[original_state.get_name()] = original_state
200-
linked_state.parent_state = linked_parent_state
201-
self._original_substates = {}
209+
210+
# Patch in the original state, apply updates, then rehydrate.
211+
private_root_state = await get_state_manager().get_state(
212+
_substate_key(self.router.session.client_token, type(self))
213+
)
214+
private_state = await private_root_state.get_state(type(self))
215+
async with _patch_state(
216+
original_state=self,
217+
linked_state=private_state,
218+
full_delta=True,
219+
):
220+
return self._rehydrate()
202221

203222
@contextlib.asynccontextmanager
204223
async def _modify_linked_states(
@@ -245,25 +264,18 @@ async def _modify_linked_states(
245264
_substate_key(linked_token, linked_state_cls)
246265
)
247266
linked_state = await linked_root_state.get_state(linked_state_cls)
267+
linked_states.append(linked_state)
248268
linked_state._linked_to = linked_token
249269
linked_state._linked_from.add(self.router.session.client_token)
250-
self._original_substates[linked_state_name] = (
251-
original_state,
252-
linked_state.parent_state,
270+
await exit_stack.enter_async_context(
271+
_patch_state(original_state, linked_state)
253272
)
254-
if (parent_state := original_state.parent_state) is not None:
255-
parent_state.substates[original_state.get_name()] = linked_state
256-
linked_state.parent_state = parent_state
257-
linked_states.append(linked_state)
258273
if (
259274
previous_dirty_vars
260275
and (dv := previous_dirty_vars.get(linked_state_name)) is not None
261276
):
262277
linked_state.dirty_vars.update(dv)
263278
linked_state._mark_dirty()
264-
# Make sure to restore the non-linked substates after exiting the context.
265-
if self._original_substates:
266-
exit_stack.push_async_exit(self._restore_original_substates)
267279
async with exit_stack:
268280
yield None
269281
# Collect dirty vars and other affected clients that need to be updated.

0 commit comments

Comments
 (0)