Skip to content

Commit 76630a6

Browse files
committed
RedisTokenManager: keep local dicts globally updated via pub/sub
1 parent e3b9787 commit 76630a6

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

reflex/istate/manager/redis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class StateManagerRedis(StateManager):
6767
# The keyspace subscription string when redis is waiting for lock to be released.
6868
_redis_notify_keyspace_events: str = dataclasses.field(
6969
default="K" # Enable keyspace notifications (target a particular key)
70+
"$" # For String commands (like setting keys)
7071
"g" # For generic commands (DEL, EXPIRE, etc)
7172
"x" # For expired events
7273
"e" # For evicted events (i.e. maxmemory exceeded)
@@ -76,7 +77,6 @@ class StateManagerRedis(StateManager):
7677
_redis_keyspace_lock_release_events: set[bytes] = dataclasses.field(
7778
default_factory=lambda: {
7879
b"del",
79-
b"expire",
8080
b"expired",
8181
b"evicted",
8282
}

reflex/utils/token_manager.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
import dataclasses
67
import json
78
import uuid
89
from abc import ABC, abstractmethod
910
from types import MappingProxyType
1011
from typing import TYPE_CHECKING
1112

13+
from reflex.istate.manager.redis import StateManagerRedis
14+
from reflex.state import BaseState
1215
from reflex.utils import console, prerequisites
1316

1417
if TYPE_CHECKING:
@@ -173,6 +176,7 @@ def __init__(self, redis: Redis):
173176

174177
config = get_config()
175178
self.token_expiration = config.redis_token_expiration
179+
self._update_task = None
176180

177181
def _get_redis_key(self, token: str) -> str:
178182
"""Get Redis key for token mapping.
@@ -185,6 +189,48 @@ def _get_redis_key(self, token: str) -> str:
185189
"""
186190
return f"token_manager_socket_record_{token}"
187191

192+
async def _socket_record_update_task(self) -> None:
193+
"""Background task to monitor Redis keyspace notifications for socket record updates."""
194+
await StateManagerRedis(
195+
state=BaseState, redis=self.redis
196+
)._enable_keyspace_notifications()
197+
redis_db = self.redis.get_connection_kwargs().get("db", 0)
198+
while True:
199+
try:
200+
await self._subscribe_socket_record_updates(redis_db)
201+
except asyncio.CancelledError: # noqa: PERF203
202+
break
203+
except Exception as e:
204+
console.error(f"RedisTokenManager socket record update task error: {e}")
205+
206+
async def _subscribe_socket_record_updates(self, redis_db: int) -> None:
207+
"""Subscribe to Redis keyspace notifications for socket record updates."""
208+
pubsub = self.redis.pubsub()
209+
await pubsub.psubscribe(
210+
f"__keyspace@{redis_db}__:token_manager_socket_record_*"
211+
)
212+
213+
async for message in pubsub.listen():
214+
if message["type"] == "pmessage":
215+
key = message["channel"].split(b":", 1)[1].decode()
216+
event = message["data"].decode()
217+
token = key.replace("token_manager_socket_record_", "")
218+
219+
if event in ("del", "expired", "evicted"):
220+
# Remove from local dicts if exists
221+
if (
222+
socket_record := self.token_to_socket.pop(token, None)
223+
) is not None:
224+
self.sid_to_token.pop(socket_record.sid, None)
225+
elif event == "set":
226+
# Fetch updated record from Redis
227+
record_json = await self.redis.get(key)
228+
if record_json:
229+
record_data = json.loads(record_json)
230+
socket_record = SocketRecord(**record_data)
231+
self.token_to_socket[token] = socket_record
232+
self.sid_to_token[socket_record.sid] = token
233+
188234
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
189235
"""Link a token to a session ID with Redis-based duplicate detection.
190236
@@ -201,6 +247,10 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None:
201247
) is not None and sid == socket_record.sid:
202248
return None # Same token, same SID = reconnection, no Redis check needed
203249

250+
# Make sure the update subscriber is running
251+
if self._update_task is None or self._update_task.done():
252+
self._update_task = asyncio.create_task(self._socket_record_update_task())
253+
204254
# Check Redis for cross-worker duplicates
205255
redis_key = self._get_redis_key(token)
206256

tests/units/utils/test_token_manager.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Unit tests for TokenManager implementations."""
22

3+
import asyncio
34
import json
45
from unittest.mock import AsyncMock, Mock, patch
56

@@ -63,6 +64,7 @@ def test_create_redis_when_redis_available(
6364
"""
6465
mock_check_redis_used.return_value = True
6566
mock_redis_client = Mock()
67+
mock_redis_client.get_connection_kwargs.return_value = {"db": 0}
6668
mock_get_redis.return_value = mock_redis_client
6769

6870
manager = TokenManager.create()
@@ -191,6 +193,20 @@ def mock_redis(self):
191193
redis.exists = AsyncMock()
192194
redis.set = AsyncMock()
193195
redis.delete = AsyncMock()
196+
197+
# Non-async call
198+
redis.get_connection_kwargs = Mock(return_value={"db": 0})
199+
200+
# Mock out pubsub
201+
async def listen():
202+
await asyncio.sleep(1)
203+
if False:
204+
yield
205+
return
206+
207+
psubscribe = AsyncMock()
208+
psubscribe.listen = listen
209+
redis.pubsub = Mock(return_value=psubscribe)
194210
return redis
195211

196212
@pytest.fixture

0 commit comments

Comments
 (0)