44import contextlib
55import dataclasses
66import functools
7+ import time
78from collections .abc import AsyncIterator
89from hashlib import md5
910from pathlib import Path
1011
1112from typing_extensions import override
1213
14+ from reflex .environment import environment
1315from reflex .istate .manager import StateManager , _default_token_expiration
1416from reflex .state import BaseState , _split_substate_key , _substate_key
15- from reflex .utils import path_ops , prerequisites
17+ from reflex .utils import console , path_ops , prerequisites
18+ from reflex .utils .misc import run_in_thread
19+
20+
21+ @dataclasses .dataclass (frozen = True )
22+ class QueueItem :
23+ """An item in the write queue."""
24+
25+ token : str
26+ state : BaseState
27+ timestamp : float
1628
1729
1830@dataclasses .dataclass
@@ -34,6 +46,22 @@ class StateManagerDisk(StateManager):
3446 # The token expiration time (s).
3547 token_expiration : int = dataclasses .field (default_factory = _default_token_expiration )
3648
49+ # Last time a token was touched.
50+ _token_last_touched : dict [str , float ] = dataclasses .field (
51+ default_factory = dict ,
52+ init = False ,
53+ )
54+
55+ # Pending writes
56+ _write_queue : dict [str , QueueItem ] = dataclasses .field (
57+ default_factory = dict ,
58+ init = False ,
59+ )
60+ _write_queue_task : asyncio .Task | None = None
61+ _write_debounce_seconds : float = dataclasses .field (
62+ default = environment .REFLEX_STATE_MANAGER_DISK_DEBOUNCE_SECONDS .get ()
63+ )
64+
3765 def __post_init__ (self ):
3866 """Create a new state manager."""
3967 path_ops .mkdir (self .states_directory )
@@ -51,8 +79,6 @@ def states_directory(self) -> Path:
5179
5280 def _purge_expired_states (self ):
5381 """Purge expired states from the disk."""
54- import time
55-
5682 for path in path_ops .ls (self .states_directory ):
5783 # check path is a pickle file
5884 if path .suffix != ".pkl" :
@@ -137,6 +163,7 @@ async def get_state(
137163 The state for the token.
138164 """
139165 client_token = _split_substate_key (token )[0 ]
166+ self ._token_last_touched [client_token ] = time .time ()
140167 root_state = self .states .get (client_token )
141168 if root_state is not None :
142169 # Retrieved state from memory.
@@ -170,11 +197,109 @@ async def set_state_for_substate(self, client_token: str, substate: BaseState):
170197 if pickle_state :
171198 if not self .states_directory .exists ():
172199 self .states_directory .mkdir (parents = True , exist_ok = True )
173- self .token_path (substate_token ).write_bytes (pickle_state )
200+ await run_in_thread (
201+ lambda : self .token_path (substate_token ).write_bytes (pickle_state ),
202+ )
174203
175204 for substate_substate in substate .substates .values ():
176205 await self .set_state_for_substate (client_token , substate_substate )
177206
207+ async def _process_write_queue_delay (self ):
208+ """Wait for the debounce period before processing the write queue again."""
209+ now = time .time ()
210+ if self ._write_queue :
211+ # There are still items in the queue, schedule another run.
212+ next_write_in = max (
213+ 0 ,
214+ min (
215+ self ._write_debounce_seconds - (now - item .timestamp )
216+ for item in self ._write_queue .values ()
217+ ),
218+ )
219+ await asyncio .sleep (next_write_in )
220+ elif self ._write_debounce_seconds > 0 :
221+ # No items left, wait a bit before checking again.
222+ await asyncio .sleep (self ._write_debounce_seconds )
223+ else :
224+ # Debounce is disabled, so sleep until the next token expiration.
225+ oldest_token_last_touch = min (
226+ self ._token_last_touched .values (), default = now
227+ )
228+ next_expiration_in = self .token_expiration - (now - oldest_token_last_touch )
229+ await asyncio .sleep (next_expiration_in )
230+
231+ async def _process_write_queue (self ):
232+ """Long running task that checks for states to write to disk.
233+
234+ Raises:
235+ asyncio.CancelledError: When the task is cancelled.
236+ """
237+ while True :
238+ try :
239+ now = time .time ()
240+ # sort the _write_queue by oldest timestamp and exclude items younger than debounce time
241+ items_to_write = sorted (
242+ (
243+ item
244+ for item in self ._write_queue .values ()
245+ if now - item .timestamp >= self ._write_debounce_seconds
246+ ),
247+ key = lambda item : item .timestamp ,
248+ )
249+ for item in items_to_write :
250+ token = item .token
251+ client_token , _ = _split_substate_key (token )
252+ await self .set_state_for_substate (
253+ client_token , self ._write_queue .pop (token ).state
254+ )
255+ # Check for expired states to purge.
256+ for token , last_touched in list (self ._token_last_touched .items ()):
257+ if now - last_touched > self .token_expiration :
258+ self ._token_last_touched .pop (token )
259+ self .states .pop (token , None )
260+ await run_in_thread (self ._purge_expired_states )
261+ await self ._process_write_queue_delay ()
262+ except asyncio .CancelledError : # noqa: PERF203
263+ await self ._flush_write_queue ()
264+ raise
265+ except Exception as e :
266+ console .error (f"Error processing write queue: { e !r} " )
267+ if e .args == ("cannot schedule new futures after shutdown" ,):
268+ # Event loop is shutdown, nothing else we can really do...
269+ return
270+ await self ._process_write_queue_delay ()
271+
272+ async def _flush_write_queue (self ):
273+ """Flush any remaining items in the write queue to disk."""
274+ outstanding_items = list (self ._write_queue .values ())
275+ n_outstanding_items = len (outstanding_items )
276+ self ._write_queue .clear ()
277+ # When the task is cancelled, write all remaining items to disk.
278+ console .debug (
279+ f"StateManagerDisk._flush_write_queue: writing { n_outstanding_items } remaining items to disk"
280+ )
281+ for item in outstanding_items :
282+ token = item .token
283+ client_token , _ = _split_substate_key (token )
284+ await self .set_state_for_substate (
285+ client_token ,
286+ item .state ,
287+ )
288+ console .debug (
289+ f"StateManagerDisk._flush_write_queue: Finished writing { n_outstanding_items } items"
290+ )
291+
292+ async def _schedule_process_write_queue (self ):
293+ """Schedule the write queue processing task if not already running."""
294+ if self ._write_queue_task is None or self ._write_queue_task .done ():
295+ async with self ._state_manager_lock :
296+ if self ._write_queue_task is None or self ._write_queue_task .done ():
297+ self ._write_queue_task = asyncio .create_task (
298+ self ._process_write_queue (),
299+ name = "StateManagerDisk|WriteQueueProcessor" ,
300+ )
301+ await asyncio .sleep (0 ) # Yield to allow the task to start.
302+
178303 @override
179304 async def set_state (self , token : str , state : BaseState ):
180305 """Set the state for a token.
@@ -184,7 +309,19 @@ async def set_state(self, token: str, state: BaseState):
184309 state: The state to set.
185310 """
186311 client_token , _ = _split_substate_key (token )
187- await self .set_state_for_substate (client_token , state )
312+ if self ._write_debounce_seconds > 0 :
313+ # Deferred write to reduce disk IO overhead.
314+ if client_token not in self ._write_queue :
315+ self ._write_queue [client_token ] = QueueItem (
316+ token = client_token ,
317+ state = state ,
318+ timestamp = time .time (),
319+ )
320+ else :
321+ # Immediate write to disk.
322+ await self .set_state_for_substate (client_token , state )
323+ # Ensure the processing task is scheduled to handle expirations and any deferred writes.
324+ await self ._schedule_process_write_queue ()
188325
189326 @override
190327 @contextlib .asynccontextmanager
@@ -208,3 +345,12 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
208345 state = await self .get_state (token )
209346 yield state
210347 await self .set_state (token , state )
348+
349+ async def close (self ):
350+ """Close the state manager, flushing any pending writes to disk."""
351+ async with self ._state_manager_lock :
352+ if self ._write_queue_task :
353+ self ._write_queue_task .cancel ()
354+ with contextlib .suppress (asyncio .CancelledError ):
355+ await self ._write_queue_task
356+ self ._write_queue_task = None
0 commit comments