Skip to content

Commit 96c0b45

Browse files
committed
ENG-7948: StateManagerDisk deferred write queue
* New env var: REFLEX_STATE_MANAGER_DISK_DEBOUNCE_SECONDS (default 2.0) * If the debounce is non-zero, then state manager will queue the disk write * Queued writes will be processed in order of set time after they exceed the debounce timeout * New StateManager.close method standardized in base class * Close app.state_manager when the server is going down * Flush all queued writes when the StateManagerDisk closes * Update test cases to always call `state_manager.close()`
1 parent 8a2337b commit 96c0b45

File tree

6 files changed

+180
-25
lines changed

6 files changed

+180
-25
lines changed

reflex/app_mixins/lifespan.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ async def _run_lifespan_tasks(self, app: Starlette):
7171
await event_namespace._token_manager.disconnect_all()
7272
except Exception as e:
7373
console.error(f"Error during lifespan cleanup: {e}")
74+
# Flush any pending writes from the state manager.
75+
try:
76+
state_manager = self.state_manager # pyright: ignore[reportAttributeAccessIssue]
77+
except AttributeError:
78+
pass
79+
else:
80+
await state_manager.close()
7481

7582
def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
7683
"""Register a task to run during the lifespan of the app.

reflex/environment.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,26 @@ def interpret_int_env(value: str, field_name: str) -> int:
9595
raise EnvironmentVarValueError(msg) from ve
9696

9797

98+
def interpret_float_env(value: str, field_name: str) -> float:
99+
"""Interpret a float environment variable value.
100+
101+
Args:
102+
value: The environment variable value.
103+
field_name: The field name.
104+
105+
Returns:
106+
The interpreted value.
107+
108+
Raises:
109+
EnvironmentVarValueError: If the value is invalid.
110+
"""
111+
try:
112+
return float(value)
113+
except ValueError as ve:
114+
msg = f"Invalid float value: {value!r} for {field_name}"
115+
raise EnvironmentVarValueError(msg) from ve
116+
117+
98118
def interpret_existing_path_env(value: str, field_name: str) -> ExistingPath:
99119
"""Interpret a path environment variable value as an existing path.
100120
@@ -228,6 +248,8 @@ def interpret_env_var_value(
228248
return loglevel
229249
if field_type is int:
230250
return interpret_int_env(value, field_name)
251+
if field_type is float:
252+
return interpret_float_env(value, field_name)
231253
if field_type is Path:
232254
return interpret_path_env(value, field_name)
233255
if field_type is ExistingPath:
@@ -671,6 +693,9 @@ class EnvironmentVariables:
671693
# Whether to mount the compiled frontend app in the backend server in production.
672694
REFLEX_MOUNT_FRONTEND_COMPILED_APP: EnvVar[bool] = env_var(False, internal=True)
673695

696+
# How long to delay writing updated states to disk. (Higher values mean less writes, but more chance of lost data.)
697+
REFLEX_STATE_MANAGER_DISK_DEBOUNCE_SECONDS: EnvVar[float] = env_var(2.0)
698+
674699

675700
environment = EnvironmentVariables()
676701

reflex/istate/manager/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
9292
"""
9393
yield self.state()
9494

95+
async def close(self): # noqa: B027
96+
"""Close the state manager."""
97+
9598

9699
def _default_token_expiration() -> int:
97100
"""Get the default token expiration time.

reflex/istate/manager/disk.py

Lines changed: 131 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,27 @@
44
import contextlib
55
import dataclasses
66
import functools
7+
import time
78
from collections.abc import AsyncIterator
89
from hashlib import md5
910
from pathlib import Path
1011

1112
from typing_extensions import override
1213

14+
from reflex.environment import environment
1315
from reflex.istate.manager import StateManager, _default_token_expiration
1416
from reflex.state import BaseState, _split_substate_key, _substate_key
15-
from reflex.utils import path_ops, prerequisites
17+
from reflex.utils import console, path_ops, prerequisites
18+
from reflex.utils.misc import run_in_thread
19+
20+
21+
@dataclasses.dataclass(frozen=True)
22+
class QueueItem:
23+
"""An item in the write queue."""
24+
25+
token: str
26+
state: BaseState
27+
timestamp: float
1628

1729

1830
@dataclasses.dataclass
@@ -34,6 +46,22 @@ class StateManagerDisk(StateManager):
3446
# The token expiration time (s).
3547
token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
3648

49+
# Last time a token was touched.
50+
_token_last_touched: dict[str, float] = dataclasses.field(
51+
default_factory=dict,
52+
init=False,
53+
)
54+
55+
# Pending writes
56+
_write_queue: dict[str, QueueItem] = dataclasses.field(
57+
default_factory=dict,
58+
init=False,
59+
)
60+
_write_queue_task: asyncio.Task | None = None
61+
_write_debounce_seconds: float = dataclasses.field(
62+
default=environment.REFLEX_STATE_MANAGER_DISK_DEBOUNCE_SECONDS.get()
63+
)
64+
3765
def __post_init__(self):
3866
"""Create a new state manager."""
3967
path_ops.mkdir(self.states_directory)
@@ -51,8 +79,6 @@ def states_directory(self) -> Path:
5179

5280
def _purge_expired_states(self):
5381
"""Purge expired states from the disk."""
54-
import time
55-
5682
for path in path_ops.ls(self.states_directory):
5783
# check path is a pickle file
5884
if path.suffix != ".pkl":
@@ -137,6 +163,7 @@ async def get_state(
137163
The state for the token.
138164
"""
139165
client_token = _split_substate_key(token)[0]
166+
self._token_last_touched[client_token] = time.time()
140167
root_state = self.states.get(client_token)
141168
if root_state is not None:
142169
# Retrieved state from memory.
@@ -170,11 +197,90 @@ async def set_state_for_substate(self, client_token: str, substate: BaseState):
170197
if pickle_state:
171198
if not self.states_directory.exists():
172199
self.states_directory.mkdir(parents=True, exist_ok=True)
173-
self.token_path(substate_token).write_bytes(pickle_state)
200+
await run_in_thread(
201+
lambda: self.token_path(substate_token).write_bytes(pickle_state),
202+
)
174203

175204
for substate_substate in substate.substates.values():
176205
await self.set_state_for_substate(client_token, substate_substate)
177206

207+
async def _process_write_queue_delay(self):
208+
"""Wait for the debounce period before processing the write queue again."""
209+
if self._write_queue:
210+
# There are still items in the queue, schedule another run.
211+
now = time.time()
212+
next_write_in = min(
213+
self._write_debounce_seconds - (now - item.timestamp)
214+
for item in self._write_queue.values()
215+
)
216+
await asyncio.sleep(next_write_in)
217+
elif self._write_debounce_seconds > 0:
218+
# No items left, wait a bit before checking again.
219+
await asyncio.sleep(self._write_debounce_seconds)
220+
else:
221+
# No debounce, wait a minute before processing expirations.
222+
await asyncio.sleep(60)
223+
224+
async def _process_write_queue(self):
225+
"""Long running task that checks for states to write to disk.
226+
227+
Raises:
228+
asyncio.CancelledError: When the task is cancelled.
229+
"""
230+
while True:
231+
try:
232+
now = time.time()
233+
# sort the _write_queue by oldest timestamp and exclude items younger than debounce time
234+
items_to_write = sorted(
235+
(
236+
item
237+
for item in self._write_queue.values()
238+
if now - item.timestamp >= self._write_debounce_seconds
239+
),
240+
key=lambda item: item.timestamp,
241+
)
242+
for item in items_to_write:
243+
token = item.token
244+
client_token, _ = _split_substate_key(token)
245+
await self.set_state_for_substate(
246+
client_token, self._write_queue.pop(token).state
247+
)
248+
# Check for expired states to purge.
249+
for token, last_touched in list(self._token_last_touched.items()):
250+
if now - last_touched > self.token_expiration:
251+
self._token_last_touched.pop(token)
252+
self.states.pop(token, None)
253+
await run_in_thread(self._purge_expired_states)
254+
await self._process_write_queue_delay()
255+
except asyncio.CancelledError: # noqa: PERF203
256+
n_outstanding_items = len(self._write_queue)
257+
# When the task is cancelled, write all remaining items to disk.
258+
console.debug(
259+
f"Closing StateManagerDisk, writing {n_outstanding_items} remaining items to disk"
260+
)
261+
for item in self._write_queue.values():
262+
token = item.token
263+
client_token, _ = _split_substate_key(token)
264+
await self.set_state_for_substate(
265+
client_token,
266+
item.state,
267+
)
268+
console.debug(f"Finished writing {n_outstanding_items} items to disk")
269+
raise
270+
except Exception as e:
271+
console.error(f"Error processing write queue: {e!r}")
272+
await self._process_write_queue_delay()
273+
274+
async def _schedule_process_write_queue(self):
275+
"""Schedule the write queue processing task if not already running."""
276+
if self._write_queue_task is None or self._write_queue_task.done():
277+
async with self._state_manager_lock:
278+
if self._write_queue_task is None or self._write_queue_task.done():
279+
self._write_queue_task = asyncio.create_task(
280+
self._process_write_queue()
281+
)
282+
await asyncio.sleep(0) # Yield to allow the task to start.
283+
178284
@override
179285
async def set_state(self, token: str, state: BaseState):
180286
"""Set the state for a token.
@@ -184,7 +290,19 @@ async def set_state(self, token: str, state: BaseState):
184290
state: The state to set.
185291
"""
186292
client_token, _ = _split_substate_key(token)
187-
await self.set_state_for_substate(client_token, state)
293+
if self._write_debounce_seconds > 0:
294+
# Deferred write to reduce disk IO overhead.
295+
if client_token not in self._write_queue:
296+
self._write_queue[client_token] = QueueItem(
297+
token=client_token,
298+
state=state,
299+
timestamp=time.time(),
300+
)
301+
else:
302+
# Immediate write to disk.
303+
await self.set_state_for_substate(client_token, state)
304+
# Ensure the processing task is scheduled to handle expirations and any deferred writes.
305+
await self._schedule_process_write_queue()
188306

189307
@override
190308
@contextlib.asynccontextmanager
@@ -208,3 +326,11 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
208326
state = await self.get_state(token)
209327
yield state
210328
await self.set_state(token, state)
329+
330+
async def close(self):
331+
"""Close the state manager, flushing any pending writes to disk."""
332+
if self._write_queue_task:
333+
self._write_queue_task.cancel()
334+
with contextlib.suppress(asyncio.CancelledError):
335+
await self._write_queue_task
336+
self._write_queue_task = None

tests/units/test_app.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,7 @@ async def test_initialize_with_state(test_state: type[ATestState], token: str):
451451
assert isinstance(state, test_state)
452452
assert state.var == 0
453453

454-
if isinstance(app.state_manager, StateManagerRedis):
455-
await app.state_manager.close()
454+
await app.state_manager.close()
456455

457456

458457
@pytest.mark.asyncio
@@ -486,8 +485,7 @@ async def test_set_and_get_state(test_state):
486485
assert state1.var == 1
487486
assert state2.var == 2
488487

489-
if isinstance(app.state_manager, StateManagerRedis):
490-
await app.state_manager.close()
488+
await app.state_manager.close()
491489

492490

493491
@pytest.mark.asyncio
@@ -999,8 +997,7 @@ def getlist(key: str):
999997
"image2.jpg",
1000998
]
1001999

1002-
if isinstance(app.state_manager, StateManagerRedis):
1003-
await app.state_manager.close()
1000+
await app.state_manager.close()
10041001

10051002

10061003
@pytest.mark.asyncio
@@ -1046,8 +1043,7 @@ def getlist(key: str):
10461043
== f"`{state.get_full_name()}.handle_upload2` handler should have a parameter annotated as list[rx.UploadFile]"
10471044
)
10481045

1049-
if isinstance(app.state_manager, StateManagerRedis):
1050-
await app.state_manager.close()
1046+
await app.state_manager.close()
10511047

10521048

10531049
@pytest.mark.asyncio
@@ -1093,8 +1089,7 @@ def getlist(key: str):
10931089
== f"@rx.event(background=True) is not supported for upload handler `{state.get_full_name()}.bg_upload`."
10941090
)
10951091

1096-
if isinstance(app.state_manager, StateManagerRedis):
1097-
await app.state_manager.close()
1092+
await app.state_manager.close()
10981093

10991094

11001095
class DynamicState(BaseState):
@@ -1372,8 +1367,7 @@ def _dynamic_state_event(name, val, **kwargs):
13721367
assert state.loaded == len(exp_vals)
13731368
assert state.counter == len(exp_vals)
13741369

1375-
if isinstance(app.state_manager, StateManagerRedis):
1376-
await app.state_manager.close()
1370+
await app.state_manager.close()
13771371

13781372

13791373
@pytest.mark.asyncio
@@ -1412,8 +1406,7 @@ async def test_process_events(mocker: MockerFixture, token: str):
14121406
assert (await app.state_manager.get_state(event.substate_token)).value == 5
14131407
assert app._postprocess.call_count == 6 # pyright: ignore [reportAttributeAccessIssue]
14141408

1415-
if isinstance(app.state_manager, StateManagerRedis):
1416-
await app.state_manager.close()
1409+
await app.state_manager.close()
14171410

14181411

14191412
@pytest.mark.parametrize(

tests/units/test_state.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,8 +1685,7 @@ async def state_manager(request) -> AsyncGenerator[StateManager, None]:
16851685

16861686
yield state_manager
16871687

1688-
if isinstance(state_manager, StateManagerRedis):
1689-
await state_manager.close()
1688+
await state_manager.close()
16901689

16911690

16921691
@pytest.fixture
@@ -1737,6 +1736,8 @@ async def test_state_manager_modify_state(
17371736
if state_manager._states_locks:
17381737
assert sm2._states_locks != state_manager._states_locks
17391738

1739+
await sm2.close()
1740+
17401741

17411742
@pytest.mark.asyncio
17421743
async def test_state_manager_contend(
@@ -2960,8 +2961,7 @@ def index():
29602961
async for update in state._process(events[1]):
29612962
assert update.delta == exp_is_hydrated(state)
29622963

2963-
if isinstance(app.state_manager, StateManagerRedis):
2964-
await app.state_manager.close()
2964+
await app.state_manager.close()
29652965

29662966

29672967
@pytest.mark.asyncio
@@ -3016,8 +3016,7 @@ def index():
30163016
async for update in state._process(events[2]):
30173017
assert update.delta == exp_is_hydrated(state)
30183018

3019-
if isinstance(app.state_manager, StateManagerRedis):
3020-
await app.state_manager.close()
3019+
await app.state_manager.close()
30213020

30223021

30233022
@pytest.mark.asyncio
@@ -3658,13 +3657,15 @@ class Child(State):
36583657
c = await root.get_state(Child)
36593658
assert s._get_was_touched()
36603659
assert not c._get_was_touched()
3660+
await dsm.close()
36613661

36623662
dsm2 = StateManagerDisk(state=Root)
36633663
root = await dsm2.get_state(token)
36643664
s = await root.get_state(State)
36653665
assert s.num == 43
36663666
c = await root.get_state(Child)
36673667
assert c.foo == "bar"
3668+
await dsm2.close()
36683669

36693670

36703671
class Obj(Base):

0 commit comments

Comments
 (0)