Skip to content

Commit 0976ece

Browse files
authored
ENG-8089: Implement Lost+Found with RedisTokenManager (#5927)
* Token manager tracks instance_id in token_to_socket * RedisTokenManager: keep local dicts globally updated via pub/sub * Implement lost+found for StateUpdate without websocket When an update is emitted for a token, but the websocket for that token is on another instance of the app, post it to the lost+found channel where other instances are listening for updates to send to their clients. * Implement `enumerate_tokens` for TokenManager Set the groundwork for being able to broadcast updates to all connected states. * Consolidate on `_get_token_owner` * fix test_connection_banner.py: expect SocketRecord JSON * Implement real redis-backed test cases for lost+found * add some polling for the emit mocks since L+F doesn't happen immediately
1 parent c4254ed commit 0976ece

File tree

6 files changed

+614
-62
lines changed

6 files changed

+614
-62
lines changed

reflex/app.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@
120120
)
121121
from reflex.utils.imports import ImportVar
122122
from reflex.utils.misc import run_in_thread
123-
from reflex.utils.token_manager import TokenManager
123+
from reflex.utils.token_manager import RedisTokenManager, TokenManager
124124
from reflex.utils.types import ASGIApp, Message, Receive, Scope, Send
125125

126126
if TYPE_CHECKING:
@@ -2033,11 +2033,13 @@ def __init__(self, namespace: str, app: App):
20332033
self._token_manager = TokenManager.create()
20342034

20352035
@property
2036-
def token_to_sid(self) -> dict[str, str]:
2036+
def token_to_sid(self) -> Mapping[str, str]:
20372037
"""Get token to SID mapping for backward compatibility.
20382038
2039+
Note: this mapping is read-only.
2040+
20392041
Returns:
2040-
The token to SID mapping dict.
2042+
The token to SID mapping.
20412043
"""
20422044
# For backward compatibility, expose the underlying dict
20432045
return self._token_manager.token_to_sid
@@ -2059,6 +2061,9 @@ async def on_connect(self, sid: str, environ: dict):
20592061
sid: The Socket.IO session id.
20602062
environ: The request information, including HTTP headers.
20612063
"""
2064+
if isinstance(self._token_manager, RedisTokenManager):
2065+
# Make sure this instance is watching for updates from other instances.
2066+
self._token_manager.ensure_lost_and_found_task(self.emit_update)
20622067
query_params = urllib.parse.parse_qs(environ.get("QUERY_STRING", ""))
20632068
token_list = query_params.get("token", [])
20642069
if token_list:
@@ -2072,11 +2077,14 @@ async def on_connect(self, sid: str, environ: dict):
20722077
f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}."
20732078
)
20742079

2075-
def on_disconnect(self, sid: str):
2080+
def on_disconnect(self, sid: str) -> asyncio.Task | None:
20762081
"""Event for when the websocket disconnects.
20772082
20782083
Args:
20792084
sid: The Socket.IO session id.
2085+
2086+
Returns:
2087+
An asyncio Task for cleaning up the token, or None.
20802088
"""
20812089
# Get token before cleaning up
20822090
disconnect_token = self.sid_to_token.get(sid)
@@ -2091,6 +2099,8 @@ def on_disconnect(self, sid: str):
20912099
lambda t: t.exception()
20922100
and console.error(f"Token cleanup error: {t.exception()}")
20932101
)
2102+
return task
2103+
return None
20942104

20952105
async def emit_update(self, update: StateUpdate, token: str) -> None:
20962106
"""Emit an update to the client.
@@ -2100,16 +2110,30 @@ async def emit_update(self, update: StateUpdate, token: str) -> None:
21002110
token: The client token (tab) associated with the event.
21012111
"""
21022112
client_token, _ = _split_substate_key(token)
2103-
sid = self.token_to_sid.get(client_token)
2104-
if sid is None:
2105-
# If the sid is None, we are not connected to a client. Prevent sending
2106-
# updates to all clients.
2107-
console.warn(f"Attempting to send delta to disconnected client {token!r}")
2113+
socket_record = self._token_manager.token_to_socket.get(client_token)
2114+
if (
2115+
socket_record is None
2116+
or socket_record.instance_id != self._token_manager.instance_id
2117+
):
2118+
if isinstance(self._token_manager, RedisTokenManager):
2119+
# The socket belongs to another instance of the app, send it to the lost and found.
2120+
if not await self._token_manager.emit_lost_and_found(
2121+
client_token, update
2122+
):
2123+
console.warn(
2124+
f"Failed to send delta to lost and found for client {token!r}"
2125+
)
2126+
else:
2127+
# If the socket record is None, we are not connected to a client. Prevent sending
2128+
# updates to all clients.
2129+
console.warn(
2130+
f"Attempting to send delta to disconnected client {token!r}"
2131+
)
21082132
return
21092133
# Creating a task prevents the update from being blocked behind other coroutines.
21102134
await asyncio.create_task(
2111-
self.emit(str(constants.SocketEvent.EVENT), update, to=sid),
2112-
name=f"reflex_emit_event|{token}|{sid}|{time.time()}",
2135+
self.emit(str(constants.SocketEvent.EVENT), update, to=socket_record.sid),
2136+
name=f"reflex_emit_event|{token}|{socket_record.sid}|{time.time()}",
21132137
)
21142138

21152139
async def on_event(self, sid: str, data: Any):

reflex/istate/manager/redis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class StateManagerRedis(StateManager):
6767
# The keyspace subscription string when redis is waiting for lock to be released.
6868
_redis_notify_keyspace_events: str = dataclasses.field(
6969
default="K" # Enable keyspace notifications (target a particular key)
70+
"$" # For String commands (like setting keys)
7071
"g" # For generic commands (DEL, EXPIRE, etc)
7172
"x" # For expired events
7273
"e" # For evicted events (i.e. maxmemory exceeded)
@@ -76,7 +77,6 @@ class StateManagerRedis(StateManager):
7677
_redis_keyspace_lock_release_events: set[bytes] = dataclasses.field(
7778
default_factory=lambda: {
7879
b"del",
79-
b"expire",
8080
b"expired",
8181
b"evicted",
8282
}

0 commit comments

Comments
 (0)