88from reflex .constants import ROUTER_DATA
99from reflex .event import Event , get_hydrate_event
1010from reflex .state import BaseState , State , _override_base_method , _substate_key
11+ from reflex .utils import console
1112from reflex .utils .exceptions import ReflexRuntimeError
1213
1314UPDATE_OTHER_CLIENT_TASKS : set [asyncio .Task ] = set ()
1415LINKED_STATE = TypeVar ("LINKED_STATE" , bound = "SharedStateBaseInternal" )
1516
1617
18+ def _log_update_client_errors (task : asyncio .Task ):
19+ """Log errors from updating other clients.
20+
21+ Args:
22+ task: The asyncio task to check for errors.
23+ """
24+ try :
25+ task .result ()
26+ except Exception as e :
27+ console .warn (f"Error updating linked client: { e } " )
28+ finally :
29+ UPDATE_OTHER_CLIENT_TASKS .discard (task )
30+
31+
1732def _do_update_other_tokens (
1833 affected_tokens : set [str ],
1934 previous_dirty_vars : dict [str , set [str ]],
@@ -47,10 +62,10 @@ async def _update_client(token: str):
4762 # Don't send updates for disconnected clients.
4863 if affected_token not in app .event_namespace ._token_manager .token_to_socket :
4964 continue
50- # TODO: remove disconnected client's after some time.
65+ # TODO: remove disconnected clients after some time.
5166 t = asyncio .create_task (_update_client (affected_token ))
5267 UPDATE_OTHER_CLIENT_TASKS .add (t )
53- t .add_done_callback (UPDATE_OTHER_CLIENT_TASKS . discard )
68+ t .add_done_callback (_log_update_client_errors )
5469 tasks .append (t )
5570 return tasks
5671
@@ -99,7 +114,7 @@ class SharedStateBaseInternal(State):
99114 """The private base state for all shared states."""
100115
101116 _exit_stack : contextlib .AsyncExitStack | None = None
102- _held_locks : dict [str , dict [type [BaseState ], BaseState ]] = {}
117+ _held_locks : dict [str , dict [type [BaseState ], BaseState ]] | None = None
103118
104119 def __getstate__ (self ):
105120 """Override redis serialization to remove temporary fields.
@@ -164,7 +179,7 @@ async def _link_to(self, token: str) -> Self:
164179 clients linked to that token.
165180
166181 Args:
167- token: The token to link to.
182+ token: The token to link to (Cannot contain underscore characters) .
168183
169184 Returns:
170185 The newly linked state.
@@ -196,7 +211,7 @@ async def _unlink(self):
196211 """Unlink this shared state from its linked token.
197212
198213 Returns:
199- The events to rehydrate the state after unlinking (these should be returned/yielded
214+ The events to rehydrate the state after unlinking (these should be returned/yielded).
200215 """
201216 from reflex .istate .manager import get_state_manager
202217
@@ -242,7 +257,7 @@ async def _internal_patch_linked_state(
242257 """
243258 from reflex .istate .manager import get_state_manager
244259
245- if self ._exit_stack is None :
260+ if self ._exit_stack is None or self . _held_locks is None :
246261 msg = "Cannot link shared state outside of _modify_linked_states context."
247262 raise ReflexRuntimeError (msg )
248263
@@ -279,6 +294,8 @@ def _held_locks_linked_states(self) -> list["SharedState"]:
279294 Returns:
280295 The list of linked states currently held.
281296 """
297+ if self ._held_locks is None :
298+ return []
282299 return [
283300 linked_state
284301 for linked_state_cls_to_instance in self ._held_locks .values ()
@@ -313,44 +330,46 @@ async def _modify_linked_states(
313330 self ._held_locks = {}
314331 current_dirty_vars : dict [str , set [str ]] = {}
315332 affected_tokens : set [str ] = set ()
316- # Go through all linked states and patch them in if they are present in the tree
317- for linked_state_name , linked_token in self ._reflex_internal_links .items ():
318- linked_state_cls : type [SharedState ] = (
319- self .get_root_state ().get_class_substate ( # pyright: ignore[reportAssignmentType]
320- linked_state_name
321- )
322- )
323- # TODO: Avoid always fetched linked states, it should be based on
324- # whether the state is accessed, however then `get_state` would need
325- # to know how to fetch in a linked state.
326- original_state = await self .get_state (linked_state_cls )
327- linked_state = await original_state ._internal_patch_linked_state (
328- linked_token
329- )
330- if (
331- previous_dirty_vars
332- and (dv := previous_dirty_vars .get (linked_state_name )) is not None
333- ):
334- linked_state .dirty_vars .update (dv )
335- linked_state ._mark_dirty ()
336- async with self ._exit_stack :
337- yield None
338- # Collect dirty vars and other affected clients that need to be updated.
339- for linked_state in self ._held_locks_linked_states ():
340- if linked_state ._previous_dirty_vars is not None :
341- current_dirty_vars [linked_state .get_full_name ()] = set (
342- linked_state ._previous_dirty_vars
333+ try :
334+ # Go through all linked states and patch them in if they are present in the tree
335+ for linked_state_name , linked_token in self ._reflex_internal_links .items ():
336+ linked_state_cls : type [SharedState ] = (
337+ self .get_root_state ().get_class_substate ( # pyright: ignore[reportAssignmentType]
338+ linked_state_name
343339 )
340+ )
341+ # TODO: Avoid always fetched linked states, it should be based on
342+ # whether the state is accessed, however then `get_state` would need
343+ # to know how to fetch in a linked state.
344+ original_state = await self .get_state (linked_state_cls )
345+ linked_state = await original_state ._internal_patch_linked_state (
346+ linked_token
347+ )
344348 if (
345- linked_state . _get_was_touched ()
346- or linked_state . _previous_dirty_vars is not None
349+ previous_dirty_vars
350+ and ( dv := previous_dirty_vars . get ( linked_state_name )) is not None
347351 ):
348- affected_tokens .update (
349- token
350- for token in linked_state ._linked_from
351- if token != self .router .session .client_token
352- )
353- self ._exit_stack = None
352+ linked_state .dirty_vars .update (dv )
353+ linked_state ._mark_dirty ()
354+ async with self ._exit_stack :
355+ yield None
356+ # Collect dirty vars and other affected clients that need to be updated.
357+ for linked_state in self ._held_locks_linked_states ():
358+ if linked_state ._previous_dirty_vars is not None :
359+ current_dirty_vars [linked_state .get_full_name ()] = set (
360+ linked_state ._previous_dirty_vars
361+ )
362+ if (
363+ linked_state ._get_was_touched ()
364+ or linked_state ._previous_dirty_vars is not None
365+ ):
366+ affected_tokens .update (
367+ token
368+ for token in linked_state ._linked_from
369+ if token != self .router .session .client_token
370+ )
371+ finally :
372+ self ._exit_stack = None
354373
355374 # Only propagate dirty vars when we are not already propagating from another state.
356375 if previous_dirty_vars is None :
@@ -364,9 +383,9 @@ async def _modify_linked_states(
364383class SharedState (SharedStateBaseInternal , mixin = True ):
365384 """Mixin for defining new shared states."""
366385
367- _linked_from : set [str ]
368- _linked_to : str
369- _previous_dirty_vars : set [str ]
386+ _linked_from : set [str ] = set ()
387+ _linked_to : str = ""
388+ _previous_dirty_vars : set [str ] = set ()
370389
371390 @classmethod
372391 def __init_subclass__ (cls , ** kwargs ):
0 commit comments