Skip to content

Commit cb0ddfa

Browse files
authored
split manager (#5852)
* split manager * sure * huh * typo * a few more ig * okie
1 parent c28bdba commit cb0ddfa

File tree

13 files changed

+452
-407
lines changed

13 files changed

+452
-407
lines changed

reflex/istate/manager/__init__.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""State manager for managing client states."""
2+
3+
import contextlib
4+
import dataclasses
5+
from abc import ABC, abstractmethod
6+
from collections.abc import AsyncIterator
7+
8+
from reflex import constants
9+
from reflex.config import get_config
10+
from reflex.state import BaseState
11+
from reflex.utils import console, prerequisites
12+
from reflex.utils.exceptions import InvalidStateManagerModeError
13+
14+
15+
@dataclasses.dataclass
16+
class StateManager(ABC):
17+
"""A class to manage many client states."""
18+
19+
# The state class to use.
20+
state: type[BaseState]
21+
22+
@classmethod
23+
def create(cls, state: type[BaseState]):
24+
"""Create a new state manager.
25+
26+
Args:
27+
state: The state class to use.
28+
29+
Raises:
30+
InvalidStateManagerModeError: If the state manager mode is invalid.
31+
32+
Returns:
33+
The state manager (either disk, memory or redis).
34+
"""
35+
config = get_config()
36+
if prerequisites.parse_redis_url() is not None:
37+
config.state_manager_mode = constants.StateManagerMode.REDIS
38+
if config.state_manager_mode == constants.StateManagerMode.MEMORY:
39+
from reflex.istate.manager.memory import StateManagerMemory
40+
41+
return StateManagerMemory(state=state)
42+
if config.state_manager_mode == constants.StateManagerMode.DISK:
43+
from reflex.istate.manager.disk import StateManagerDisk
44+
45+
return StateManagerDisk(state=state)
46+
if config.state_manager_mode == constants.StateManagerMode.REDIS:
47+
redis = prerequisites.get_redis()
48+
if redis is not None:
49+
from reflex.istate.manager.redis import StateManagerRedis
50+
51+
# make sure expiration values are obtained only from the config object on creation
52+
return StateManagerRedis(
53+
state=state,
54+
redis=redis,
55+
token_expiration=config.redis_token_expiration,
56+
lock_expiration=config.redis_lock_expiration,
57+
lock_warning_threshold=config.redis_lock_warning_threshold,
58+
)
59+
msg = f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
60+
raise InvalidStateManagerModeError(msg)
61+
62+
@abstractmethod
63+
async def get_state(self, token: str) -> BaseState:
64+
"""Get the state for a token.
65+
66+
Args:
67+
token: The token to get the state for.
68+
69+
Returns:
70+
The state for the token.
71+
"""
72+
73+
@abstractmethod
74+
async def set_state(self, token: str, state: BaseState):
75+
"""Set the state for a token.
76+
77+
Args:
78+
token: The token to set the state for.
79+
state: The state to set.
80+
"""
81+
82+
@abstractmethod
83+
@contextlib.asynccontextmanager
84+
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
85+
"""Modify the state for a token while holding exclusive lock.
86+
87+
Args:
88+
token: The token to modify the state for.
89+
90+
Yields:
91+
The state for the token.
92+
"""
93+
yield self.state()
94+
95+
96+
def _default_token_expiration() -> int:
97+
"""Get the default token expiration time.
98+
99+
Returns:
100+
The default token expiration time.
101+
"""
102+
return get_config().redis_token_expiration
103+
104+
105+
def reset_disk_state_manager():
106+
"""Reset the disk state manager."""
107+
console.debug("Resetting disk state manager.")
108+
states_directory = prerequisites.get_states_dir()
109+
if states_directory.exists():
110+
for path in states_directory.iterdir():
111+
path.unlink()
112+
113+
114+
def get_state_manager() -> StateManager:
115+
"""Get the state manager for the app that is currently running.
116+
117+
Returns:
118+
The state manager.
119+
"""
120+
return prerequisites.get_and_validate_app().app.state_manager

reflex/istate/manager/disk.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
"""A state manager that stores states on disk."""
2+
3+
import asyncio
4+
import contextlib
5+
import dataclasses
6+
import functools
7+
from collections.abc import AsyncIterator
8+
from hashlib import md5
9+
from pathlib import Path
10+
11+
from typing_extensions import override
12+
13+
from reflex.istate.manager import StateManager, _default_token_expiration
14+
from reflex.state import BaseState, _split_substate_key, _substate_key
15+
from reflex.utils import path_ops, prerequisites
16+
17+
18+
@dataclasses.dataclass
19+
class StateManagerDisk(StateManager):
20+
"""A state manager that stores states on disk."""
21+
22+
# The mapping of client ids to states.
23+
states: dict[str, BaseState] = dataclasses.field(default_factory=dict)
24+
25+
# The mutex ensures the dict of mutexes is updated exclusively
26+
_state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock())
27+
28+
# The dict of mutexes for each client
29+
_states_locks: dict[str, asyncio.Lock] = dataclasses.field(
30+
default_factory=dict,
31+
init=False,
32+
)
33+
34+
# The token expiration time (s).
35+
token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
36+
37+
def __post_init__(self):
38+
"""Create a new state manager."""
39+
path_ops.mkdir(self.states_directory)
40+
41+
self._purge_expired_states()
42+
43+
@functools.cached_property
44+
def states_directory(self) -> Path:
45+
"""Get the states directory.
46+
47+
Returns:
48+
The states directory.
49+
"""
50+
return prerequisites.get_states_dir()
51+
52+
def _purge_expired_states(self):
53+
"""Purge expired states from the disk."""
54+
import time
55+
56+
for path in path_ops.ls(self.states_directory):
57+
# check path is a pickle file
58+
if path.suffix != ".pkl":
59+
continue
60+
61+
# load last edited field from file
62+
last_edited = path.stat().st_mtime
63+
64+
# check if the file is older than the token expiration time
65+
if time.time() - last_edited > self.token_expiration:
66+
# remove the file
67+
path.unlink()
68+
69+
def token_path(self, token: str) -> Path:
70+
"""Get the path for a token.
71+
72+
Args:
73+
token: The token to get the path for.
74+
75+
Returns:
76+
The path for the token.
77+
"""
78+
return (
79+
self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
80+
).absolute()
81+
82+
async def load_state(self, token: str) -> BaseState | None:
83+
"""Load a state object based on the provided token.
84+
85+
Args:
86+
token: The token used to identify the state object.
87+
88+
Returns:
89+
The loaded state object or None.
90+
"""
91+
token_path = self.token_path(token)
92+
93+
if token_path.exists():
94+
try:
95+
with token_path.open(mode="rb") as file:
96+
return BaseState._deserialize(fp=file)
97+
except Exception:
98+
pass
99+
return None
100+
101+
async def populate_substates(
102+
self, client_token: str, state: BaseState, root_state: BaseState
103+
):
104+
"""Populate the substates of a state object.
105+
106+
Args:
107+
client_token: The client token.
108+
state: The state object to populate.
109+
root_state: The root state object.
110+
"""
111+
for substate in state.get_substates():
112+
substate_token = _substate_key(client_token, substate)
113+
114+
fresh_instance = await root_state.get_state(substate)
115+
instance = await self.load_state(substate_token)
116+
if instance is not None:
117+
# Ensure all substates exist, even if they weren't serialized previously.
118+
instance.substates = fresh_instance.substates
119+
else:
120+
instance = fresh_instance
121+
state.substates[substate.get_name()] = instance
122+
instance.parent_state = state
123+
124+
await self.populate_substates(client_token, instance, root_state)
125+
126+
@override
127+
async def get_state(
128+
self,
129+
token: str,
130+
) -> BaseState:
131+
"""Get the state for a token.
132+
133+
Args:
134+
token: The token to get the state for.
135+
136+
Returns:
137+
The state for the token.
138+
"""
139+
client_token = _split_substate_key(token)[0]
140+
root_state = self.states.get(client_token)
141+
if root_state is not None:
142+
# Retrieved state from memory.
143+
return root_state
144+
145+
# Deserialize root state from disk.
146+
root_state = await self.load_state(_substate_key(client_token, self.state))
147+
# Create a new root state tree with all substates instantiated.
148+
fresh_root_state = self.state(_reflex_internal_init=True)
149+
if root_state is None:
150+
root_state = fresh_root_state
151+
else:
152+
# Ensure all substates exist, even if they were not serialized previously.
153+
root_state.substates = fresh_root_state.substates
154+
self.states[client_token] = root_state
155+
await self.populate_substates(client_token, root_state, root_state)
156+
return root_state
157+
158+
async def set_state_for_substate(self, client_token: str, substate: BaseState):
159+
"""Set the state for a substate.
160+
161+
Args:
162+
client_token: The client token.
163+
substate: The substate to set.
164+
"""
165+
substate_token = _substate_key(client_token, substate)
166+
167+
if substate._get_was_touched():
168+
substate._was_touched = False # Reset the touched flag after serializing.
169+
pickle_state = substate._serialize()
170+
if pickle_state:
171+
if not self.states_directory.exists():
172+
self.states_directory.mkdir(parents=True, exist_ok=True)
173+
self.token_path(substate_token).write_bytes(pickle_state)
174+
175+
for substate_substate in substate.substates.values():
176+
await self.set_state_for_substate(client_token, substate_substate)
177+
178+
@override
179+
async def set_state(self, token: str, state: BaseState):
180+
"""Set the state for a token.
181+
182+
Args:
183+
token: The token to set the state for.
184+
state: The state to set.
185+
"""
186+
client_token, _ = _split_substate_key(token)
187+
await self.set_state_for_substate(client_token, state)
188+
189+
@override
190+
@contextlib.asynccontextmanager
191+
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
192+
"""Modify the state for a token while holding exclusive lock.
193+
194+
Args:
195+
token: The token to modify the state for.
196+
197+
Yields:
198+
The state for the token.
199+
"""
200+
# Disk state manager ignores the substate suffix and always returns the top-level state.
201+
client_token, _ = _split_substate_key(token)
202+
if client_token not in self._states_locks:
203+
async with self._state_manager_lock:
204+
if client_token not in self._states_locks:
205+
self._states_locks[client_token] = asyncio.Lock()
206+
207+
async with self._states_locks[client_token]:
208+
state = await self.get_state(token)
209+
yield state
210+
await self.set_state(token, state)

0 commit comments

Comments
 (0)