Skip to content

Commit 2cc6884

Browse files
authored
ENG-7948: StateManagerDisk deferred write queue (#5883)
* ENG-7948: StateManagerDisk deferred write queue * New env var: REFLEX_STATE_MANAGER_DISK_DEBOUNCE_SECONDS (default 2.0) * If the debounce is non-zero, then state manager will queue the disk write * Queued writes will be processed in order of set time after they exceed the debounce timeout * New StateManager.close method standardized in base class * Close app.state_manager when the server is going down * Flush all queued writes when the StateManagerDisk closes * Update test cases to always call `state_manager.close()` * FB: never sleep less than zero seconds * AppHarness: call state_manager.close() for all state managers * Do not reschedule write queue after event loop is closed * Make AppHarness more compatible-er with the new StateManagerDisk * clear StateManagerDisk _write_queue on `.close()` * AppHarness.get_state makes sure to drain the backend's _write_queue * move _flush_write_queue to a separate function * when debounce is disabled, sleep the expiration task until the oldest state would expire conserve resources by pausing the _process_write_queue for the amount of time of the oldest known token to expire. * simplify AppHarness song and dance for flushing backend's StateManagerDisk * Take _state_manager_lock when closing Avoid interference with _schedule_process_write_queue
1 parent 1e2a8c9 commit 2cc6884

File tree

9 files changed

+261
-48
lines changed

9 files changed

+261
-48
lines changed

reflex/app_mixins/lifespan.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ async def _run_lifespan_tasks(self, app: Starlette):
7171
await event_namespace._token_manager.disconnect_all()
7272
except Exception as e:
7373
console.error(f"Error during lifespan cleanup: {e}")
74+
# Flush any pending writes from the state manager.
75+
try:
76+
state_manager = self.state_manager # pyright: ignore[reportAttributeAccessIssue]
77+
except AttributeError:
78+
pass
79+
else:
80+
await state_manager.close()
7481

7582
def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
7683
"""Register a task to run during the lifespan of the app.

reflex/environment.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,26 @@ def interpret_int_env(value: str, field_name: str) -> int:
9595
raise EnvironmentVarValueError(msg) from ve
9696

9797

98+
def interpret_float_env(value: str, field_name: str) -> float:
99+
"""Interpret a float environment variable value.
100+
101+
Args:
102+
value: The environment variable value.
103+
field_name: The field name.
104+
105+
Returns:
106+
The interpreted value.
107+
108+
Raises:
109+
EnvironmentVarValueError: If the value is invalid.
110+
"""
111+
try:
112+
return float(value)
113+
except ValueError as ve:
114+
msg = f"Invalid float value: {value!r} for {field_name}"
115+
raise EnvironmentVarValueError(msg) from ve
116+
117+
98118
def interpret_existing_path_env(value: str, field_name: str) -> ExistingPath:
99119
"""Interpret a path environment variable value as an existing path.
100120
@@ -228,6 +248,8 @@ def interpret_env_var_value(
228248
return loglevel
229249
if field_type is int:
230250
return interpret_int_env(value, field_name)
251+
if field_type is float:
252+
return interpret_float_env(value, field_name)
231253
if field_type is Path:
232254
return interpret_path_env(value, field_name)
233255
if field_type is ExistingPath:
@@ -671,6 +693,9 @@ class EnvironmentVariables:
671693
# Whether to mount the compiled frontend app in the backend server in production.
672694
REFLEX_MOUNT_FRONTEND_COMPILED_APP: EnvVar[bool] = env_var(False, internal=True)
673695

696+
# How long to delay writing updated states to disk. (Higher values mean less writes, but more chance of lost data.)
697+
REFLEX_STATE_MANAGER_DISK_DEBOUNCE_SECONDS: EnvVar[float] = env_var(2.0)
698+
674699

675700
environment = EnvironmentVariables()
676701

reflex/istate/manager/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
9292
"""
9393
yield self.state()
9494

95+
async def close(self): # noqa: B027
96+
"""Close the state manager."""
97+
9598

9699
def _default_token_expiration() -> int:
97100
"""Get the default token expiration time.

reflex/istate/manager/disk.py

Lines changed: 151 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,27 @@
44
import contextlib
55
import dataclasses
66
import functools
7+
import time
78
from collections.abc import AsyncIterator
89
from hashlib import md5
910
from pathlib import Path
1011

1112
from typing_extensions import override
1213

14+
from reflex.environment import environment
1315
from reflex.istate.manager import StateManager, _default_token_expiration
1416
from reflex.state import BaseState, _split_substate_key, _substate_key
15-
from reflex.utils import path_ops, prerequisites
17+
from reflex.utils import console, path_ops, prerequisites
18+
from reflex.utils.misc import run_in_thread
19+
20+
21+
@dataclasses.dataclass(frozen=True)
22+
class QueueItem:
23+
"""An item in the write queue."""
24+
25+
token: str
26+
state: BaseState
27+
timestamp: float
1628

1729

1830
@dataclasses.dataclass
@@ -34,6 +46,22 @@ class StateManagerDisk(StateManager):
3446
# The token expiration time (s).
3547
token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
3648

49+
# Last time a token was touched.
50+
_token_last_touched: dict[str, float] = dataclasses.field(
51+
default_factory=dict,
52+
init=False,
53+
)
54+
55+
# Pending writes
56+
_write_queue: dict[str, QueueItem] = dataclasses.field(
57+
default_factory=dict,
58+
init=False,
59+
)
60+
_write_queue_task: asyncio.Task | None = None
61+
_write_debounce_seconds: float = dataclasses.field(
62+
default=environment.REFLEX_STATE_MANAGER_DISK_DEBOUNCE_SECONDS.get()
63+
)
64+
3765
def __post_init__(self):
3866
"""Create a new state manager."""
3967
path_ops.mkdir(self.states_directory)
@@ -51,8 +79,6 @@ def states_directory(self) -> Path:
5179

5280
def _purge_expired_states(self):
5381
"""Purge expired states from the disk."""
54-
import time
55-
5682
for path in path_ops.ls(self.states_directory):
5783
# check path is a pickle file
5884
if path.suffix != ".pkl":
@@ -137,6 +163,7 @@ async def get_state(
137163
The state for the token.
138164
"""
139165
client_token = _split_substate_key(token)[0]
166+
self._token_last_touched[client_token] = time.time()
140167
root_state = self.states.get(client_token)
141168
if root_state is not None:
142169
# Retrieved state from memory.
@@ -170,11 +197,109 @@ async def set_state_for_substate(self, client_token: str, substate: BaseState):
170197
if pickle_state:
171198
if not self.states_directory.exists():
172199
self.states_directory.mkdir(parents=True, exist_ok=True)
173-
self.token_path(substate_token).write_bytes(pickle_state)
200+
await run_in_thread(
201+
lambda: self.token_path(substate_token).write_bytes(pickle_state),
202+
)
174203

175204
for substate_substate in substate.substates.values():
176205
await self.set_state_for_substate(client_token, substate_substate)
177206

207+
async def _process_write_queue_delay(self):
208+
"""Wait for the debounce period before processing the write queue again."""
209+
now = time.time()
210+
if self._write_queue:
211+
# There are still items in the queue, schedule another run.
212+
next_write_in = max(
213+
0,
214+
min(
215+
self._write_debounce_seconds - (now - item.timestamp)
216+
for item in self._write_queue.values()
217+
),
218+
)
219+
await asyncio.sleep(next_write_in)
220+
elif self._write_debounce_seconds > 0:
221+
# No items left, wait a bit before checking again.
222+
await asyncio.sleep(self._write_debounce_seconds)
223+
else:
224+
# Debounce is disabled, so sleep until the next token expiration.
225+
oldest_token_last_touch = min(
226+
self._token_last_touched.values(), default=now
227+
)
228+
next_expiration_in = self.token_expiration - (now - oldest_token_last_touch)
229+
await asyncio.sleep(next_expiration_in)
230+
231+
async def _process_write_queue(self):
232+
"""Long running task that checks for states to write to disk.
233+
234+
Raises:
235+
asyncio.CancelledError: When the task is cancelled.
236+
"""
237+
while True:
238+
try:
239+
now = time.time()
240+
# sort the _write_queue by oldest timestamp and exclude items younger than debounce time
241+
items_to_write = sorted(
242+
(
243+
item
244+
for item in self._write_queue.values()
245+
if now - item.timestamp >= self._write_debounce_seconds
246+
),
247+
key=lambda item: item.timestamp,
248+
)
249+
for item in items_to_write:
250+
token = item.token
251+
client_token, _ = _split_substate_key(token)
252+
await self.set_state_for_substate(
253+
client_token, self._write_queue.pop(token).state
254+
)
255+
# Check for expired states to purge.
256+
for token, last_touched in list(self._token_last_touched.items()):
257+
if now - last_touched > self.token_expiration:
258+
self._token_last_touched.pop(token)
259+
self.states.pop(token, None)
260+
await run_in_thread(self._purge_expired_states)
261+
await self._process_write_queue_delay()
262+
except asyncio.CancelledError: # noqa: PERF203
263+
await self._flush_write_queue()
264+
raise
265+
except Exception as e:
266+
console.error(f"Error processing write queue: {e!r}")
267+
if e.args == ("cannot schedule new futures after shutdown",):
268+
# Event loop is shutdown, nothing else we can really do...
269+
return
270+
await self._process_write_queue_delay()
271+
272+
async def _flush_write_queue(self):
273+
"""Flush any remaining items in the write queue to disk."""
274+
outstanding_items = list(self._write_queue.values())
275+
n_outstanding_items = len(outstanding_items)
276+
self._write_queue.clear()
277+
# When the task is cancelled, write all remaining items to disk.
278+
console.debug(
279+
f"StateManagerDisk._flush_write_queue: writing {n_outstanding_items} remaining items to disk"
280+
)
281+
for item in outstanding_items:
282+
token = item.token
283+
client_token, _ = _split_substate_key(token)
284+
await self.set_state_for_substate(
285+
client_token,
286+
item.state,
287+
)
288+
console.debug(
289+
f"StateManagerDisk._flush_write_queue: Finished writing {n_outstanding_items} items"
290+
)
291+
292+
async def _schedule_process_write_queue(self):
293+
"""Schedule the write queue processing task if not already running."""
294+
if self._write_queue_task is None or self._write_queue_task.done():
295+
async with self._state_manager_lock:
296+
if self._write_queue_task is None or self._write_queue_task.done():
297+
self._write_queue_task = asyncio.create_task(
298+
self._process_write_queue(),
299+
name="StateManagerDisk|WriteQueueProcessor",
300+
)
301+
await asyncio.sleep(0) # Yield to allow the task to start.
302+
178303
@override
179304
async def set_state(self, token: str, state: BaseState):
180305
"""Set the state for a token.
@@ -184,7 +309,19 @@ async def set_state(self, token: str, state: BaseState):
184309
state: The state to set.
185310
"""
186311
client_token, _ = _split_substate_key(token)
187-
await self.set_state_for_substate(client_token, state)
312+
if self._write_debounce_seconds > 0:
313+
# Deferred write to reduce disk IO overhead.
314+
if client_token not in self._write_queue:
315+
self._write_queue[client_token] = QueueItem(
316+
token=client_token,
317+
state=state,
318+
timestamp=time.time(),
319+
)
320+
else:
321+
# Immediate write to disk.
322+
await self.set_state_for_substate(client_token, state)
323+
# Ensure the processing task is scheduled to handle expirations and any deferred writes.
324+
await self._schedule_process_write_queue()
188325

189326
@override
190327
@contextlib.asynccontextmanager
@@ -208,3 +345,12 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
208345
state = await self.get_state(token)
209346
yield state
210347
await self.set_state(token, state)
348+
349+
async def close(self):
350+
"""Close the state manager, flushing any pending writes to disk."""
351+
async with self._state_manager_lock:
352+
if self._write_queue_task:
353+
self._write_queue_task.cancel()
354+
with contextlib.suppress(asyncio.CancelledError):
355+
await self._write_queue_task
356+
self._write_queue_task = None

0 commit comments

Comments
 (0)