Skip to content

Commit faf620b

Browse files
committed
_link_to returns the newly linked state
_modify_linked_states context can now release the locks of newly linked states and send updates for changes in newly linked states. rehydrating after linking is no longer necessary.
1 parent 6984750 commit faf620b

File tree

4 files changed

+144
-60
lines changed

4 files changed

+144
-60
lines changed

reflex/istate/manager/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ async def modify_state_with_links(
132132
The state for the token with linked states patched in.
133133
"""
134134
async with self.modify_state(token, **context) as root_state:
135-
if getattr(root_state, "_reflex_internal_links", None):
135+
if getattr(root_state, "_reflex_internal_links", None) is not None:
136136
from reflex.istate.shared import SharedStateBaseInternal
137137

138138
shared_state = await root_state.get_state(SharedStateBaseInternal)

reflex/istate/shared.py

Lines changed: 107 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import asyncio
44
import contextlib
55
from collections.abc import AsyncIterator
6+
from typing import Self, TypeVar
67

78
from reflex.event import Event, get_hydrate_event
89
from reflex.state import BaseState, State, _override_base_method, _substate_key
910
from reflex.utils.exceptions import ReflexRuntimeError
1011

1112
UPDATE_OTHER_CLIENT_TASKS: set[asyncio.Task] = set()
13+
LINKED_STATE = TypeVar("LINKED_STATE", bound="SharedStateBaseInternal")
1214

1315

1416
def _do_update_other_tokens(
@@ -92,15 +94,19 @@ async def _patch_state(
9294
class SharedStateBaseInternal(State):
9395
"""The private base state for all shared states."""
9496

97+
_exit_stack: contextlib.AsyncExitStack | None = None
98+
_held_locks: dict[str, dict[type[BaseState], BaseState]] = {}
99+
95100
def __getstate__(self):
96101
"""Override redis serialization to remove temporary fields.
97102
98103
Returns:
99104
The state dictionary without temporary fields.
100105
"""
101106
s = super().__getstate__()
102-
# Don't want to persist the cached substates
103107
s.pop("_previous_dirty_vars", None)
108+
s.pop("_exit_stack", None)
109+
s.pop("_held_locks", None)
104110
return s
105111

106112
@_override_base_method
@@ -124,6 +130,8 @@ def _mark_dirty(self):
124130
state to be considered dirty either.
125131
"""
126132
self.dirty_vars.discard("_previous_dirty_vars")
133+
self.dirty_vars.discard("_exit_stack")
134+
self.dirty_vars.discard("_held_locks")
127135
# Only mark dirty if there are still dirty vars, or any substate is dirty
128136
if self.dirty_vars or any(
129137
substate.dirty_vars for substate in self.substates.values()
@@ -144,7 +152,7 @@ def _rehydrate(self):
144152
State.set_is_hydrated(True),
145153
]
146154

147-
async def _link_to(self, token: str):
155+
async def _link_to(self, token: str) -> Self:
148156
"""Link this shared state to a token.
149157
150158
After linking, subsequent access to this shared state will affect the
@@ -155,15 +163,16 @@ async def _link_to(self, token: str):
155163
token: The token to link to.
156164
157165
Returns:
158-
The events to rehydrate the state after linking (these should be returned/yielded).
159-
"""
160-
from reflex.istate.manager import get_state_manager
166+
The newly linked state.
161167
168+
Raises:
169+
ReflexRuntimeError: If linking fails or token is invalid.
170+
"""
162171
if not token:
163172
msg = "Cannot link shared state to empty token."
164173
raise ReflexRuntimeError(msg)
165174
if self._linked_to == token:
166-
return None # already linked to this token
175+
return self # already linked to this token
167176
if self._linked_to and self._linked_to != token:
168177
# Disassociate from previous linked token since unlink will not be called.
169178
self._linked_from.discard(self.router.session.client_token)
@@ -174,21 +183,10 @@ async def _link_to(self, token: str):
174183

175184
# Associate substate with the given link token.
176185
state_name = self.get_full_name()
186+
if self._reflex_internal_links is None:
187+
self._reflex_internal_links = {}
177188
self._reflex_internal_links[state_name] = token
178-
179-
# Get the newly linked state and update pointers/delta for subsequent events.
180-
async with get_state_manager().modify_state(
181-
_substate_key(token, type(self))
182-
) as linked_root_state:
183-
linked_state = await linked_root_state.get_state(type(self))
184-
linked_state._linked_from.add(self.router.session.client_token)
185-
linked_state._linked_to = token
186-
async with _patch_state(
187-
original_state=self,
188-
linked_state=linked_state,
189-
full_delta=True,
190-
):
191-
return self._rehydrate()
189+
return await self._internal_patch_linked_state(token, full_delta=True)
192190

193191
async def _unlink(self):
194192
"""Unlink this shared state from its linked token.
@@ -199,7 +197,10 @@ async def _unlink(self):
199197
from reflex.istate.manager import get_state_manager
200198

201199
state_name = self.get_full_name()
202-
if state_name not in self._reflex_internal_links:
200+
if (
201+
not self._reflex_internal_links
202+
or state_name not in self._reflex_internal_links
203+
):
203204
msg = f"State {state_name} is not linked and cannot be unlinked."
204205
raise ReflexRuntimeError(msg)
205206

@@ -219,6 +220,68 @@ async def _unlink(self):
219220
):
220221
return self._rehydrate()
221222

223+
async def _internal_patch_linked_state(
224+
self, token: str, full_delta: bool = False
225+
) -> Self:
226+
"""Load and replace this state with the linked state for a given token.
227+
228+
Must be called inside a `_modify_linked_states` context, to ensure locks are
229+
released after the event is done processing.
230+
231+
Args:
232+
token: The token of the linked state.
233+
full_delta: If True, mark all Vars in linked_state dirty and resolve
234+
delta to update cached computed vars
235+
236+
Returns:
237+
The state that was linked into the tree.
238+
"""
239+
from reflex.istate.manager import get_state_manager
240+
241+
if self._exit_stack is None:
242+
msg = "Cannot link shared state outside of _modify_linked_states context."
243+
raise ReflexRuntimeError(msg)
244+
245+
# Get the newly linked state and update pointers/delta for subsequent events.
246+
if token not in self._held_locks:
247+
linked_root_state = await self._exit_stack.enter_async_context(
248+
get_state_manager().modify_state(_substate_key(token, type(self)))
249+
)
250+
self._held_locks.setdefault(token, {})
251+
else:
252+
linked_root_state = await get_state_manager().get_state(
253+
_substate_key(token, type(self))
254+
)
255+
linked_state = await linked_root_state.get_state(type(self))
256+
# Avoid unnecessary dirtiness of shared state when there are no changes.
257+
if type(self) not in self._held_locks[token]:
258+
self._held_locks[token][type(self)] = linked_state
259+
if self.router.session.client_token not in linked_state._linked_from:
260+
linked_state._linked_from.add(self.router.session.client_token)
261+
if linked_state._linked_to != token:
262+
linked_state._linked_to = token
263+
await self._exit_stack.enter_async_context(
264+
_patch_state(
265+
original_state=self,
266+
linked_state=linked_state,
267+
full_delta=full_delta,
268+
)
269+
)
270+
return linked_state
271+
272+
def _held_locks_linked_states(self) -> list["SharedState"]:
273+
"""Get all linked states currently held by this state.
274+
275+
Returns:
276+
The list of linked states currently held.
277+
"""
278+
return [
279+
linked_state
280+
for linked_state_cls_to_instance in self._held_locks.values()
281+
for linked_state in linked_state_cls_to_instance.values()
282+
if isinstance(linked_state, SharedState)
283+
]
284+
222285
@contextlib.asynccontextmanager
223286
async def _modify_linked_states(
224287
self, previous_dirty_vars: dict[str, set[str]] | None = None
@@ -236,67 +299,54 @@ async def _modify_linked_states(
236299
Yields:
237300
None.
238301
"""
239-
from reflex.istate.manager import get_state_manager
240-
241-
exit_stack = contextlib.AsyncExitStack()
242-
held_locks: set[str] = set()
243-
linked_states: list[BaseState] = []
302+
if self._exit_stack is not None:
303+
msg = "Cannot nest _modify_linked_states contexts."
304+
raise ReflexRuntimeError(msg)
305+
if self._reflex_internal_links is None:
306+
msg = "No linked states to modify."
307+
raise ReflexRuntimeError(msg)
308+
self._exit_stack = contextlib.AsyncExitStack()
309+
self._held_locks = {}
244310
current_dirty_vars: dict[str, set[str]] = {}
245311
affected_tokens: set[str] = set()
246312
# Go through all linked states and patch them in if they are present in the tree
247313
for linked_state_name, linked_token in self._reflex_internal_links.items():
248-
linked_state_cls = self.get_root_state().get_class_substate(
249-
linked_state_name
314+
linked_state_cls: type[SharedState] = (
315+
self.get_root_state().get_class_substate( # pyright: ignore[reportAssignmentType]
316+
linked_state_name
317+
)
250318
)
251319
# TODO: Avoid always fetched linked states, it should be based on
252320
# whether the state is accessed, however then `get_state` would need
253321
# to know how to fetch in a linked state.
254322
original_state = await self.get_state(linked_state_cls)
255-
if linked_token not in held_locks:
256-
linked_root_state = await exit_stack.enter_async_context(
257-
get_state_manager().modify_state(
258-
_substate_key(linked_token, linked_state_cls)
259-
)
260-
)
261-
held_locks.add(linked_token)
262-
else:
263-
linked_root_state = await get_state_manager().get_state(
264-
_substate_key(linked_token, linked_state_cls)
265-
)
266-
linked_state = await linked_root_state.get_state(linked_state_cls)
267-
linked_states.append(linked_state)
268-
linked_state._linked_to = linked_token
269-
linked_state._linked_from.add(self.router.session.client_token)
270-
await exit_stack.enter_async_context(
271-
_patch_state(original_state, linked_state)
323+
linked_state = await original_state._internal_patch_linked_state(
324+
linked_token
272325
)
273326
if (
274327
previous_dirty_vars
275328
and (dv := previous_dirty_vars.get(linked_state_name)) is not None
276329
):
277330
linked_state.dirty_vars.update(dv)
278331
linked_state._mark_dirty()
279-
async with exit_stack:
332+
async with self._exit_stack:
280333
yield None
281334
# Collect dirty vars and other affected clients that need to be updated.
282-
for linked_state in linked_states:
283-
if (
284-
linked_state_previous_dirty_vars := getattr(
285-
linked_state, "_previous_dirty_vars", None
286-
)
287-
) is not None:
335+
for linked_state in self._held_locks_linked_states():
336+
if linked_state._previous_dirty_vars is not None:
288337
current_dirty_vars[linked_state.get_full_name()] = set(
289-
linked_state_previous_dirty_vars
338+
linked_state._previous_dirty_vars
290339
)
291340
if (
292341
linked_state._get_was_touched()
293-
or linked_state_previous_dirty_vars is not None
342+
or linked_state._previous_dirty_vars is not None
294343
):
295344
affected_tokens.update(
296345
token
297346
for token in linked_state._linked_from
298347
if token != self.router.session.client_token
299348
)
349+
self._exit_stack = None
300350

301351
# Only propagate dirty vars when we are not already propagating from another state.
302352
if previous_dirty_vars is None:
@@ -324,3 +374,6 @@ def __init_subclass__(cls, **kwargs):
324374
kwargs["mixin"] = False
325375
cls._mixin = False
326376
super().__init_subclass__(**kwargs)
377+
root_state = cls.get_root_state()
378+
if root_state.backend_vars["_reflex_internal_links"] is None:
379+
root_state.backend_vars["_reflex_internal_links"] = {}

reflex/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2465,7 +2465,7 @@ class State(BaseState):
24652465
# The hydrated bool.
24662466
is_hydrated: bool = False
24672467
# Maps the state full_name to an arbitrary token it is linked to for shared state.
2468-
_reflex_internal_links: dict[str, str] = {}
2468+
_reflex_internal_links: dict[str, str] | None = None
24692469

24702470
@event
24712471
def set_is_hydrated(self, value: bool) -> None:

tests/integration/test_linked_state.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
def LinkedStateApp():
1818
"""Test that linked state works as expected."""
19+
import uuid
1920
from typing import Any
2021

2122
import reflex as rx
@@ -36,23 +37,31 @@ def set_who(self, who: str) -> None:
3637

3738
@rx.event
3839
async def link_to(self, token: str):
39-
return await self._link_to(token)
40+
await self._link_to(token)
41+
42+
@rx.event
43+
async def link_to_and_increment(self):
44+
linked_state = await self._link_to(f"arbitrary-token-{uuid.uuid4()}")
45+
linked_state.counter += 1
4046

4147
@rx.event
4248
async def unlink(self):
4349
return await self._unlink()
4450

4551
@rx.event
4652
async def on_load_link_default(self):
47-
return await self._link_to(self.room or "default")
53+
linked_state = await self._link_to(self.room or "default")
54+
if self.room:
55+
assert linked_state._linked_to == self.room
56+
else:
57+
assert linked_state._linked_to == "default"
4858

4959
@rx.event
5060
async def handle_submit(self, form_data: dict[str, Any]):
5161
if "who" in form_data:
5262
self.set_who(form_data["who"])
5363
if "token" in form_data:
54-
return await self.link_to(form_data["token"])
55-
return None
64+
await self.link_to(form_data["token"])
5665

5766
class PrivateState(rx.State):
5867
@rx.var
@@ -126,6 +135,11 @@ def index() -> rx.Component:
126135
on_click=PrivateState.bump_counter_yield,
127136
id="yield-button",
128137
),
138+
rx.button(
139+
"Link to arbitrary token and Increment n_changes",
140+
on_click=SharedState.link_to_and_increment,
141+
id="link-increment-button",
142+
),
129143
)
130144

131145
app = rx.App()
@@ -355,3 +369,20 @@ def test_linked_state(
355369
== "Hello, world!"
356370
)
357371
assert linked_state.poll_for_content(counter_button_1, exp_not_equal="48") == "0"
372+
counter_button_1.click()
373+
assert linked_state.poll_for_content(counter_button_1, exp_not_equal="0") == "1"
374+
counter_button_1.click()
375+
assert linked_state.poll_for_content(counter_button_1, exp_not_equal="1") == "2"
376+
counter_button_1.click()
377+
assert linked_state.poll_for_content(counter_button_1, exp_not_equal="2") == "3"
378+
# Ensure other tabs are unaffected
379+
assert n_changes_2.text == "2"
380+
assert greeting_2.text == "Hello, Diana!"
381+
assert counter_button_2.text == "48"
382+
assert n_changes_3.text == "2"
383+
assert greeting_3.text == "Hello, Diana!"
384+
assert counter_button_3.text == "48"
385+
386+
# Link to a new state and increment the counter in the same event
387+
tab1.find_element(By.ID, "link-increment-button").click()
388+
assert linked_state.poll_for_content(counter_button_1, exp_not_equal="3") == "1"

0 commit comments

Comments
 (0)