@@ -56,7 +56,7 @@ class SharedStateBaseInternal(State):
5656 """The private base state for all shared states."""
5757
5858 # While _modify_linked_states is active, this holds the original substates for the client's tree.
59- _original_substates : dict [str , tuple [BaseState , BaseState | None ]]
59+ _original_substates : dict [str , tuple [BaseState , BaseState | None ]] = {}
6060
6161 def __getstate__ (self ):
6262 """Override redis serialization to remove temporary fields.
@@ -125,23 +125,50 @@ async def _link_to(self, token: str):
125125 Returns:
126126 The events to rehydrate the state after linking (these should be returned/yielded).
127127 """
128+ from reflex .istate .manager import get_state_manager
129+
130+ if not token :
131+ msg = "Cannot link shared state to empty token."
132+ raise ReflexRuntimeError (msg )
133+ if self ._linked_to == token :
134+ return None # already linked to this token
135+ if self ._linked_to and self ._linked_to != token :
136+ # Disassociate from previous linked token since unlink will not be called.
137+ self ._linked_from .discard (self .router .session .client_token )
128138 # TODO: Change StateManager to accept token + class instead of combining them in a string.
129139 if "_" in token :
130140 msg = f"Invalid token { token } for linking state { self .get_full_name ()} , cannot use underscore (_) in the token name."
131141 raise ReflexRuntimeError (msg )
142+
143+ # Associate substate with the given link token.
132144 state_name = self .get_full_name ()
133145 self ._reflex_internal_links [state_name ] = token
134- async with self ._modify_linked_states () as _ :
135- linked_state = await self .get_state (type (self ))
146+
147+ # Get the newly linked state and update pointers/delta for subsequent events.
148+ original_substate = self
149+ async with get_state_manager ().modify_state (
150+ _substate_key (token , type (self ))
151+ ) as linked_root_state :
152+ linked_state = await linked_root_state .get_state (type (self ))
153+ linked_parent_state = linked_state .parent_state
136154 linked_state ._linked_from .add (self .router .session .client_token )
137155 linked_state ._linked_to = token
138- linked_state .dirty_vars .update (self .base_vars )
139- linked_state .dirty_vars .update (self .backend_vars )
140- linked_state .dirty_vars .update (self .computed_vars )
141- linked_state ._mark_dirty ()
142- # Apply the updates into the existing state tree, then rehydrate.
143- root_state = self ._get_root_state ()
144- await root_state ._get_resolved_delta ()
156+ try :
157+ if (parent_state := self .parent_state ) is not None :
158+ parent_state .substates [self .get_name ()] = linked_state
159+ linked_state .parent_state = parent_state
160+ linked_state .dirty_vars .update (self .base_vars )
161+ linked_state .dirty_vars .update (self .backend_vars )
162+ linked_state .dirty_vars .update (self .computed_vars )
163+ linked_state ._mark_dirty ()
164+ # Apply the updates into the existing state tree, then rehydrate.
165+ root_state = self ._get_root_state ()
166+ await root_state ._get_resolved_delta ()
167+ finally :
168+ # Put the tree back together for now, since we're about to drop the lock.
169+ if self .parent_state is not None :
170+ self .parent_state .substates [self .get_name ()] = original_substate
171+ linked_state .parent_state = linked_parent_state
145172 return self ._rehydrate ()
146173
147174 async def _unlink (self ):
@@ -218,6 +245,8 @@ async def _modify_linked_states(
218245 _substate_key (linked_token , linked_state_cls )
219246 )
220247 linked_state = await linked_root_state .get_state (linked_state_cls )
248+ linked_state ._linked_to = linked_token
249+ linked_state ._linked_from .add (self .router .session .client_token )
221250 self ._original_substates [linked_state_name ] = (
222251 original_state ,
223252 linked_state .parent_state ,
0 commit comments