Skip to content

Commit 6a4db93

Browse files
committed
generalize "forever" tasks to centralize exception handling/retry
1 parent 871c885 commit 6a4db93

File tree

4 files changed

+258
-68
lines changed

4 files changed

+258
-68
lines changed

reflex/istate/manager/redis.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
LockExpiredError,
3030
StateSchemaMismatchError,
3131
)
32+
from reflex.utils.tasks import ensure_task
3233

3334

3435
def _default_lock_expiration() -> int:
@@ -155,7 +156,7 @@ def __post_init__(self):
155156
raise InvalidLockWarningThresholdError(msg)
156157
with contextlib.suppress(RuntimeError):
157158
asyncio.get_running_loop() # Check if we're in an event loop.
158-
self._lock_task = asyncio.create_task(self._lock_updates_forever())
159+
self._ensure_lock_task()
159160

160161
def _get_required_state_classes(
161162
self,
@@ -586,7 +587,7 @@ async def _create_lease_break_task(
586587
Returns:
587588
The lease break task, or None when there is contention.
588589
"""
589-
await self._ensure_lock_task()
590+
self._ensure_lock_task()
590591

591592
client_token, _ = _split_substate_key(token)
592593

@@ -765,6 +766,9 @@ async def _subscribe_lock_updates(self, redis_db: int = 0):
765766
Args:
766767
redis_db: The logical database number to subscribe to.
767768
"""
769+
await self._enable_keyspace_notifications()
770+
redis_db = self.redis.get_connection_kwargs().get("db", 0)
771+
768772
lock_key_pattern = f"__keyspace@{redis_db}__:*_lock"
769773
lock_waiter_key_pattern = f"__keyspace@{redis_db}__:*_lock_waiters"
770774
handlers = {
@@ -776,27 +780,14 @@ async def _subscribe_lock_updates(self, redis_db: int = 0):
776780
async for _ in pubsub.listen():
777781
pass
778782

779-
async def _lock_updates_forever(self) -> None:
780-
"""Background task to monitor Redis keyspace notifications for lock updates."""
781-
await self._enable_keyspace_notifications()
782-
redis_db = self.redis.get_connection_kwargs().get("db", 0)
783-
while True:
784-
try:
785-
await self._subscribe_lock_updates(redis_db)
786-
except asyncio.CancelledError: # noqa: PERF203
787-
raise
788-
except Exception as e:
789-
if isinstance(e, RuntimeError) and str(e) == "no running event loop":
790-
# Happens when shutting down, break out of the loop.
791-
raise
792-
console.error(f"StateManagerRedis lock update task error: {e}")
793-
794-
async def _ensure_lock_task(self) -> None:
783+
def _ensure_lock_task(self) -> None:
795784
"""Ensure the lock updates subscriber task is running."""
796-
if self._lock_task is None or self._lock_task.done():
797-
async with self._state_manager_lock:
798-
if self._lock_task is None or self._lock_task.done():
799-
self._lock_task = asyncio.create_task(self._lock_updates_forever())
785+
ensure_task(
786+
owner=self,
787+
task_attribute="_lock_task",
788+
coro_function=self._subscribe_lock_updates,
789+
suppress_exceptions=[Exception],
790+
)
800791

801792
async def _enable_keyspace_notifications(self):
802793
"""Enable keyspace notifications for the redis server.
@@ -954,7 +945,7 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
954945
)
955946
return
956947
# Make sure lock waiter task is running.
957-
await self._ensure_lock_task()
948+
self._ensure_lock_task()
958949
async with (
959950
self._lock_waiter(lock_key) as lock_released_event,
960951
self._request_lock_release(lock_key, lock_id),

reflex/utils/tasks.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Helpers for managing asyncio tasks."""
2+
3+
import asyncio
4+
import time
5+
from collections.abc import Callable, Coroutine
6+
from typing import Any
7+
8+
from reflex.utils import console
9+
10+
11+
async def _run_forever(
12+
coro_function: Callable[..., Coroutine],
13+
*args: Any,
14+
suppress_exceptions: list[type[BaseException]],
15+
exception_delay: float,
16+
exception_limit: int,
17+
exception_limit_window: float,
18+
**kwargs: Any,
19+
):
20+
"""Wrapper to continuously run a coroutine function, suppressing certain exceptions.
21+
22+
Args:
23+
coro_function: The coroutine function to run.
24+
*args: The arguments to pass to the coroutine function.
25+
suppress_exceptions: The exceptions to suppress.
26+
exception_delay: The delay between retries when an exception is suppressed.
27+
exception_limit: The maximum number of suppressed exceptions within the limit window before raising.
28+
exception_limit_window: The time window in seconds for counting suppressed exceptions.
29+
**kwargs: The keyword arguments to pass to the coroutine function.
30+
"""
31+
last_regular_loop_start = 0
32+
exception_count = 0
33+
34+
while True:
35+
# Reset the exception count when the limit window has elapsed since the last non-exception loop started.
36+
if last_regular_loop_start + exception_limit_window < time.monotonic():
37+
exception_count = 0
38+
if not exception_count:
39+
last_regular_loop_start = time.monotonic()
40+
try:
41+
await coro_function(*args, **kwargs)
42+
except (asyncio.CancelledError, RuntimeError):
43+
raise
44+
except Exception as e:
45+
if any(isinstance(e, ex) for ex in suppress_exceptions):
46+
exception_count += 1
47+
if exception_count >= exception_limit:
48+
console.error(
49+
f"{coro_function.__name__}: task exceeded exception limit {exception_limit} within {exception_limit_window}s: {e}"
50+
)
51+
raise
52+
console.error(f"{coro_function.__name__}: task error suppressed: {e}")
53+
await asyncio.sleep(exception_delay)
54+
continue
55+
raise
56+
57+
58+
def ensure_task(
59+
owner: Any,
60+
task_attribute: str,
61+
coro_function: Callable[..., Coroutine],
62+
*args: Any,
63+
suppress_exceptions: list[type[BaseException]] | None = None,
64+
exception_delay: float = 1.0,
65+
exception_limit: int = 5,
66+
exception_limit_window: float = 60.0,
67+
**kwargs: Any,
68+
) -> asyncio.Task:
69+
"""Ensure that a task is running for the given coroutine function.
70+
71+
Note: if the task is already running, args and kwargs are ignored.
72+
73+
Args:
74+
owner: The owner of the task.
75+
task_attribute: The attribute name to store/retrieve the task from the owner object.
76+
coro_function: The coroutine function to run as a task.
77+
suppress_exceptions: The exceptions to log and continue when running the coroutine.
78+
exception_delay: The delay between retries when an exception is suppressed.
79+
exception_limit: The maximum number of suppressed exceptions within the limit window before raising.
80+
exception_limit_window: The time window in seconds for counting suppressed exceptions.
81+
*args: The arguments to pass to the coroutine function.
82+
**kwargs: The keyword arguments to pass to the coroutine function.
83+
84+
Returns:
85+
The asyncio task running the coroutine function.
86+
"""
87+
if suppress_exceptions is None:
88+
suppress_exceptions = []
89+
if RuntimeError in suppress_exceptions:
90+
msg = "Cannot suppress RuntimeError exceptions which may be raised by asyncio machinery."
91+
raise RuntimeError(msg)
92+
93+
task = getattr(owner, task_attribute, None)
94+
if task is None or task.done():
95+
asyncio.get_running_loop() # Ensure we're in an event loop.
96+
task = asyncio.create_task(
97+
_run_forever(
98+
coro_function,
99+
*args,
100+
suppress_exceptions=suppress_exceptions,
101+
exception_delay=exception_delay,
102+
exception_limit=exception_limit,
103+
exception_limit_window=exception_limit_window,
104+
**kwargs,
105+
),
106+
name=f"reflex_ensure_task|{type(owner).__name__}.{task_attribute}={coro_function.__name__}|{time.time()}",
107+
)
108+
setattr(owner, task_attribute, task)
109+
return task

reflex/utils/token_manager.py

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from reflex.istate.manager.redis import StateManagerRedis
1515
from reflex.state import BaseState, StateUpdate
1616
from reflex.utils import console, prerequisites
17+
from reflex.utils.tasks import ensure_task
1718

1819
if TYPE_CHECKING:
1920
from redis.asyncio import Redis
@@ -241,6 +242,11 @@ def _handle_socket_record_del(self, token: str) -> None:
241242

242243
async def _subscribe_socket_record_updates(self, redis_db: int) -> None:
243244
"""Subscribe to Redis keyspace notifications for socket record updates."""
245+
await StateManagerRedis(
246+
state=BaseState, redis=self.redis
247+
)._enable_keyspace_notifications()
248+
redis_db = self.redis.get_connection_kwargs().get("db", 0)
249+
244250
async with self.redis.pubsub() as pubsub:
245251
await pubsub.psubscribe(
246252
f"__keyspace@{redis_db}__:{self._get_redis_key('*')}"
@@ -260,29 +266,14 @@ async def _subscribe_socket_record_updates(self, redis_db: int) -> None:
260266
elif event == "set":
261267
await self._get_token_owner(token, refresh=True)
262268

263-
async def _socket_record_updates_forever(self) -> None:
264-
"""Background task to monitor Redis keyspace notifications for socket record updates."""
265-
await StateManagerRedis(
266-
state=BaseState, redis=self.redis
267-
)._enable_keyspace_notifications()
268-
redis_db = self.redis.get_connection_kwargs().get("db", 0)
269-
while True:
270-
try:
271-
await self._subscribe_socket_record_updates(redis_db)
272-
except asyncio.CancelledError: # noqa: PERF203
273-
raise
274-
except Exception as e:
275-
if isinstance(e, RuntimeError) and str(e) == "no running event loop":
276-
# Happens when shutting down, break out of the loop.
277-
raise
278-
console.error(f"RedisTokenManager socket record update task error: {e}")
279-
280269
def _ensure_socket_record_task(self) -> None:
281270
"""Ensure the socket record updates subscriber task is running."""
282-
if self._socket_record_task is None or self._socket_record_task.done():
283-
self._socket_record_task = asyncio.create_task(
284-
self._socket_record_updates_forever()
285-
)
271+
ensure_task(
272+
owner=self,
273+
task_attribute="_socket_record_task",
274+
coro_function=self._subscribe_socket_record_updates,
275+
suppress_exceptions=[Exception],
276+
)
286277

287278
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
288279
"""Link a token to a session ID with Redis-based duplicate detection.
@@ -389,26 +380,6 @@ async def _subscribe_lost_and_found_updates(
389380
record = LostAndFoundRecord(**json.loads(message["data"].decode()))
390381
await emit_update(StateUpdate(**record.update), record.token)
391382

392-
async def _lost_and_found_updates_forever(
393-
self,
394-
emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]],
395-
):
396-
"""Background task to monitor Redis lost and found deltas.
397-
398-
Args:
399-
emit_update: The function to emit state updates.
400-
"""
401-
while True:
402-
try:
403-
await self._subscribe_lost_and_found_updates(emit_update)
404-
except asyncio.CancelledError: # noqa: PERF203
405-
raise
406-
except Exception as e:
407-
if isinstance(e, RuntimeError) and str(e) == "no running event loop":
408-
# Happens when shutting down, break out of the loop.
409-
raise
410-
console.error(f"RedisTokenManager lost and found task error: {e}")
411-
412383
def ensure_lost_and_found_task(
413384
self,
414385
emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]],
@@ -418,10 +389,13 @@ def ensure_lost_and_found_task(
418389
Args:
419390
emit_update: The function to emit state updates.
420391
"""
421-
if self._lost_and_found_task is None or self._lost_and_found_task.done():
422-
self._lost_and_found_task = asyncio.create_task(
423-
self._lost_and_found_updates_forever(emit_update)
424-
)
392+
ensure_task(
393+
owner=self,
394+
task_attribute="_lost_and_found_task",
395+
coro_function=self._subscribe_lost_and_found_updates,
396+
suppress_exceptions=[Exception],
397+
emit_update=emit_update,
398+
)
425399

426400
async def _get_token_owner(self, token: str, refresh: bool = False) -> str | None:
427401
"""Get the instance ID of the owner of a token.

0 commit comments

Comments
 (0)