Skip to content

Commit e3b9787

Browse files
committed
Token manager tracks instance_id in token_to_socket
1 parent 6b043d6 commit e3b9787

File tree

4 files changed

+106
-46
lines changed

4 files changed

+106
-46
lines changed

reflex/app.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,11 +2033,13 @@ def __init__(self, namespace: str, app: App):
20332033
self._token_manager = TokenManager.create()
20342034

20352035
@property
2036-
def token_to_sid(self) -> dict[str, str]:
2036+
def token_to_sid(self) -> Mapping[str, str]:
20372037
"""Get token to SID mapping for backward compatibility.
20382038
2039+
Note: this mapping is read-only.
2040+
20392041
Returns:
2040-
The token to SID mapping dict.
2042+
The token to SID mapping.
20412043
"""
20422044
# For backward compatibility, expose the underlying dict
20432045
return self._token_manager.token_to_sid

reflex/utils/token_manager.py

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
from __future__ import annotations
44

5+
import dataclasses
6+
import json
57
import uuid
68
from abc import ABC, abstractmethod
9+
from types import MappingProxyType
710
from typing import TYPE_CHECKING
811

912
from reflex.utils import console, prerequisites
@@ -21,16 +24,37 @@ def _get_new_token() -> str:
2124
return str(uuid.uuid4())
2225

2326

27+
@dataclasses.dataclass(frozen=True, kw_only=True)
28+
class SocketRecord:
29+
"""Record for a connected socket client."""
30+
31+
instance_id: str
32+
sid: str
33+
34+
2435
class TokenManager(ABC):
2536
"""Abstract base class for managing client token to session ID mappings."""
2637

2738
def __init__(self):
2839
"""Initialize the token manager with local dictionaries."""
29-
# Keep a mapping between socket ID and client token.
30-
self.token_to_sid: dict[str, str] = {}
40+
# Each process has an instance_id to identify its own sockets.
41+
self.instance_id: str = _get_new_token()
3142
# Keep a mapping between client token and socket ID.
43+
self.token_to_socket: dict[str, SocketRecord] = {}
44+
# Keep a mapping between socket ID and client token.
3245
self.sid_to_token: dict[str, str] = {}
3346

47+
@property
48+
def token_to_sid(self) -> MappingProxyType[str, str]:
49+
"""Read-only compatibility property for token_to_socket mapping.
50+
51+
Returns:
52+
The token to session ID mapping.
53+
"""
54+
return MappingProxyType({
55+
token: sr.sid for token, sr in self.token_to_socket.items()
56+
})
57+
3458
@abstractmethod
3559
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
3660
"""Link a token to a session ID.
@@ -68,7 +92,9 @@ def create(cls) -> TokenManager:
6892

6993
async def disconnect_all(self):
7094
"""Disconnect all tracked tokens when the server is going down."""
71-
token_sid_pairs: set[tuple[str, str]] = set(self.token_to_sid.items())
95+
token_sid_pairs: set[tuple[str, str]] = {
96+
(token, sr.sid) for token, sr in self.token_to_socket.items()
97+
}
7298
token_sid_pairs.update(
7399
((token, sid) for sid, token in self.sid_to_token.items())
74100
)
@@ -95,14 +121,20 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None:
95121
New token if duplicate detected and new token generated, None otherwise.
96122
"""
97123
# Check if token is already mapped to a different SID (duplicate tab)
98-
if token in self.token_to_sid and sid != self.token_to_sid.get(token):
124+
if (
125+
socket_record := self.token_to_socket.get(token)
126+
) is not None and sid != socket_record.sid:
99127
new_token = _get_new_token()
100-
self.token_to_sid[new_token] = sid
128+
self.token_to_socket[new_token] = SocketRecord(
129+
instance_id=self.instance_id, sid=sid
130+
)
101131
self.sid_to_token[sid] = new_token
102132
return new_token
103133

104134
# Normal case - link token to SID
105-
self.token_to_sid[token] = sid
135+
self.token_to_socket[token] = SocketRecord(
136+
instance_id=self.instance_id, sid=sid
137+
)
106138
self.sid_to_token[sid] = token
107139
return None
108140

@@ -114,7 +146,7 @@ async def disconnect_token(self, token: str, sid: str) -> None:
114146
sid: The Socket.IO session ID.
115147
"""
116148
# Clean up both mappings
117-
self.token_to_sid.pop(token, None)
149+
self.token_to_socket.pop(token, None)
118150
self.sid_to_token.pop(sid, None)
119151

120152

@@ -149,9 +181,9 @@ def _get_redis_key(self, token: str) -> str:
149181
token: The client token.
150182
151183
Returns:
152-
Redis key following Reflex conventions: {token}_sid
184+
Redis key following Reflex conventions: token_manager_socket_record_{token}
153185
"""
154-
return f"{token}_sid"
186+
return f"token_manager_socket_record_{token}"
155187

156188
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
157189
"""Link a token to a session ID with Redis-based duplicate detection.
@@ -164,7 +196,9 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None:
164196
New token if duplicate detected and new token generated, None otherwise.
165197
"""
166198
# Fast local check first (handles reconnections)
167-
if token in self.token_to_sid and self.token_to_sid[token] == sid:
199+
if (
200+
socket_record := self.token_to_socket.get(token)
201+
) is not None and sid == socket_record.sid:
168202
return None # Same token, same SID = reconnection, no Redis check needed
169203

170204
# Check Redis for cross-worker duplicates
@@ -176,34 +210,29 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None:
176210
console.error(f"Redis error checking token existence: {e}")
177211
return await super().link_token_to_sid(token, sid)
178212

213+
new_token = None
179214
if token_exists_in_redis:
180215
# Duplicate exists somewhere - generate new token
181-
new_token = _get_new_token()
182-
new_redis_key = self._get_redis_key(new_token)
216+
token = new_token = _get_new_token()
217+
redis_key = self._get_redis_key(new_token)
183218

184-
try:
185-
# Store in Redis
186-
await self.redis.set(new_redis_key, "1", ex=self.token_expiration)
187-
except Exception as e:
188-
console.error(f"Redis error storing new token: {e}")
189-
# Still update local dicts and continue
190-
191-
# Store in local dicts (always do this)
192-
self.token_to_sid[new_token] = sid
193-
self.sid_to_token[sid] = new_token
194-
return new_token
219+
# Store in local dicts
220+
socket_record = self.token_to_socket[token] = SocketRecord(
221+
instance_id=self.instance_id, sid=sid
222+
)
223+
self.sid_to_token[sid] = token
195224

196-
# Normal case - store in both Redis and local dicts
225+
# Store in Redis if possible
197226
try:
198-
await self.redis.set(redis_key, "1", ex=self.token_expiration)
227+
await self.redis.set(
228+
redis_key,
229+
json.dumps(dataclasses.asdict(socket_record)),
230+
ex=self.token_expiration,
231+
)
199232
except Exception as e:
200233
console.error(f"Redis error storing token: {e}")
201-
# Continue with local storage
202-
203-
# Store in local dicts (always do this)
204-
self.token_to_sid[token] = sid
205-
self.sid_to_token[sid] = token
206-
return None
234+
# Return the new token if one was generated
235+
return new_token
207236

208237
async def disconnect_token(self, token: str, sid: str) -> None:
209238
"""Clean up token mapping when client disconnects.
@@ -213,7 +242,9 @@ async def disconnect_token(self, token: str, sid: str) -> None:
213242
sid: The Socket.IO session ID.
214243
"""
215244
# Only clean up if we own it locally (fast ownership check)
216-
if self.token_to_sid.get(token) == sid:
245+
if (
246+
socket_record := self.token_to_socket.get(token)
247+
) is not None and socket_record.sid == sid:
217248
# Clean up Redis
218249
redis_key = self._get_redis_key(token)
219250
try:

tests/units/test_state.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
UnretrievableVarValueError,
5454
)
5555
from reflex.utils.format import json_dumps
56+
from reflex.utils.token_manager import SocketRecord
5657
from reflex.vars.base import Var, computed_var
5758

5859
from .states import GenState
@@ -2016,7 +2017,9 @@ async def test_state_proxy(
20162017
namespace = mock_app.event_namespace
20172018
assert namespace is not None
20182019
namespace.sid_to_token[router_data.session.session_id] = token
2019-
namespace.token_to_sid[token] = router_data.session.session_id
2020+
namespace._token_manager.token_to_socket[token] = SocketRecord(
2021+
instance_id="", sid=router_data.session.session_id
2022+
)
20202023
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
20212024
mock_app.state_manager.states[parent_state.router.session.client_token] = (
20222025
parent_state
@@ -2227,7 +2230,9 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
22272230
namespace = mock_app.event_namespace
22282231
assert namespace is not None
22292232
namespace.sid_to_token[sid] = token
2230-
namespace.token_to_sid[token] = sid
2233+
namespace._token_manager.token_to_socket[token] = SocketRecord(
2234+
instance_id="", sid=sid
2235+
)
22312236
mock_app.state_manager.state = mock_app._state = BackgroundTaskState
22322237
async for update in rx.app.process(
22332238
mock_app,

tests/units/utils/test_token_manager.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""Unit tests for TokenManager implementations."""
22

3+
import json
34
from unittest.mock import AsyncMock, Mock, patch
45

56
import pytest
67

78
from reflex.utils.token_manager import (
89
LocalTokenManager,
910
RedisTokenManager,
11+
SocketRecord,
1012
TokenManager,
1113
)
1214

@@ -215,7 +217,7 @@ def test_get_redis_key(self, manager):
215217
manager: RedisTokenManager fixture instance.
216218
"""
217219
token = "test_token_123"
218-
expected_key = f"{token}_sid"
220+
expected_key = f"token_manager_socket_record_{token}"
219221

220222
assert manager._get_redis_key(token) == expected_key
221223

@@ -232,9 +234,15 @@ async def test_link_token_to_sid_normal_case(self, manager, mock_redis):
232234
result = await manager.link_token_to_sid(token, sid)
233235

234236
assert result is None
235-
mock_redis.exists.assert_called_once_with(f"{token}_sid")
236-
mock_redis.set.assert_called_once_with(f"{token}_sid", "1", ex=3600)
237-
assert manager.token_to_sid[token] == sid
237+
mock_redis.exists.assert_called_once_with(
238+
f"token_manager_socket_record_{token}"
239+
)
240+
mock_redis.set.assert_called_once_with(
241+
f"token_manager_socket_record_{token}",
242+
json.dumps({"instance_id": manager.instance_id, "sid": sid}),
243+
ex=3600,
244+
)
245+
assert manager.token_to_socket[token].sid == sid
238246
assert manager.sid_to_token[sid] == token
239247

240248
async def test_link_token_to_sid_reconnection_skips_redis(
@@ -247,7 +255,9 @@ async def test_link_token_to_sid_reconnection_skips_redis(
247255
mock_redis: Mock Redis client fixture.
248256
"""
249257
token, sid = "token1", "sid1"
250-
manager.token_to_sid[token] = sid
258+
manager.token_to_socket[token] = SocketRecord(
259+
instance_id=manager.instance_id, sid=sid
260+
)
251261

252262
result = await manager.link_token_to_sid(token, sid)
253263

@@ -271,8 +281,14 @@ async def test_link_token_to_sid_duplicate_detected(self, manager, mock_redis):
271281
assert result != token
272282
assert len(result) == 36 # UUID4 length
273283

274-
mock_redis.exists.assert_called_once_with(f"{token}_sid")
275-
mock_redis.set.assert_called_once_with(f"{result}_sid", "1", ex=3600)
284+
mock_redis.exists.assert_called_once_with(
285+
f"token_manager_socket_record_{token}"
286+
)
287+
mock_redis.set.assert_called_once_with(
288+
f"token_manager_socket_record_{result}",
289+
json.dumps({"instance_id": manager.instance_id, "sid": sid}),
290+
ex=3600,
291+
)
276292
assert manager.token_to_sid[result] == sid
277293
assert manager.sid_to_token[sid] == result
278294

@@ -323,12 +339,16 @@ async def test_disconnect_token_owned_locally(self, manager, mock_redis):
323339
mock_redis: Mock Redis client fixture.
324340
"""
325341
token, sid = "token1", "sid1"
326-
manager.token_to_sid[token] = sid
342+
manager.token_to_socket[token] = SocketRecord(
343+
instance_id=manager.instance_id, sid=sid
344+
)
327345
manager.sid_to_token[sid] = token
328346

329347
await manager.disconnect_token(token, sid)
330348

331-
mock_redis.delete.assert_called_once_with(f"{token}_sid")
349+
mock_redis.delete.assert_called_once_with(
350+
f"token_manager_socket_record_{token}"
351+
)
332352
assert token not in manager.token_to_sid
333353
assert sid not in manager.sid_to_token
334354

@@ -353,7 +373,9 @@ async def test_disconnect_token_redis_error(self, manager, mock_redis):
353373
mock_redis: Mock Redis client fixture.
354374
"""
355375
token, sid = "token1", "sid1"
356-
manager.token_to_sid[token] = sid
376+
manager.token_to_socket[token] = SocketRecord(
377+
instance_id=manager.instance_id, sid=sid
378+
)
357379
manager.sid_to_token[sid] = token
358380
mock_redis.delete.side_effect = Exception("Redis delete error")
359381

0 commit comments

Comments
 (0)