33import asyncio
44import contextlib
55from collections .abc import AsyncIterator
6+ from typing import Self , TypeVar
67
78from reflex .event import Event , get_hydrate_event
89from reflex .state import BaseState , State , _override_base_method , _substate_key
910from reflex .utils .exceptions import ReflexRuntimeError
1011
1112UPDATE_OTHER_CLIENT_TASKS : set [asyncio .Task ] = set ()
13+ LINKED_STATE = TypeVar ("LINKED_STATE" , bound = "SharedStateBaseInternal" )
1214
1315
1416def _do_update_other_tokens (
@@ -92,15 +94,19 @@ async def _patch_state(
9294class SharedStateBaseInternal (State ):
9395 """The private base state for all shared states."""
9496
97+ _exit_stack : contextlib .AsyncExitStack | None = None
98+ _held_locks : dict [str , dict [type [BaseState ], BaseState ]] = {}
99+
95100 def __getstate__ (self ):
96101 """Override redis serialization to remove temporary fields.
97102
98103 Returns:
99104 The state dictionary without temporary fields.
100105 """
101106 s = super ().__getstate__ ()
102- # Don't want to persist the cached substates
103107 s .pop ("_previous_dirty_vars" , None )
108+ s .pop ("_exit_stack" , None )
109+ s .pop ("_held_locks" , None )
104110 return s
105111
106112 @_override_base_method
@@ -124,6 +130,8 @@ def _mark_dirty(self):
124130 state to be considered dirty either.
125131 """
126132 self .dirty_vars .discard ("_previous_dirty_vars" )
133+ self .dirty_vars .discard ("_exit_stack" )
134+ self .dirty_vars .discard ("_held_locks" )
127135 # Only mark dirty if there are still dirty vars, or any substate is dirty
128136 if self .dirty_vars or any (
129137 substate .dirty_vars for substate in self .substates .values ()
@@ -144,7 +152,7 @@ def _rehydrate(self):
144152 State .set_is_hydrated (True ),
145153 ]
146154
147- async def _link_to (self , token : str ):
155+ async def _link_to (self , token : str ) -> Self :
148156 """Link this shared state to a token.
149157
150158 After linking, subsequent access to this shared state will affect the
@@ -155,15 +163,16 @@ async def _link_to(self, token: str):
155163 token: The token to link to.
156164
157165 Returns:
158- The events to rehydrate the state after linking (these should be returned/yielded).
159- """
160- from reflex .istate .manager import get_state_manager
166+ The newly linked state.
161167
168+ Raises:
169+ ReflexRuntimeError: If linking fails or token is invalid.
170+ """
162171 if not token :
163172 msg = "Cannot link shared state to empty token."
164173 raise ReflexRuntimeError (msg )
165174 if self ._linked_to == token :
166- return None # already linked to this token
175+ return self # already linked to this token
167176 if self ._linked_to and self ._linked_to != token :
168177 # Disassociate from previous linked token since unlink will not be called.
169178 self ._linked_from .discard (self .router .session .client_token )
@@ -174,21 +183,10 @@ async def _link_to(self, token: str):
174183
175184 # Associate substate with the given link token.
176185 state_name = self .get_full_name ()
186+ if self ._reflex_internal_links is None :
187+ self ._reflex_internal_links = {}
177188 self ._reflex_internal_links [state_name ] = token
178-
179- # Get the newly linked state and update pointers/delta for subsequent events.
180- async with get_state_manager ().modify_state (
181- _substate_key (token , type (self ))
182- ) as linked_root_state :
183- linked_state = await linked_root_state .get_state (type (self ))
184- linked_state ._linked_from .add (self .router .session .client_token )
185- linked_state ._linked_to = token
186- async with _patch_state (
187- original_state = self ,
188- linked_state = linked_state ,
189- full_delta = True ,
190- ):
191- return self ._rehydrate ()
189+ return await self ._internal_patch_linked_state (token , full_delta = True )
192190
193191 async def _unlink (self ):
194192 """Unlink this shared state from its linked token.
@@ -199,7 +197,10 @@ async def _unlink(self):
199197 from reflex .istate .manager import get_state_manager
200198
201199 state_name = self .get_full_name ()
202- if state_name not in self ._reflex_internal_links :
200+ if (
201+ not self ._reflex_internal_links
202+ or state_name not in self ._reflex_internal_links
203+ ):
203204 msg = f"State { state_name } is not linked and cannot be unlinked."
204205 raise ReflexRuntimeError (msg )
205206
@@ -219,6 +220,68 @@ async def _unlink(self):
219220 ):
220221 return self ._rehydrate ()
221222
223+ async def _internal_patch_linked_state (
224+ self , token : str , full_delta : bool = False
225+ ) -> Self :
226+ """Load and replace this state with the linked state for a given token.
227+
228+ Must be called inside a `_modify_linked_states` context, to ensure locks are
229+ released after the event is done processing.
230+
231+ Args:
232+ token: The token of the linked state.
233+ full_delta: If True, mark all Vars in linked_state dirty and resolve
234+ delta to update cached computed vars
235+
236+ Returns:
237+ The state that was linked into the tree.
238+ """
239+ from reflex .istate .manager import get_state_manager
240+
241+ if self ._exit_stack is None :
242+ msg = "Cannot link shared state outside of _modify_linked_states context."
243+ raise ReflexRuntimeError (msg )
244+
245+ # Get the newly linked state and update pointers/delta for subsequent events.
246+ if token not in self ._held_locks :
247+ linked_root_state = await self ._exit_stack .enter_async_context (
248+ get_state_manager ().modify_state (_substate_key (token , type (self )))
249+ )
250+ self ._held_locks .setdefault (token , {})
251+ else :
252+ linked_root_state = await get_state_manager ().get_state (
253+ _substate_key (token , type (self ))
254+ )
255+ linked_state = await linked_root_state .get_state (type (self ))
256+ # Avoid unnecessary dirtiness of shared state when there are no changes.
257+ if type (self ) not in self ._held_locks [token ]:
258+ self ._held_locks [token ][type (self )] = linked_state
259+ if self .router .session .client_token not in linked_state ._linked_from :
260+ linked_state ._linked_from .add (self .router .session .client_token )
261+ if linked_state ._linked_to != token :
262+ linked_state ._linked_to = token
263+ await self ._exit_stack .enter_async_context (
264+ _patch_state (
265+ original_state = self ,
266+ linked_state = linked_state ,
267+ full_delta = full_delta ,
268+ )
269+ )
270+ return linked_state
271+
272+ def _held_locks_linked_states (self ) -> list ["SharedState" ]:
273+ """Get all linked states currently held by this state.
274+
275+ Returns:
276+ The list of linked states currently held.
277+ """
278+ return [
279+ linked_state
280+ for linked_state_cls_to_instance in self ._held_locks .values ()
281+ for linked_state in linked_state_cls_to_instance .values ()
282+ if isinstance (linked_state , SharedState )
283+ ]
284+
222285 @contextlib .asynccontextmanager
223286 async def _modify_linked_states (
224287 self , previous_dirty_vars : dict [str , set [str ]] | None = None
@@ -236,67 +299,54 @@ async def _modify_linked_states(
236299 Yields:
237300 None.
238301 """
239- from reflex .istate .manager import get_state_manager
240-
241- exit_stack = contextlib .AsyncExitStack ()
242- held_locks : set [str ] = set ()
243- linked_states : list [BaseState ] = []
302+ if self ._exit_stack is not None :
303+ msg = "Cannot nest _modify_linked_states contexts."
304+ raise ReflexRuntimeError (msg )
305+ if self ._reflex_internal_links is None :
306+ msg = "No linked states to modify."
307+ raise ReflexRuntimeError (msg )
308+ self ._exit_stack = contextlib .AsyncExitStack ()
309+ self ._held_locks = {}
244310 current_dirty_vars : dict [str , set [str ]] = {}
245311 affected_tokens : set [str ] = set ()
246312 # Go through all linked states and patch them in if they are present in the tree
247313 for linked_state_name , linked_token in self ._reflex_internal_links .items ():
248- linked_state_cls = self .get_root_state ().get_class_substate (
249- linked_state_name
314+ linked_state_cls : type [SharedState ] = (
315+ self .get_root_state ().get_class_substate ( # pyright: ignore[reportAssignmentType]
316+ linked_state_name
317+ )
250318 )
251319 # TODO: Avoid always fetched linked states, it should be based on
252320 # whether the state is accessed, however then `get_state` would need
253321 # to know how to fetch in a linked state.
254322 original_state = await self .get_state (linked_state_cls )
255- if linked_token not in held_locks :
256- linked_root_state = await exit_stack .enter_async_context (
257- get_state_manager ().modify_state (
258- _substate_key (linked_token , linked_state_cls )
259- )
260- )
261- held_locks .add (linked_token )
262- else :
263- linked_root_state = await get_state_manager ().get_state (
264- _substate_key (linked_token , linked_state_cls )
265- )
266- linked_state = await linked_root_state .get_state (linked_state_cls )
267- linked_states .append (linked_state )
268- linked_state ._linked_to = linked_token
269- linked_state ._linked_from .add (self .router .session .client_token )
270- await exit_stack .enter_async_context (
271- _patch_state (original_state , linked_state )
323+ linked_state = await original_state ._internal_patch_linked_state (
324+ linked_token
272325 )
273326 if (
274327 previous_dirty_vars
275328 and (dv := previous_dirty_vars .get (linked_state_name )) is not None
276329 ):
277330 linked_state .dirty_vars .update (dv )
278331 linked_state ._mark_dirty ()
279- async with exit_stack :
332+ async with self . _exit_stack :
280333 yield None
281334 # Collect dirty vars and other affected clients that need to be updated.
282- for linked_state in linked_states :
283- if (
284- linked_state_previous_dirty_vars := getattr (
285- linked_state , "_previous_dirty_vars" , None
286- )
287- ) is not None :
335+ for linked_state in self ._held_locks_linked_states ():
336+ if linked_state ._previous_dirty_vars is not None :
288337 current_dirty_vars [linked_state .get_full_name ()] = set (
289- linked_state_previous_dirty_vars
338+ linked_state . _previous_dirty_vars
290339 )
291340 if (
292341 linked_state ._get_was_touched ()
293- or linked_state_previous_dirty_vars is not None
342+ or linked_state . _previous_dirty_vars is not None
294343 ):
295344 affected_tokens .update (
296345 token
297346 for token in linked_state ._linked_from
298347 if token != self .router .session .client_token
299348 )
349+ self ._exit_stack = None
300350
301351 # Only propagate dirty vars when we are not already propagating from another state.
302352 if previous_dirty_vars is None :
@@ -324,3 +374,6 @@ def __init_subclass__(cls, **kwargs):
324374 kwargs ["mixin" ] = False
325375 cls ._mixin = False
326376 super ().__init_subclass__ (** kwargs )
377+ root_state = cls .get_root_state ()
378+ if root_state .backend_vars ["_reflex_internal_links" ] is None :
379+ root_state .backend_vars ["_reflex_internal_links" ] = {}
0 commit comments