Skip to content

Commit 1b6b2c2

Browse files
authored
Only fetch required linked states (#6049)
Optimize the linked state loading path in redis mode by only linking states which are already cached and augmenting `get_state` to link in dynamically fetched states. Adds SharedStateBaseInternal as an _always_dirty_substates from the root state. This causes it to be always shallow-cached when loading any states from redis. Therefore it will be available for use with `_modify_linked_states` without explicitly getting it (and by consequence ALL SharedState instances). Only required SharedState classes will be fetched now.
1 parent 1e71398 commit 1b6b2c2

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

reflex/istate/shared.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,11 @@ async def _modify_linked_states(
338338
linked_state_name
339339
)
340340
)
341-
# TODO: Avoid always fetched linked states, it should be based on
342-
# whether the state is accessed, however then `get_state` would need
343-
# to know how to fetch in a linked state.
344-
original_state = await self.get_state(linked_state_cls)
341+
try:
342+
original_state = self._get_state_from_cache(linked_state_cls)
343+
except ValueError:
344+
# This state wasn't required for processing the event.
345+
continue
345346
linked_state = await original_state._internal_patch_linked_state(
346347
linked_token
347348
)
@@ -400,3 +401,9 @@ def __init_subclass__(cls, **kwargs):
400401
root_state = cls.get_root_state()
401402
if root_state.backend_vars["_reflex_internal_links"] is None:
402403
root_state.backend_vars["_reflex_internal_links"] = {}
404+
if root_state is State:
405+
# Always fetch SharedStateBaseInternal to access
406+
# `_modify_linked_states` without having to use `.get_state()` which
407+
# pulls in all linked states and substates which may not actually be
408+
# accessed for this event.
409+
root_state._always_dirty_substates.add(SharedStateBaseInternal.get_name())

reflex/state.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,6 +2467,35 @@ class State(BaseState):
24672467
# Maps the state full_name to an arbitrary token it is linked to for shared state.
24682468
_reflex_internal_links: dict[str, str] | None = None
24692469

2470+
@_override_base_method
2471+
async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE:
2472+
"""Get a state instance from redis with linking support.
2473+
2474+
Args:
2475+
state_cls: The class of the state.
2476+
2477+
Returns:
2478+
The instance of state_cls associated with this state's client_token.
2479+
"""
2480+
state_instance = await super()._get_state_from_redis(state_cls)
2481+
if (
2482+
self._reflex_internal_links
2483+
and (
2484+
linked_token := self._reflex_internal_links.get(
2485+
state_cls.get_full_name()
2486+
)
2487+
)
2488+
is not None
2489+
and (
2490+
internal_patch_linked_state := getattr(
2491+
state_instance, "_internal_patch_linked_state", None
2492+
)
2493+
)
2494+
is not None
2495+
):
2496+
return await internal_patch_linked_state(linked_token)
2497+
return state_instance
2498+
24702499
@event
24712500
def set_is_hydrated(self, value: bool) -> None:
24722501
"""Set the hydrated state.

0 commit comments

Comments
 (0)