1414from reflex .istate .manager .redis import StateManagerRedis
1515from reflex .state import BaseState , StateUpdate
1616from reflex .utils import console , prerequisites
17+ from reflex .utils .tasks import ensure_task
1718
1819if TYPE_CHECKING :
1920 from redis .asyncio import Redis
@@ -241,6 +242,11 @@ def _handle_socket_record_del(self, token: str) -> None:
241242
242243 async def _subscribe_socket_record_updates (self , redis_db : int ) -> None :
243244 """Subscribe to Redis keyspace notifications for socket record updates."""
245+ await StateManagerRedis (
246+ state = BaseState , redis = self .redis
247+ )._enable_keyspace_notifications ()
248+ redis_db = self .redis .get_connection_kwargs ().get ("db" , 0 )
249+
244250 async with self .redis .pubsub () as pubsub :
245251 await pubsub .psubscribe (
246252 f"__keyspace@{ redis_db } __:{ self ._get_redis_key ('*' )} "
@@ -260,29 +266,14 @@ async def _subscribe_socket_record_updates(self, redis_db: int) -> None:
260266 elif event == "set" :
261267 await self ._get_token_owner (token , refresh = True )
262268
263- async def _socket_record_updates_forever (self ) -> None :
264- """Background task to monitor Redis keyspace notifications for socket record updates."""
265- await StateManagerRedis (
266- state = BaseState , redis = self .redis
267- )._enable_keyspace_notifications ()
268- redis_db = self .redis .get_connection_kwargs ().get ("db" , 0 )
269- while True :
270- try :
271- await self ._subscribe_socket_record_updates (redis_db )
272- except asyncio .CancelledError : # noqa: PERF203
273- raise
274- except Exception as e :
275- if isinstance (e , RuntimeError ) and str (e ) == "no running event loop" :
276- # Happens when shutting down, break out of the loop.
277- raise
278- console .error (f"RedisTokenManager socket record update task error: { e } " )
279-
280269 def _ensure_socket_record_task (self ) -> None :
281270 """Ensure the socket record updates subscriber task is running."""
282- if self ._socket_record_task is None or self ._socket_record_task .done ():
283- self ._socket_record_task = asyncio .create_task (
284- self ._socket_record_updates_forever ()
285- )
271+ ensure_task (
272+ owner = self ,
273+ task_attribute = "_socket_record_task" ,
274+ coro_function = self ._subscribe_socket_record_updates ,
275+ suppress_exceptions = [Exception ],
276+ )
286277
287278 async def link_token_to_sid (self , token : str , sid : str ) -> str | None :
288279 """Link a token to a session ID with Redis-based duplicate detection.
@@ -389,26 +380,6 @@ async def _subscribe_lost_and_found_updates(
389380 record = LostAndFoundRecord (** json .loads (message ["data" ].decode ()))
390381 await emit_update (StateUpdate (** record .update ), record .token )
391382
392- async def _lost_and_found_updates_forever (
393- self ,
394- emit_update : Callable [[StateUpdate , str ], Coroutine [None , None , None ]],
395- ):
396- """Background task to monitor Redis lost and found deltas.
397-
398- Args:
399- emit_update: The function to emit state updates.
400- """
401- while True :
402- try :
403- await self ._subscribe_lost_and_found_updates (emit_update )
404- except asyncio .CancelledError : # noqa: PERF203
405- raise
406- except Exception as e :
407- if isinstance (e , RuntimeError ) and str (e ) == "no running event loop" :
408- # Happens when shutting down, break out of the loop.
409- raise
410- console .error (f"RedisTokenManager lost and found task error: { e } " )
411-
412383 def ensure_lost_and_found_task (
413384 self ,
414385 emit_update : Callable [[StateUpdate , str ], Coroutine [None , None , None ]],
@@ -418,10 +389,13 @@ def ensure_lost_and_found_task(
418389 Args:
419390 emit_update: The function to emit state updates.
420391 """
421- if self ._lost_and_found_task is None or self ._lost_and_found_task .done ():
422- self ._lost_and_found_task = asyncio .create_task (
423- self ._lost_and_found_updates_forever (emit_update )
424- )
392+ ensure_task (
393+ owner = self ,
394+ task_attribute = "_lost_and_found_task" ,
395+ coro_function = self ._subscribe_lost_and_found_updates ,
396+ suppress_exceptions = [Exception ],
397+ emit_update = emit_update ,
398+ )
425399
426400 async def _get_token_owner (self , token : str , refresh : bool = False ) -> str | None :
427401 """Get the instance ID of the owner of a token.
0 commit comments