Skip to content

Commit 096f6ac

Browse files
committed
Implement enumerate_tokens for TokenManager
Set the groundwork for being able to broadcast updates to all connected states.
1 parent e71eece commit 096f6ac

File tree

1 file changed

+32
-5
lines changed

1 file changed

+32
-5
lines changed

reflex/utils/token_manager.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import json
88
import uuid
99
from abc import ABC, abstractmethod
10-
from collections.abc import Callable, Coroutine
10+
from collections.abc import AsyncIterator, Callable, Coroutine
1111
from types import MappingProxyType
12-
from typing import TYPE_CHECKING, Any
12+
from typing import TYPE_CHECKING, Any, ClassVar
1313

1414
from reflex.istate.manager.redis import StateManagerRedis
1515
from reflex.state import BaseState, StateUpdate
@@ -67,6 +67,15 @@ def token_to_sid(self) -> MappingProxyType[str, str]:
6767
token: sr.sid for token, sr in self.token_to_socket.items()
6868
})
6969

70+
async def enumerate_tokens(self) -> AsyncIterator[str]:
71+
"""Iterate over all tokens in the system.
72+
73+
Yields:
74+
All client tokens known to the TokenManager.
75+
"""
76+
for token in self.token_to_socket:
77+
yield token
78+
7079
@abstractmethod
7180
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
7281
"""Link a token to a session ID.
@@ -169,6 +178,8 @@ class RedisTokenManager(LocalTokenManager):
169178
for cross-worker duplicate detection.
170179
"""
171180

181+
_token_socket_record_prefix: ClassVar[str] = "token_manager_socket_record_"
182+
172183
def __init__(self, redis: Redis):
173184
"""Initialize the Redis token manager.
174185
@@ -199,7 +210,23 @@ def _get_redis_key(self, token: str) -> str:
199210
Returns:
200211
Redis key following Reflex conventions: token_manager_socket_record_{token}
201212
"""
202-
return f"token_manager_socket_record_{token}"
213+
return f"{self._token_socket_record_prefix}{token}"
214+
215+
async def enumerate_tokens(self) -> AsyncIterator[str]:
216+
"""Iterate over all tokens in the system.
217+
218+
Yields:
219+
All client tokens known to the RedisTokenManager.
220+
"""
221+
cursor = 0
222+
while scan_result := await self.redis.scan(
223+
cursor=cursor, match=self._get_redis_key("*")
224+
):
225+
cursor = int(scan_result[0])
226+
for key in scan_result[1]:
227+
yield key.decode().replace(self._token_socket_record_prefix, "")
228+
if not cursor:
229+
break
203230

204231
def _handle_socket_record_del(self, token: str) -> None:
205232
"""Handle deletion of a socket record from Redis.
@@ -230,12 +257,12 @@ async def _subscribe_socket_record_updates(self, redis_db: int) -> None:
230257
"""Subscribe to Redis keyspace notifications for socket record updates."""
231258
async with self.redis.pubsub() as pubsub:
232259
await pubsub.psubscribe(
233-
f"__keyspace@{redis_db}__:token_manager_socket_record_*"
260+
f"__keyspace@{redis_db}__:{self._get_redis_key('*')}"
234261
)
235262
async for message in pubsub.listen():
236263
if message["type"] == "pmessage":
237264
key = message["channel"].split(b":", 1)[1].decode()
238-
token = key.replace("token_manager_socket_record_", "")
265+
token = key.replace(self._token_socket_record_prefix, "")
239266

240267
if token not in self.token_to_socket:
241268
# We don't know about this token, skip

0 commit comments

Comments
 (0)