77import json
88import uuid
99from abc import ABC , abstractmethod
10+ from collections .abc import Callable , Coroutine
1011from types import MappingProxyType
11- from typing import TYPE_CHECKING
12+ from typing import TYPE_CHECKING , Any
1213
1314from reflex .istate .manager .redis import StateManagerRedis
14- from reflex .state import BaseState
15+ from reflex .state import BaseState , StateUpdate
1516from reflex .utils import console , prerequisites
1617
1718if TYPE_CHECKING :
@@ -35,6 +36,14 @@ class SocketRecord:
3536 sid : str
3637
3738
39+ @dataclasses .dataclass (frozen = True , kw_only = True )
40+ class LostAndFoundRecord :
41+ """Record for a StateUpdate for a token with its socket on another instance."""
42+
43+ token : str
44+ update : dict [str , Any ]
45+
46+
3847class TokenManager (ABC ):
3948 """Abstract base class for managing client token to session ID mappings."""
4049
@@ -176,7 +185,10 @@ def __init__(self, redis: Redis):
176185
177186 config = get_config ()
178187 self .token_expiration = config .redis_token_expiration
179- self ._update_task = None
188+
189+ # Pub/sub tasks for handling sockets owned by other instances.
190+ self ._socket_record_task : asyncio .Task | None = None
191+ self ._lost_and_found_task : asyncio .Task | None = None
180192
181193 def _get_redis_key (self , token : str ) -> str :
182194 """Get Redis key for token mapping.
@@ -189,7 +201,53 @@ def _get_redis_key(self, token: str) -> str:
189201 """
190202 return f"token_manager_socket_record_{ token } "
191203
192- async def _socket_record_update_task (self ) -> None :
204+ def _handle_socket_record_del (self , token : str ) -> None :
205+ """Handle deletion of a socket record from Redis.
206+
207+ Args:
208+ token: The client token whose record was deleted.
209+ """
210+ if (
211+ socket_record := self .token_to_socket .pop (token , None )
212+ ) is not None and socket_record .instance_id != self .instance_id :
213+ self .sid_to_token .pop (socket_record .sid , None )
214+
215+ async def _handle_socket_record_set (self , token : str ) -> None :
216+ """Handle setting/updating of a socket record from Redis.
217+
218+ Args:
219+ token: The client token whose record was set/updated.
220+ """
221+ # Fetch updated record from Redis
222+ record_json = await self .redis .get (self ._get_redis_key (token ))
223+ if record_json :
224+ record_data = json .loads (record_json )
225+ socket_record = SocketRecord (** record_data )
226+ self .token_to_socket [token ] = socket_record
227+ self .sid_to_token [socket_record .sid ] = token
228+
229+ async def _subscribe_socket_record_updates (self , redis_db : int ) -> None :
230+ """Subscribe to Redis keyspace notifications for socket record updates."""
231+ async with self .redis .pubsub () as pubsub :
232+ await pubsub .psubscribe (
233+ f"__keyspace@{ redis_db } __:token_manager_socket_record_*"
234+ )
235+ async for message in pubsub .listen ():
236+ if message ["type" ] == "pmessage" :
237+ key = message ["channel" ].split (b":" , 1 )[1 ].decode ()
238+ token = key .replace ("token_manager_socket_record_" , "" )
239+
240+ if token not in self .token_to_socket :
241+ # We don't know about this token, skip
242+ continue
243+
244+ event = message ["data" ].decode ()
245+ if event in ("del" , "expired" , "evicted" ):
246+ self ._handle_socket_record_del (token )
247+ elif event == "set" :
248+ await self ._handle_socket_record_set (token )
249+
250+ async def _socket_record_updates_forever (self ) -> None :
193251 """Background task to monitor Redis keyspace notifications for socket record updates."""
194252 await StateManagerRedis (
195253 state = BaseState , redis = self .redis
@@ -203,33 +261,12 @@ async def _socket_record_update_task(self) -> None:
203261 except Exception as e :
204262 console .error (f"RedisTokenManager socket record update task error: { e } " )
205263
206- async def _subscribe_socket_record_updates (self , redis_db : int ) -> None :
207- """Subscribe to Redis keyspace notifications for socket record updates."""
208- pubsub = self .redis .pubsub ()
209- await pubsub .psubscribe (
210- f"__keyspace@{ redis_db } __:token_manager_socket_record_*"
211- )
212-
213- async for message in pubsub .listen ():
214- if message ["type" ] == "pmessage" :
215- key = message ["channel" ].split (b":" , 1 )[1 ].decode ()
216- event = message ["data" ].decode ()
217- token = key .replace ("token_manager_socket_record_" , "" )
218-
219- if event in ("del" , "expired" , "evicted" ):
220- # Remove from local dicts if exists
221- if (
222- socket_record := self .token_to_socket .pop (token , None )
223- ) is not None :
224- self .sid_to_token .pop (socket_record .sid , None )
225- elif event == "set" :
226- # Fetch updated record from Redis
227- record_json = await self .redis .get (key )
228- if record_json :
229- record_data = json .loads (record_json )
230- socket_record = SocketRecord (** record_data )
231- self .token_to_socket [token ] = socket_record
232- self .sid_to_token [socket_record .sid ] = token
264+ def _ensure_socket_record_task (self ) -> None :
265+ """Ensure the socket record updates subscriber task is running."""
266+ if self ._socket_record_task is None or self ._socket_record_task .done ():
267+ self ._socket_record_task = asyncio .create_task (
268+ self ._socket_record_updates_forever ()
269+ )
233270
234271 async def link_token_to_sid (self , token : str , sid : str ) -> str | None :
235272 """Link a token to a session ID with Redis-based duplicate detection.
@@ -248,8 +285,7 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None:
248285 return None # Same token, same SID = reconnection, no Redis check needed
249286
250287 # Make sure the update subscriber is running
251- if self ._update_task is None or self ._update_task .done ():
252- self ._update_task = asyncio .create_task (self ._socket_record_update_task ())
288+ self ._ensure_socket_record_task ()
253289
254290 # Check Redis for cross-worker duplicates
255291 redis_key = self ._get_redis_key (token )
@@ -293,8 +329,10 @@ async def disconnect_token(self, token: str, sid: str) -> None:
293329 """
294330 # Only clean up if we own it locally (fast ownership check)
295331 if (
296- socket_record := self .token_to_socket .get (token )
297- ) is not None and socket_record .sid == sid :
332+ (socket_record := self .token_to_socket .get (token )) is not None
333+ and socket_record .sid == sid
334+ and socket_record .instance_id == self .instance_id
335+ ):
298336 # Clean up Redis
299337 redis_key = self ._get_redis_key (token )
300338 try :
@@ -304,3 +342,124 @@ async def disconnect_token(self, token: str, sid: str) -> None:
304342
305343 # Clean up local dicts (always do this)
306344 await super ().disconnect_token (token , sid )
345+
346+ @staticmethod
347+ def _get_lost_and_found_key (instance_id : str ) -> str :
348+ """Get the Redis key for lost and found deltas for an instance.
349+
350+ Args:
351+ instance_id: The instance ID.
352+
353+ Returns:
354+ The Redis key for lost and found deltas.
355+ """
356+ return f"token_manager_lost_and_found_{ instance_id } "
357+
358+ async def _subscribe_lost_and_found_updates (
359+ self ,
360+ emit_update : Callable [[StateUpdate , str ], Coroutine [None , None , None ]],
361+ ) -> None :
362+ """Subscribe to Redis channel notifications for lost and found deltas.
363+
364+ Args:
365+ emit_update: The function to emit state updates.
366+ """
367+ async with self .redis .pubsub () as pubsub :
368+ await pubsub .psubscribe (
369+ f"channel:{ self ._get_lost_and_found_key (self .instance_id )} "
370+ )
371+ async for message in pubsub .listen ():
372+ if message ["type" ] == "pmessage" :
373+ record = LostAndFoundRecord (** json .loads (message ["data" ].decode ()))
374+ await emit_update (StateUpdate (** record .update ), record .token )
375+
376+ async def _lost_and_found_updates_forever (
377+ self ,
378+ emit_update : Callable [[StateUpdate , str ], Coroutine [None , None , None ]],
379+ ):
380+ """Background task to monitor Redis lost and found deltas.
381+
382+ Args:
383+ emit_update: The function to emit state updates.
384+ """
385+ while True :
386+ try :
387+ await self ._subscribe_lost_and_found_updates (emit_update )
388+ except asyncio .CancelledError : # noqa: PERF203
389+ break
390+ except Exception as e :
391+ console .error (f"RedisTokenManager lost and found task error: { e } " )
392+
393+ def ensure_lost_and_found_task (
394+ self ,
395+ emit_update : Callable [[StateUpdate , str ], Coroutine [None , None , None ]],
396+ ) -> None :
397+ """Ensure the lost and found subscriber task is running.
398+
399+ Args:
400+ emit_update: The function to emit state updates.
401+ """
402+ if self ._lost_and_found_task is None or self ._lost_and_found_task .done ():
403+ self ._lost_and_found_task = asyncio .create_task (
404+ self ._lost_and_found_updates_forever (emit_update )
405+ )
406+
407+ async def _get_token_owner (self , token : str , refresh : bool = False ) -> str | None :
408+ """Get the instance ID of the owner of a token.
409+
410+ Args:
411+ token: The client token.
412+ refresh: Whether to fetch the latest record from Redis.
413+
414+ Returns:
415+ The instance ID of the owner, or None if not found.
416+ """
417+ if (
418+ not refresh
419+ and (socket_record := self .token_to_socket .get (token )) is not None
420+ ):
421+ return socket_record .instance_id
422+
423+ redis_key = self ._get_redis_key (token )
424+ try :
425+ record_json = await self .redis .get (redis_key )
426+ if record_json :
427+ record_data = json .loads (record_json )
428+ socket_record = SocketRecord (** record_data )
429+ self .token_to_socket [token ] = socket_record
430+ self .sid_to_token [socket_record .sid ] = token
431+ return socket_record .instance_id
432+ console .error (f"Redis token owner not found for token { token } " )
433+ except Exception as e :
434+ console .error (f"Redis error getting token owner: { e } " )
435+ return None
436+
437+ async def emit_lost_and_found (
438+ self ,
439+ token : str ,
440+ update : StateUpdate ,
441+ ) -> bool :
442+ """Emit a lost and found delta to Redis.
443+
444+ Args:
445+ token: The client token.
446+ update: The state update.
447+
448+ Returns:
449+ True if the delta was published, False otherwise.
450+ """
451+ # See where this update belongs
452+ owner_instance_id = await self ._get_token_owner (token )
453+ if owner_instance_id is None :
454+ return False
455+ record = LostAndFoundRecord (token = token , update = dataclasses .asdict (update ))
456+ try :
457+ await self .redis .publish (
458+ f"channel:{ self ._get_lost_and_found_key (owner_instance_id )} " ,
459+ json .dumps (dataclasses .asdict (record )),
460+ )
461+ except Exception as e :
462+ console .error (f"Redis error publishing lost and found delta: { e } " )
463+ else :
464+ return True
465+ return False
0 commit comments