@@ -52,12 +52,46 @@ async def _update_client(token: str):
5252 return tasks
5353
5454
55+ @contextlib .asynccontextmanager
56+ async def _patch_state (
57+ original_state : BaseState , linked_state : BaseState , full_delta : bool = False
58+ ):
59+ """Patch the linked state into the original state's tree, restoring it afterward.
60+
61+ Args:
62+ original_state: The original shared state.
63+ linked_state: The linked shared state.
64+ full_delta: If True, mark all Vars in linked_state dirty and resolve
65+ the delta from the root. This option is used when linking or unlinking
66+ to ensure that other computed vars in the tree pick up the newly
67+ linked/unlinked values.
68+ """
69+ if (original_parent_state := original_state .parent_state ) is None :
70+ msg = "Cannot patch root state as linked state."
71+ raise ReflexRuntimeError (msg )
72+
73+ state_name = original_state .get_name ()
74+ original_parent_state .substates [state_name ] = linked_state
75+ linked_parent_state = linked_state .parent_state
76+ linked_state .parent_state = original_parent_state
77+ try :
78+ if full_delta :
79+ linked_state .dirty_vars .update (linked_state .base_vars )
80+ linked_state .dirty_vars .update (linked_state .backend_vars )
81+ linked_state .dirty_vars .update (linked_state .computed_vars )
82+ linked_state ._mark_dirty ()
83+ # Apply the updates into the existing state tree for rehydrate.
84+ root_state = original_state ._get_root_state ()
85+ await root_state ._get_resolved_delta ()
86+ yield
87+ finally :
88+ original_parent_state .substates [state_name ] = original_state
89+ linked_state .parent_state = linked_parent_state
90+
91+
5592class SharedStateBaseInternal (State ):
5693 """The private base state for all shared states."""
5794
58- # While _modify_linked_states is active, this holds the original substates for the client's tree.
59- _original_substates : dict [str , tuple [BaseState , BaseState | None ]] = {}
60-
6195 def __getstate__ (self ):
6296 """Override redis serialization to remove temporary fields.
6397
@@ -66,7 +100,6 @@ def __getstate__(self):
66100 """
67101 s = super ().__getstate__ ()
68102 # Don't want to persist the cached substates
69- s .pop ("_original_substates" , None )
70103 s .pop ("_previous_dirty_vars" , None )
71104 return s
72105
@@ -90,7 +123,6 @@ def _mark_dirty(self):
90123 Since these internal fields are not persisted to redis, they shouldn't cause the
91124 state to be considered dirty either.
92125 """
93- self .dirty_vars .discard ("_original_substates" )
94126 self .dirty_vars .discard ("_previous_dirty_vars" )
95127 # Only mark dirty if there are still dirty vars, or any substate is dirty
96128 if self .dirty_vars or any (
@@ -145,60 +177,47 @@ async def _link_to(self, token: str):
145177 self ._reflex_internal_links [state_name ] = token
146178
147179 # Get the newly linked state and update pointers/delta for subsequent events.
148- original_substate = self
149180 async with get_state_manager ().modify_state (
150181 _substate_key (token , type (self ))
151182 ) as linked_root_state :
152183 linked_state = await linked_root_state .get_state (type (self ))
153- linked_parent_state = linked_state .parent_state
154184 linked_state ._linked_from .add (self .router .session .client_token )
155185 linked_state ._linked_to = token
156- try :
157- if (parent_state := self .parent_state ) is not None :
158- parent_state .substates [self .get_name ()] = linked_state
159- linked_state .parent_state = parent_state
160- linked_state .dirty_vars .update (self .base_vars )
161- linked_state .dirty_vars .update (self .backend_vars )
162- linked_state .dirty_vars .update (self .computed_vars )
163- linked_state ._mark_dirty ()
164- # Apply the updates into the existing state tree, then rehydrate.
165- root_state = self ._get_root_state ()
166- await root_state ._get_resolved_delta ()
167- finally :
168- # Put the tree back together for now, since we're about to drop the lock.
169- if self .parent_state is not None :
170- self .parent_state .substates [self .get_name ()] = original_substate
171- linked_state .parent_state = linked_parent_state
172- return self ._rehydrate ()
186+ async with _patch_state (
187+ original_state = self ,
188+ linked_state = linked_state ,
189+ full_delta = True ,
190+ ):
191+ return self ._rehydrate ()
173192
174193 async def _unlink (self ):
175194 """Unlink this shared state from its linked token.
176195
177196 Returns:
178197 The events to rehydrate the state after unlinking (these should be returned/yielded
179198 """
199+ from reflex .istate .manager import get_state_manager
200+
180201 state_name = self .get_full_name ()
181202 if state_name not in self ._reflex_internal_links :
182203 msg = f"State { state_name } is not linked and cannot be unlinked."
183204 raise ReflexRuntimeError (msg )
205+
206+ # Break the linkage for future events.
184207 self ._reflex_internal_links .pop (state_name )
185208 self ._linked_from .discard (self .router .session .client_token )
186- # Rehydrate after unlinking to restore original values.
187- return self ._rehydrate ()
188-
189- async def _restore_original_substates (self , * _exc_info ) -> None :
190- """Restore the original substates that were linked."""
191- root_state = self ._get_root_state ()
192- for linked_state_name , (
193- original_state ,
194- linked_parent_state ,
195- ) in self ._original_substates .items ():
196- linked_state_cls = root_state .get_class_substate (linked_state_name )
197- linked_state = await root_state .get_state (linked_state_cls )
198- if (parent_state := linked_state .parent_state ) is not None :
199- parent_state .substates [original_state .get_name ()] = original_state
200- linked_state .parent_state = linked_parent_state
201- self ._original_substates = {}
209+
210+ # Patch in the original state, apply updates, then rehydrate.
211+ private_root_state = await get_state_manager ().get_state (
212+ _substate_key (self .router .session .client_token , type (self ))
213+ )
214+ private_state = await private_root_state .get_state (type (self ))
215+ async with _patch_state (
216+ original_state = self ,
217+ linked_state = private_state ,
218+ full_delta = True ,
219+ ):
220+ return self ._rehydrate ()
202221
203222 @contextlib .asynccontextmanager
204223 async def _modify_linked_states (
@@ -245,25 +264,18 @@ async def _modify_linked_states(
245264 _substate_key (linked_token , linked_state_cls )
246265 )
247266 linked_state = await linked_root_state .get_state (linked_state_cls )
267+ linked_states .append (linked_state )
248268 linked_state ._linked_to = linked_token
249269 linked_state ._linked_from .add (self .router .session .client_token )
250- self ._original_substates [linked_state_name ] = (
251- original_state ,
252- linked_state .parent_state ,
270+ await exit_stack .enter_async_context (
271+ _patch_state (original_state , linked_state )
253272 )
254- if (parent_state := original_state .parent_state ) is not None :
255- parent_state .substates [original_state .get_name ()] = linked_state
256- linked_state .parent_state = parent_state
257- linked_states .append (linked_state )
258273 if (
259274 previous_dirty_vars
260275 and (dv := previous_dirty_vars .get (linked_state_name )) is not None
261276 ):
262277 linked_state .dirty_vars .update (dv )
263278 linked_state ._mark_dirty ()
264- # Make sure to restore the non-linked substates after exiting the context.
265- if self ._original_substates :
266- exit_stack .push_async_exit (self ._restore_original_substates )
267279 async with exit_stack :
268280 yield None
269281 # Collect dirty vars and other affected clients that need to be updated.
0 commit comments