|
7 | 7 | import json |
8 | 8 | import uuid |
9 | 9 | from abc import ABC, abstractmethod |
10 | | -from collections.abc import Callable, Coroutine |
| 10 | +from collections.abc import AsyncIterator, Callable, Coroutine |
11 | 11 | from types import MappingProxyType |
12 | | -from typing import TYPE_CHECKING, Any |
| 12 | +from typing import TYPE_CHECKING, Any, ClassVar |
13 | 13 |
|
14 | 14 | from reflex.istate.manager.redis import StateManagerRedis |
15 | 15 | from reflex.state import BaseState, StateUpdate |
@@ -67,6 +67,15 @@ def token_to_sid(self) -> MappingProxyType[str, str]: |
67 | 67 | token: sr.sid for token, sr in self.token_to_socket.items() |
68 | 68 | }) |
69 | 69 |
|
| 70 | + async def enumerate_tokens(self) -> AsyncIterator[str]: |
| 71 | + """Iterate over all tokens in the system. |
| 72 | +
|
| 73 | + Yields: |
| 74 | + All client tokens known to the TokenManager. |
| 75 | + """ |
| 76 | + for token in self.token_to_socket: |
| 77 | + yield token |
| 78 | + |
70 | 79 | @abstractmethod |
71 | 80 | async def link_token_to_sid(self, token: str, sid: str) -> str | None: |
72 | 81 | """Link a token to a session ID. |
@@ -169,6 +178,8 @@ class RedisTokenManager(LocalTokenManager): |
169 | 178 | for cross-worker duplicate detection. |
170 | 179 | """ |
171 | 180 |
|
| 181 | + _token_socket_record_prefix: ClassVar[str] = "token_manager_socket_record_" |
| 182 | + |
172 | 183 | def __init__(self, redis: Redis): |
173 | 184 | """Initialize the Redis token manager. |
174 | 185 |
|
@@ -199,7 +210,23 @@ def _get_redis_key(self, token: str) -> str: |
199 | 210 | Returns: |
200 | 211 | Redis key following Reflex conventions: token_manager_socket_record_{token} |
201 | 212 | """ |
202 | | - return f"token_manager_socket_record_{token}" |
| 213 | + return f"{self._token_socket_record_prefix}{token}" |
| 214 | + |
| 215 | + async def enumerate_tokens(self) -> AsyncIterator[str]: |
| 216 | + """Iterate over all tokens in the system. |
| 217 | +
|
| 218 | + Yields: |
| 219 | + All client tokens known to the RedisTokenManager. |
| 220 | + """ |
| 221 | + cursor = 0 |
| 222 | + while scan_result := await self.redis.scan( |
| 223 | + cursor=cursor, match=self._get_redis_key("*") |
| 224 | + ): |
| 225 | + cursor = int(scan_result[0]) |
| 226 | + for key in scan_result[1]: |
| 227 | + yield key.decode().replace(self._token_socket_record_prefix, "") |
| 228 | + if not cursor: |
| 229 | + break |
203 | 230 |
|
204 | 231 | def _handle_socket_record_del(self, token: str) -> None: |
205 | 232 | """Handle deletion of a socket record from Redis. |
@@ -230,12 +257,12 @@ async def _subscribe_socket_record_updates(self, redis_db: int) -> None: |
230 | 257 | """Subscribe to Redis keyspace notifications for socket record updates.""" |
231 | 258 | async with self.redis.pubsub() as pubsub: |
232 | 259 | await pubsub.psubscribe( |
233 | | - f"__keyspace@{redis_db}__:token_manager_socket_record_*" |
| 260 | + f"__keyspace@{redis_db}__:{self._get_redis_key('*')}" |
234 | 261 | ) |
235 | 262 | async for message in pubsub.listen(): |
236 | 263 | if message["type"] == "pmessage": |
237 | 264 | key = message["channel"].split(b":", 1)[1].decode() |
238 | | - token = key.replace("token_manager_socket_record_", "") |
| 265 | + token = key.replace(self._token_socket_record_prefix, "") |
239 | 266 |
|
240 | 267 | if token not in self.token_to_socket: |
241 | 268 | # We don't know about this token, skip |
|
0 commit comments