22
33from __future__ import annotations
44
5+ import asyncio
56import dataclasses
67import json
78import uuid
89from abc import ABC , abstractmethod
910from types import MappingProxyType
1011from typing import TYPE_CHECKING
1112
13+ from reflex .istate .manager .redis import StateManagerRedis
14+ from reflex .state import BaseState
1215from reflex .utils import console , prerequisites
1316
1417if TYPE_CHECKING :
@@ -173,6 +176,7 @@ def __init__(self, redis: Redis):
173176
174177 config = get_config ()
175178 self .token_expiration = config .redis_token_expiration
179+ self ._update_task = None
176180
177181 def _get_redis_key (self , token : str ) -> str :
178182 """Get Redis key for token mapping.
@@ -185,6 +189,48 @@ def _get_redis_key(self, token: str) -> str:
185189 """
186190 return f"token_manager_socket_record_{ token } "
187191
192+ async def _socket_record_update_task (self ) -> None :
193+ """Background task to monitor Redis keyspace notifications for socket record updates."""
194+ await StateManagerRedis (
195+ state = BaseState , redis = self .redis
196+ )._enable_keyspace_notifications ()
197+ redis_db = self .redis .get_connection_kwargs ().get ("db" , 0 )
198+ while True :
199+ try :
200+ await self ._subscribe_socket_record_updates (redis_db )
201+ except asyncio .CancelledError : # noqa: PERF203
202+ break
203+ except Exception as e :
204+ console .error (f"RedisTokenManager socket record update task error: { e } " )
205+
206+ async def _subscribe_socket_record_updates (self , redis_db : int ) -> None :
207+ """Subscribe to Redis keyspace notifications for socket record updates."""
208+ pubsub = self .redis .pubsub ()
209+ await pubsub .psubscribe (
210+ f"__keyspace@{ redis_db } __:token_manager_socket_record_*"
211+ )
212+
213+ async for message in pubsub .listen ():
214+ if message ["type" ] == "pmessage" :
215+ key = message ["channel" ].split (b":" , 1 )[1 ].decode ()
216+ event = message ["data" ].decode ()
217+ token = key .replace ("token_manager_socket_record_" , "" )
218+
219+ if event in ("del" , "expired" , "evicted" ):
220+ # Remove from local dicts if exists
221+ if (
222+ socket_record := self .token_to_socket .pop (token , None )
223+ ) is not None :
224+ self .sid_to_token .pop (socket_record .sid , None )
225+ elif event == "set" :
226+ # Fetch updated record from Redis
227+ record_json = await self .redis .get (key )
228+ if record_json :
229+ record_data = json .loads (record_json )
230+ socket_record = SocketRecord (** record_data )
231+ self .token_to_socket [token ] = socket_record
232+ self .sid_to_token [socket_record .sid ] = token
233+
188234 async def link_token_to_sid (self , token : str , sid : str ) -> str | None :
189235 """Link a token to a session ID with Redis-based duplicate detection.
190236
@@ -201,6 +247,10 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None:
201247 ) is not None and sid == socket_record .sid :
202248 return None # Same token, same SID = reconnection, no Redis check needed
203249
250+ # Make sure the update subscriber is running
251+ if self ._update_task is None or self ._update_task .done ():
252+ self ._update_task = asyncio .create_task (self ._socket_record_update_task ())
253+
204254 # Check Redis for cross-worker duplicates
205255 redis_key = self ._get_redis_key (token )
206256
0 commit comments