22
33from __future__ import annotations
44
5+ import dataclasses
6+ import json
57import uuid
68from abc import ABC , abstractmethod
9+ from types import MappingProxyType
710from typing import TYPE_CHECKING
811
912from reflex .utils import console , prerequisites
@@ -21,16 +24,37 @@ def _get_new_token() -> str:
2124 return str (uuid .uuid4 ())
2225
2326
27+ @dataclasses .dataclass (frozen = True , kw_only = True )
28+ class SocketRecord :
29+ """Record for a connected socket client."""
30+
31+ instance_id : str
32+ sid : str
33+
34+
2435class TokenManager (ABC ):
2536 """Abstract base class for managing client token to session ID mappings."""
2637
2738 def __init__ (self ):
2839 """Initialize the token manager with local dictionaries."""
29- # Keep a mapping between socket ID and client token .
30- self .token_to_sid : dict [ str , str ] = {}
40+ # Each process has an instance_id to identify its own sockets .
41+ self .instance_id : str = _get_new_token ()
3142 # Keep a mapping between client token and socket ID.
43+ self .token_to_socket : dict [str , SocketRecord ] = {}
44+ # Keep a mapping between socket ID and client token.
3245 self .sid_to_token : dict [str , str ] = {}
3346
47+ @property
48+ def token_to_sid (self ) -> MappingProxyType [str , str ]:
49+ """Read-only compatibility property for token_to_socket mapping.
50+
51+ Returns:
52+ The token to session ID mapping.
53+ """
54+ return MappingProxyType ({
55+ token : sr .sid for token , sr in self .token_to_socket .items ()
56+ })
57+
3458 @abstractmethod
3559 async def link_token_to_sid (self , token : str , sid : str ) -> str | None :
3660 """Link a token to a session ID.
@@ -68,7 +92,9 @@ def create(cls) -> TokenManager:
6892
6993 async def disconnect_all (self ):
7094 """Disconnect all tracked tokens when the server is going down."""
71- token_sid_pairs : set [tuple [str , str ]] = set (self .token_to_sid .items ())
95+ token_sid_pairs : set [tuple [str , str ]] = {
96+ (token , sr .sid ) for token , sr in self .token_to_socket .items ()
97+ }
7298 token_sid_pairs .update (
7399 ((token , sid ) for sid , token in self .sid_to_token .items ())
74100 )
@@ -95,14 +121,20 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None:
95121 New token if duplicate detected and new token generated, None otherwise.
96122 """
97123 # Check if token is already mapped to a different SID (duplicate tab)
98- if token in self .token_to_sid and sid != self .token_to_sid .get (token ):
124+ if (
125+ socket_record := self .token_to_socket .get (token )
126+ ) is not None and sid != socket_record .sid :
99127 new_token = _get_new_token ()
100- self .token_to_sid [new_token ] = sid
128+ self .token_to_socket [new_token ] = SocketRecord (
129+ instance_id = self .instance_id , sid = sid
130+ )
101131 self .sid_to_token [sid ] = new_token
102132 return new_token
103133
104134 # Normal case - link token to SID
105- self .token_to_sid [token ] = sid
135+ self .token_to_socket [token ] = SocketRecord (
136+ instance_id = self .instance_id , sid = sid
137+ )
106138 self .sid_to_token [sid ] = token
107139 return None
108140
@@ -114,7 +146,7 @@ async def disconnect_token(self, token: str, sid: str) -> None:
114146 sid: The Socket.IO session ID.
115147 """
116148 # Clean up both mappings
117- self .token_to_sid .pop (token , None )
149+ self .token_to_socket .pop (token , None )
118150 self .sid_to_token .pop (sid , None )
119151
120152
@@ -149,9 +181,9 @@ def _get_redis_key(self, token: str) -> str:
149181 token: The client token.
150182
151183 Returns:
152- Redis key following Reflex conventions: {token}_sid
184+ Redis key following Reflex conventions: token_manager_socket_record_ {token}
153185 """
154- return f"{ token } _sid "
186+ return f"token_manager_socket_record_ { token } "
155187
156188 async def link_token_to_sid (self , token : str , sid : str ) -> str | None :
157189 """Link a token to a session ID with Redis-based duplicate detection.
@@ -164,7 +196,9 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None:
164196 New token if duplicate detected and new token generated, None otherwise.
165197 """
166198 # Fast local check first (handles reconnections)
167- if token in self .token_to_sid and self .token_to_sid [token ] == sid :
199+ if (
200+ socket_record := self .token_to_socket .get (token )
201+ ) is not None and sid == socket_record .sid :
168202 return None # Same token, same SID = reconnection, no Redis check needed
169203
170204 # Check Redis for cross-worker duplicates
@@ -176,34 +210,29 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None:
176210 console .error (f"Redis error checking token existence: { e } " )
177211 return await super ().link_token_to_sid (token , sid )
178212
213+ new_token = None
179214 if token_exists_in_redis :
180215 # Duplicate exists somewhere - generate new token
181- new_token = _get_new_token ()
182- new_redis_key = self ._get_redis_key (new_token )
216+ token = new_token = _get_new_token ()
217+ redis_key = self ._get_redis_key (new_token )
183218
184- try :
185- # Store in Redis
186- await self .redis .set (new_redis_key , "1" , ex = self .token_expiration )
187- except Exception as e :
188- console .error (f"Redis error storing new token: { e } " )
189- # Still update local dicts and continue
190-
191- # Store in local dicts (always do this)
192- self .token_to_sid [new_token ] = sid
193- self .sid_to_token [sid ] = new_token
194- return new_token
219+ # Store in local dicts
220+ socket_record = self .token_to_socket [token ] = SocketRecord (
221+ instance_id = self .instance_id , sid = sid
222+ )
223+ self .sid_to_token [sid ] = token
195224
196- # Normal case - store in both Redis and local dicts
225+ # Store in Redis if possible
197226 try :
198- await self .redis .set (redis_key , "1" , ex = self .token_expiration )
227+ await self .redis .set (
228+ redis_key ,
229+ json .dumps (dataclasses .asdict (socket_record )),
230+ ex = self .token_expiration ,
231+ )
199232 except Exception as e :
200233 console .error (f"Redis error storing token: { e } " )
201- # Continue with local storage
202-
203- # Store in local dicts (always do this)
204- self .token_to_sid [token ] = sid
205- self .sid_to_token [sid ] = token
206- return None
234+ # Return the new token if one was generated
235+ return new_token
207236
208237 async def disconnect_token (self , token : str , sid : str ) -> None :
209238 """Clean up token mapping when client disconnects.
@@ -213,7 +242,9 @@ async def disconnect_token(self, token: str, sid: str) -> None:
213242 sid: The Socket.IO session ID.
214243 """
215244 # Only clean up if we own it locally (fast ownership check)
216- if self .token_to_sid .get (token ) == sid :
245+ if (
246+ socket_record := self .token_to_socket .get (token )
247+ ) is not None and socket_record .sid == sid :
217248 # Clean up Redis
218249 redis_key = self ._get_redis_key (token )
219250 try :
0 commit comments