Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions reflex/app_mixins/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ async def _run_lifespan_tasks(self, app: Starlette):
await event_namespace._token_manager.disconnect_all()
except Exception as e:
console.error(f"Error during lifespan cleanup: {e}")
# Flush any pending writes from the state manager.
try:
state_manager = self.state_manager # pyright: ignore[reportAttributeAccessIssue]
except AttributeError:
pass
else:
await state_manager.close()

def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
"""Register a task to run during the lifespan of the app.
Expand Down
25 changes: 25 additions & 0 deletions reflex/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,26 @@ def interpret_int_env(value: str, field_name: str) -> int:
raise EnvironmentVarValueError(msg) from ve


def interpret_float_env(value: str, field_name: str) -> float:
"""Interpret a float environment variable value.

Args:
value: The environment variable value.
field_name: The field name.

Returns:
The interpreted value.

Raises:
EnvironmentVarValueError: If the value is invalid.
"""
try:
return float(value)
except ValueError as ve:
msg = f"Invalid float value: {value!r} for {field_name}"
raise EnvironmentVarValueError(msg) from ve


def interpret_existing_path_env(value: str, field_name: str) -> ExistingPath:
"""Interpret a path environment variable value as an existing path.

Expand Down Expand Up @@ -228,6 +248,8 @@ def interpret_env_var_value(
return loglevel
if field_type is int:
return interpret_int_env(value, field_name)
if field_type is float:
return interpret_float_env(value, field_name)
if field_type is Path:
return interpret_path_env(value, field_name)
if field_type is ExistingPath:
Expand Down Expand Up @@ -671,6 +693,9 @@ class EnvironmentVariables:
# Whether to mount the compiled frontend app in the backend server in production.
REFLEX_MOUNT_FRONTEND_COMPILED_APP: EnvVar[bool] = env_var(False, internal=True)

# How long to delay writing updated states to disk. (Higher values mean less writes, but more chance of lost data.)
REFLEX_STATE_MANAGER_DISK_DEBOUNCE_SECONDS: EnvVar[float] = env_var(2.0)


environment = EnvironmentVariables()

Expand Down
3 changes: 3 additions & 0 deletions reflex/istate/manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""
yield self.state()

async def close(self): # noqa: B027
"""Close the state manager."""


def _default_token_expiration() -> int:
"""Get the default token expiration time.
Expand Down
156 changes: 151 additions & 5 deletions reflex/istate/manager/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,27 @@
import contextlib
import dataclasses
import functools
import time
from collections.abc import AsyncIterator
from hashlib import md5
from pathlib import Path

from typing_extensions import override

from reflex.environment import environment
from reflex.istate.manager import StateManager, _default_token_expiration
from reflex.state import BaseState, _split_substate_key, _substate_key
from reflex.utils import path_ops, prerequisites
from reflex.utils import console, path_ops, prerequisites
from reflex.utils.misc import run_in_thread


@dataclasses.dataclass(frozen=True)
class QueueItem:
"""An item in the write queue."""

token: str
state: BaseState
timestamp: float


@dataclasses.dataclass
Expand All @@ -34,6 +46,22 @@ class StateManagerDisk(StateManager):
# The token expiration time (s).
token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)

# Last time a token was touched.
_token_last_touched: dict[str, float] = dataclasses.field(
default_factory=dict,
init=False,
)

# Pending writes
_write_queue: dict[str, QueueItem] = dataclasses.field(
default_factory=dict,
init=False,
)
_write_queue_task: asyncio.Task | None = None
_write_debounce_seconds: float = dataclasses.field(
default=environment.REFLEX_STATE_MANAGER_DISK_DEBOUNCE_SECONDS.get()
)

def __post_init__(self):
"""Create a new state manager."""
path_ops.mkdir(self.states_directory)
Expand All @@ -51,8 +79,6 @@ def states_directory(self) -> Path:

def _purge_expired_states(self):
"""Purge expired states from the disk."""
import time

for path in path_ops.ls(self.states_directory):
# check path is a pickle file
if path.suffix != ".pkl":
Expand Down Expand Up @@ -137,6 +163,7 @@ async def get_state(
The state for the token.
"""
client_token = _split_substate_key(token)[0]
self._token_last_touched[client_token] = time.time()
root_state = self.states.get(client_token)
if root_state is not None:
# Retrieved state from memory.
Expand Down Expand Up @@ -170,11 +197,109 @@ async def set_state_for_substate(self, client_token: str, substate: BaseState):
if pickle_state:
if not self.states_directory.exists():
self.states_directory.mkdir(parents=True, exist_ok=True)
self.token_path(substate_token).write_bytes(pickle_state)
await run_in_thread(
lambda: self.token_path(substate_token).write_bytes(pickle_state),
)

for substate_substate in substate.substates.values():
await self.set_state_for_substate(client_token, substate_substate)

async def _process_write_queue_delay(self):
"""Wait for the debounce period before processing the write queue again."""
now = time.time()
if self._write_queue:
# There are still items in the queue, schedule another run.
next_write_in = max(
0,
min(
self._write_debounce_seconds - (now - item.timestamp)
for item in self._write_queue.values()
),
)
await asyncio.sleep(next_write_in)
elif self._write_debounce_seconds > 0:
# No items left, wait a bit before checking again.
await asyncio.sleep(self._write_debounce_seconds)
else:
# Debounce is disabled, so sleep until the next token expiration.
oldest_token_last_touch = min(
self._token_last_touched.values(), default=now
)
next_expiration_in = self.token_expiration - (now - oldest_token_last_touch)
await asyncio.sleep(next_expiration_in)

async def _process_write_queue(self):
"""Long running task that checks for states to write to disk.

Raises:
asyncio.CancelledError: When the task is cancelled.
"""
while True:
try:
now = time.time()
# sort the _write_queue by oldest timestamp and exclude items younger than debounce time
items_to_write = sorted(
(
item
for item in self._write_queue.values()
if now - item.timestamp >= self._write_debounce_seconds
),
key=lambda item: item.timestamp,
)
for item in items_to_write:
token = item.token
client_token, _ = _split_substate_key(token)
await self.set_state_for_substate(
client_token, self._write_queue.pop(token).state
)
# Check for expired states to purge.
for token, last_touched in list(self._token_last_touched.items()):
if now - last_touched > self.token_expiration:
self._token_last_touched.pop(token)
self.states.pop(token, None)
await run_in_thread(self._purge_expired_states)
await self._process_write_queue_delay()
except asyncio.CancelledError: # noqa: PERF203
await self._flush_write_queue()
raise
except Exception as e:
console.error(f"Error processing write queue: {e!r}")
if e.args == ("cannot schedule new futures after shutdown",):
# Event loop is shutdown, nothing else we can really do...
return
await self._process_write_queue_delay()

async def _flush_write_queue(self):
"""Flush any remaining items in the write queue to disk."""
outstanding_items = list(self._write_queue.values())
n_outstanding_items = len(outstanding_items)
self._write_queue.clear()
# When the task is cancelled, write all remaining items to disk.
console.debug(
f"StateManagerDisk._flush_write_queue: writing {n_outstanding_items} remaining items to disk"
)
for item in outstanding_items:
token = item.token
client_token, _ = _split_substate_key(token)
await self.set_state_for_substate(
client_token,
item.state,
)
console.debug(
f"StateManagerDisk._flush_write_queue: Finished writing {n_outstanding_items} items"
)

async def _schedule_process_write_queue(self):
"""Schedule the write queue processing task if not already running."""
if self._write_queue_task is None or self._write_queue_task.done():
async with self._state_manager_lock:
if self._write_queue_task is None or self._write_queue_task.done():
self._write_queue_task = asyncio.create_task(
self._process_write_queue(),
name="StateManagerDisk|WriteQueueProcessor",
)
await asyncio.sleep(0) # Yield to allow the task to start.

@override
async def set_state(self, token: str, state: BaseState):
"""Set the state for a token.
Expand All @@ -184,7 +309,19 @@ async def set_state(self, token: str, state: BaseState):
state: The state to set.
"""
client_token, _ = _split_substate_key(token)
await self.set_state_for_substate(client_token, state)
if self._write_debounce_seconds > 0:
# Deferred write to reduce disk IO overhead.
if client_token not in self._write_queue:
self._write_queue[client_token] = QueueItem(
token=client_token,
state=state,
timestamp=time.time(),
)
else:
# Immediate write to disk.
await self.set_state_for_substate(client_token, state)
# Ensure the processing task is scheduled to handle expirations and any deferred writes.
await self._schedule_process_write_queue()

@override
@contextlib.asynccontextmanager
Expand All @@ -208,3 +345,12 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
state = await self.get_state(token)
yield state
await self.set_state(token, state)

async def close(self):
"""Close the state manager, flushing any pending writes to disk."""
async with self._state_manager_lock:
if self._write_queue_task:
self._write_queue_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._write_queue_task
self._write_queue_task = None
Loading
Loading