Skip to content

Commit 3d40cac

Browse files
committed
ENG-8258: API for linking and sharing states
It works by defining a substate of SharedState and then calling self._link_to(target_token) from some event handler. from that point on, whenever that user's state is loaded, the StateManager will patch in the linked shared states. whenever a linked state is modified, we explicitly load all of the other linked tokens, patch in the modified states, and send a delta to those clients You can call ._unlink to remove the link association, which causes the substate to be subsequently loaded from the client_token's tree as a private state It is intended to work transparently with computed vars, background events, and frontend rendering.
1 parent fad67cf commit 3d40cac

File tree

6 files changed

+327
-6
lines changed

6 files changed

+327
-6
lines changed

pyi_hashes.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"reflex/__init__.pyi": "b304ed6f7a2fa028a194cad81bd83112",
2+
"reflex/__init__.pyi": "0a3ae880e256b9fd3b960e12a2cb51a7",
33
"reflex/components/__init__.pyi": "ac05995852baa81062ba3d18fbc489fb",
44
"reflex/components/base/__init__.pyi": "16e47bf19e0d62835a605baa3d039c5a",
55
"reflex/components/base/app_wrap.pyi": "22e94feaa9fe675bcae51c412f5b67f1",

reflex/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@
336336
"State",
337337
"dynamic",
338338
],
339+
"istate.shared": ["SharedState"],
339340
"istate.wrappers": ["get_state"],
340341
"style": ["Style", "toggle_color_mode"],
341342
"utils.imports": ["ImportDict", "ImportVar"],

reflex/app.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,13 +1562,17 @@ def all_routes(_request: Request) -> Response:
15621562

15631563
@contextlib.asynccontextmanager
15641564
async def modify_state(
1565-
self, token: str, background: bool = False
1565+
self,
1566+
token: str,
1567+
background: bool = False,
1568+
previous_dirty_vars: set[str] | None = None,
15661569
) -> AsyncIterator[BaseState]:
15671570
"""Modify the state out of band.
15681571
15691572
Args:
15701573
token: The token to modify the state for.
15711574
background: Whether the modification is happening in a background task.
1575+
previous_dirty_vars: Vars that are considered dirty from a previous operation.
15721576
15731577
Yields:
15741578
The state to modify.
@@ -1581,7 +1585,9 @@ async def modify_state(
15811585
raise RuntimeError(msg)
15821586

15831587
# Get exclusive access to the state.
1584-
async with self.state_manager.modify_state(token) as state:
1588+
async with self.state_manager.modify_state_with_links(
1589+
token, previous_dirty_vars=previous_dirty_vars
1590+
) as state:
15851591
# No other event handler can modify the state while in this context.
15861592
yield state
15871593
delta = await state._get_resolved_delta()
@@ -1769,7 +1775,7 @@ async def process(
17691775
constants.RouteVar.CLIENT_IP: client_ip,
17701776
})
17711777
# Get the state for the session exclusively.
1772-
async with app.state_manager.modify_state(
1778+
async with app.state_manager.modify_state_with_links(
17731779
event.substate_token, event=event
17741780
) as state:
17751781
# When this is a brand new instance of the state, signal the
@@ -2003,7 +2009,9 @@ async def _ndjson_updates():
20032009
Each state update as JSON followed by a new line.
20042010
"""
20052011
# Process the event.
2006-
async with app.state_manager.modify_state(event.substate_token) as state:
2012+
async with app.state_manager.modify_state_with_links(
2013+
event.substate_token
2014+
) as state:
20072015
async for update in state._process(event):
20082016
# Postprocess the event.
20092017
update = await app._postprocess(state, event, update)

reflex/istate/manager/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,36 @@ async def modify_state(
114114
"""
115115
yield self.state()
116116

117+
@contextlib.asynccontextmanager
118+
async def modify_state_with_links(
119+
self,
120+
token: str,
121+
previous_dirty_vars: set[str] | None = None,
122+
**context: Unpack[StateModificationContext],
123+
) -> AsyncIterator[BaseState]:
124+
"""Modify the state for a token, including linked substates, while holding exclusive lock.
125+
126+
Args:
127+
token: The token to modify the state for.
128+
previous_dirty_vars: The previously dirty vars for linked states.
129+
context: The state modification context.
130+
131+
Yields:
132+
The state for the token with linked states patched in.
133+
"""
134+
from reflex.istate.shared import SharedStateBaseInternal
135+
136+
shared_state_name = SharedStateBaseInternal.get_name()
137+
138+
async with self.modify_state(token, **context) as root_state:
139+
if shared_state_name in root_state.substates:
140+
async with root_state.substates[
141+
shared_state_name
142+
]._modify_linked_states(previous_dirty_vars=previous_dirty_vars) as _:
143+
yield root_state
144+
else:
145+
yield root_state
146+
117147
async def close(self): # noqa: B027
118148
"""Close the state manager."""
119149

reflex/istate/shared.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
"""Base classes for shared / linked states."""
2+
3+
import contextlib
4+
from collections.abc import AsyncIterator
5+
6+
from reflex.event import Event, get_hydrate_event
7+
from reflex.state import BaseState, State, _override_base_method, _substate_key
8+
from reflex.utils.exceptions import ReflexRuntimeError
9+
10+
11+
class SharedStateBaseInternal(State):
12+
"""The private base state for all shared states."""
13+
14+
# Maps the state full_name to an arbitrary token it is linked to.
15+
_links: dict[str, str]
16+
# While _modify_linked_states is active, this holds the original substates for the client's tree.
17+
_original_substates: dict[str, tuple[BaseState, BaseState | None]]
18+
19+
@classmethod
20+
def _init_var_dependency_dicts(cls):
21+
super()._init_var_dependency_dicts()
22+
if (
23+
"_links" in cls.inherited_backend_vars
24+
or (parent_state_cls := cls.get_parent_state()) is None
25+
):
26+
return
27+
# Mark the internal state as always dirty so the state manager
28+
# automatically fetches this state containing the _links.
29+
parent_state_cls._always_dirty_substates.add(cls.get_name())
30+
31+
def __getstate__(self):
32+
"""Override redis serialization to remove temporary fields.
33+
34+
Returns:
35+
The state dictionary without temporary fields.
36+
"""
37+
s = super().__getstate__()
38+
# Don't want to persist the cached substates
39+
s.pop("_original_substates", None)
40+
s.pop("_previous_dirty_vars", None)
41+
return s
42+
43+
@_override_base_method
44+
def _clean(self):
45+
"""Override BaseState._clean to track the last set of dirty vars.
46+
47+
This is necessary for applying dirty vars from one event to other linked states.
48+
"""
49+
if hasattr(self, "_previous_dirty_vars"):
50+
self._previous_dirty_vars.clear()
51+
self._previous_dirty_vars.update(self.dirty_vars)
52+
super()._clean()
53+
54+
@_override_base_method
55+
def _mark_dirty(self):
56+
"""Override BaseState._mark_dirty to avoid marking certain vars as dirty.
57+
58+
Since these internal fields are not persisted to redis, they shouldn't cause the
59+
state to be considered dirty either.
60+
"""
61+
self.dirty_vars.discard("_original_substates")
62+
self.dirty_vars.discard("_previously_dirty_substates")
63+
if self.dirty_vars:
64+
super()._mark_dirty()
65+
66+
def _rehydrate(self):
67+
"""Get the events to rehydrate the state.
68+
69+
Returns:
70+
The events to rehydrate the state (these should be returned/yielded).
71+
"""
72+
return [
73+
Event(
74+
token=self.router.session.client_token,
75+
name=get_hydrate_event(self._get_root_state()),
76+
),
77+
State.set_is_hydrated(True),
78+
]
79+
80+
async def _link_to(self, token: str):
81+
"""Link this shared state to a token.
82+
83+
After linking, subsequent access to this shared state will affect the
84+
linked token's state, and cause changes to be propagated to all other
85+
clients linked to that token.
86+
87+
Args:
88+
token: The token to link to.
89+
90+
Returns:
91+
The events to rehydrate the state after linking (these should be returned/yielded).
92+
"""
93+
# TODO: Change StateManager to accept token + class instead of combining them in a string.
94+
if "_" in token:
95+
msg = f"Invalid token {token} for linking state {self.get_full_name()}, cannot use underscore (_) in the token name."
96+
raise ReflexRuntimeError(msg)
97+
state_name = self.get_full_name()
98+
self._links[state_name] = token
99+
async with self._modify_linked_states() as _:
100+
linked_state = await self.get_state(type(self))
101+
linked_state._linked_from.add(self.router.session.client_token)
102+
linked_state._linked_to = token
103+
linked_state.dirty_vars.update(self.base_vars)
104+
linked_state.dirty_vars.update(self.backend_vars)
105+
linked_state.dirty_vars.update(self.computed_vars)
106+
linked_state._mark_dirty()
107+
# Apply the updates into the existing state tree, then rehydrate.
108+
root_state = self._get_root_state()
109+
await root_state._get_resolved_delta()
110+
root_state._clean()
111+
return self._rehydrate()
112+
113+
async def _unlink(self):
114+
"""Unlink this shared state from its linked token.
115+
116+
Returns:
117+
The events to rehydrate the state after unlinking (these should be returned/yielded
118+
"""
119+
state_name = self.get_full_name()
120+
if state_name not in self._links:
121+
msg = f"State {state_name} is not linked and cannot be unlinked."
122+
raise ReflexRuntimeError(msg)
123+
self._links.pop(state_name)
124+
self._linked_from.discard(self.router.session.client_token)
125+
# Rehydrate after unlinking to restore original values.
126+
return self._rehydrate()
127+
128+
async def _restore_original_substates(self, *_exc_info) -> None:
129+
"""Restore the original substates that were linked."""
130+
root_state = self._get_root_state()
131+
for linked_state_name, (
132+
original_state,
133+
linked_parent_state,
134+
) in self._original_substates.items():
135+
linked_state_cls = root_state.get_class_substate(linked_state_name)
136+
linked_state = await root_state.get_state(linked_state_cls)
137+
if (parent_state := linked_state.parent_state) is not None:
138+
parent_state.substates[original_state.get_name()] = original_state
139+
linked_state.parent_state = linked_parent_state
140+
self._original_substates = {}
141+
142+
@contextlib.asynccontextmanager
143+
async def _modify_linked_states(
144+
self, previous_dirty_vars: dict[str, set[str]] | None = None
145+
) -> AsyncIterator[None]:
146+
"""Take lock, fetch all linked states, and patch them into the current state tree.
147+
148+
If previous_dirty_vars is NOT provided, then any dirty vars after
149+
exiting the context will be applied to all other clients linked to this
150+
state's linked token.
151+
152+
Args:
153+
previous_dirty_vars: When apply linked state changes to other
154+
tokens, provide mapping of state full_name to set of dirty vars.
155+
156+
Yields:
157+
None.
158+
"""
159+
from reflex.istate.manager import get_state_manager
160+
161+
exit_stack = contextlib.AsyncExitStack()
162+
held_locks: set[str] = set()
163+
linked_states: list[BaseState] = []
164+
current_dirty_vars: dict[str, set[str]] = {}
165+
affected_tokens: set[str] = set()
166+
# Go through all linked states and patch them in if they are present in the tree
167+
for linked_state_name, linked_token in self._links.items():
168+
linked_state_cls = self.get_root_state().get_class_substate(
169+
linked_state_name
170+
)
171+
# TODO: Avoid always fetched linked states, it should be based on
172+
# whether the state is accessed, however then `get_state` would need
173+
# to know how to fetch in a linked state.
174+
original_state = await self.get_state(linked_state_cls)
175+
if linked_token not in held_locks:
176+
linked_root_state = await exit_stack.enter_async_context(
177+
get_state_manager().modify_state(
178+
_substate_key(linked_token, linked_state_cls)
179+
)
180+
)
181+
held_locks.add(linked_token)
182+
else:
183+
linked_root_state = await get_state_manager().get_state(
184+
_substate_key(linked_token, linked_state_cls)
185+
)
186+
linked_state = await linked_root_state.get_state(linked_state_cls)
187+
self._original_substates[linked_state_name] = (
188+
original_state,
189+
linked_state.parent_state,
190+
)
191+
if (parent_state := original_state.parent_state) is not None:
192+
parent_state.substates[original_state.get_name()] = linked_state
193+
linked_state.parent_state = parent_state
194+
linked_states.append(linked_state)
195+
if (
196+
previous_dirty_vars
197+
and (dv := previous_dirty_vars.get(linked_state_name)) is not None
198+
):
199+
linked_state.dirty_vars.update(dv)
200+
linked_state._mark_dirty()
201+
# Make sure to restore the non-linked substates after exiting the context.
202+
if self._original_substates:
203+
exit_stack.push_async_exit(self._restore_original_substates)
204+
async with exit_stack:
205+
yield None
206+
# Collect dirty vars and other affected clients that need to be updated.
207+
for linked_state in linked_states:
208+
if hasattr(linked_state, "_previous_dirty_vars"):
209+
current_dirty_vars[linked_state.get_full_name()] = set(
210+
linked_state._previous_dirty_vars
211+
)
212+
if linked_state._get_was_touched():
213+
affected_tokens.update(
214+
token
215+
for token in linked_state._linked_from
216+
if token != self.router.session.client_token
217+
)
218+
219+
# Only propagate dirty vars when we are not already propagating from another state.
220+
if previous_dirty_vars is None:
221+
from reflex.utils.prerequisites import get_app
222+
223+
app = get_app().app
224+
225+
for affected_token in affected_tokens:
226+
# Don't send updates for disconnected clients.
227+
if (
228+
affected_token
229+
not in app.event_namespace._token_manager.token_to_socket
230+
):
231+
continue
232+
async with app.modify_state(
233+
_substate_key(affected_token, type(self)),
234+
previous_dirty_vars=current_dirty_vars,
235+
):
236+
pass
237+
238+
239+
class SharedState(SharedStateBaseInternal, mixin=True):
240+
"""Mixin for defining new shared states."""
241+
242+
_linked_from: set[str]
243+
_linked_to: str
244+
_previous_dirty_vars: set[str]
245+
246+
@classmethod
247+
def __init_subclass__(cls, **kwargs):
248+
"""Initialize subclass and set up shared state fields.
249+
250+
Args:
251+
**kwargs: The kwargs to pass to the init_subclass method.
252+
"""
253+
kwargs["mixin"] = False
254+
cls._mixin = False
255+
super().__init_subclass__(**kwargs)

0 commit comments

Comments
 (0)