Skip to content

Commit e71eece

Browse files
committed
Implement lost+found for StateUpdate without websocket
When an update is emitted for a token, but the websocket for that token is on another instance of the app, post it to the lost+found channel where other instances are listening for updates to send to their clients.
1 parent 76630a6 commit e71eece

File tree

4 files changed

+231
-48
lines changed

4 files changed

+231
-48
lines changed

reflex/app.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@
120120
)
121121
from reflex.utils.imports import ImportVar
122122
from reflex.utils.misc import run_in_thread
123-
from reflex.utils.token_manager import TokenManager
123+
from reflex.utils.token_manager import RedisTokenManager, TokenManager
124124
from reflex.utils.types import ASGIApp, Message, Receive, Scope, Send
125125

126126
if TYPE_CHECKING:
@@ -2061,6 +2061,9 @@ async def on_connect(self, sid: str, environ: dict):
20612061
sid: The Socket.IO session id.
20622062
environ: The request information, including HTTP headers.
20632063
"""
2064+
if isinstance(self._token_manager, RedisTokenManager):
2065+
# Make sure this instance is watching for updates from other instances.
2066+
self._token_manager.ensure_lost_and_found_task(self.emit_update)
20642067
query_params = urllib.parse.parse_qs(environ.get("QUERY_STRING", ""))
20652068
token_list = query_params.get("token", [])
20662069
if token_list:
@@ -2102,16 +2105,30 @@ async def emit_update(self, update: StateUpdate, token: str) -> None:
21022105
token: The client token (tab) associated with the event.
21032106
"""
21042107
client_token, _ = _split_substate_key(token)
2105-
sid = self.token_to_sid.get(client_token)
2106-
if sid is None:
2107-
# If the sid is None, we are not connected to a client. Prevent sending
2108-
# updates to all clients.
2109-
console.warn(f"Attempting to send delta to disconnected client {token!r}")
2108+
socket_record = self._token_manager.token_to_socket.get(client_token)
2109+
if (
2110+
socket_record is None
2111+
or socket_record.instance_id != self._token_manager.instance_id
2112+
):
2113+
if isinstance(self._token_manager, RedisTokenManager):
2114+
# The socket belongs to another instance of the app, send it to the lost and found.
2115+
if not await self._token_manager.emit_lost_and_found(
2116+
client_token, update
2117+
):
2118+
console.warn(
2119+
f"Failed to send delta to lost and found for client {token!r}"
2120+
)
2121+
else:
2122+
# If the socket record is None, we are not connected to a client. Prevent sending
2123+
# updates to all clients.
2124+
console.warn(
2125+
f"Attempting to send delta to disconnected client {token!r}"
2126+
)
21102127
return
21112128
# Creating a task prevents the update from being blocked behind other coroutines.
21122129
await asyncio.create_task(
2113-
self.emit(str(constants.SocketEvent.EVENT), update, to=sid),
2114-
name=f"reflex_emit_event|{token}|{sid}|{time.time()}",
2130+
self.emit(str(constants.SocketEvent.EVENT), update, to=socket_record.sid),
2131+
name=f"reflex_emit_event|{token}|{socket_record.sid}|{time.time()}",
21152132
)
21162133

21172134
async def on_event(self, sid: str, data: Any):

reflex/utils/token_manager.py

Lines changed: 194 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import json
88
import uuid
99
from abc import ABC, abstractmethod
10+
from collections.abc import Callable, Coroutine
1011
from types import MappingProxyType
11-
from typing import TYPE_CHECKING
12+
from typing import TYPE_CHECKING, Any
1213

1314
from reflex.istate.manager.redis import StateManagerRedis
14-
from reflex.state import BaseState
15+
from reflex.state import BaseState, StateUpdate
1516
from reflex.utils import console, prerequisites
1617

1718
if TYPE_CHECKING:
@@ -35,6 +36,14 @@ class SocketRecord:
3536
sid: str
3637

3738

39+
@dataclasses.dataclass(frozen=True, kw_only=True)
40+
class LostAndFoundRecord:
41+
"""Record for a StateUpdate for a token with its socket on another instance."""
42+
43+
token: str
44+
update: dict[str, Any]
45+
46+
3847
class TokenManager(ABC):
3948
"""Abstract base class for managing client token to session ID mappings."""
4049

@@ -176,7 +185,10 @@ def __init__(self, redis: Redis):
176185

177186
config = get_config()
178187
self.token_expiration = config.redis_token_expiration
179-
self._update_task = None
188+
189+
# Pub/sub tasks for handling sockets owned by other instances.
190+
self._socket_record_task: asyncio.Task | None = None
191+
self._lost_and_found_task: asyncio.Task | None = None
180192

181193
def _get_redis_key(self, token: str) -> str:
182194
"""Get Redis key for token mapping.
@@ -189,7 +201,53 @@ def _get_redis_key(self, token: str) -> str:
189201
"""
190202
return f"token_manager_socket_record_{token}"
191203

192-
async def _socket_record_update_task(self) -> None:
204+
def _handle_socket_record_del(self, token: str) -> None:
205+
"""Handle deletion of a socket record from Redis.
206+
207+
Args:
208+
token: The client token whose record was deleted.
209+
"""
210+
if (
211+
socket_record := self.token_to_socket.pop(token, None)
212+
) is not None and socket_record.instance_id != self.instance_id:
213+
self.sid_to_token.pop(socket_record.sid, None)
214+
215+
async def _handle_socket_record_set(self, token: str) -> None:
216+
"""Handle setting/updating of a socket record from Redis.
217+
218+
Args:
219+
token: The client token whose record was set/updated.
220+
"""
221+
# Fetch updated record from Redis
222+
record_json = await self.redis.get(self._get_redis_key(token))
223+
if record_json:
224+
record_data = json.loads(record_json)
225+
socket_record = SocketRecord(**record_data)
226+
self.token_to_socket[token] = socket_record
227+
self.sid_to_token[socket_record.sid] = token
228+
229+
async def _subscribe_socket_record_updates(self, redis_db: int) -> None:
230+
"""Subscribe to Redis keyspace notifications for socket record updates."""
231+
async with self.redis.pubsub() as pubsub:
232+
await pubsub.psubscribe(
233+
f"__keyspace@{redis_db}__:token_manager_socket_record_*"
234+
)
235+
async for message in pubsub.listen():
236+
if message["type"] == "pmessage":
237+
key = message["channel"].split(b":", 1)[1].decode()
238+
token = key.replace("token_manager_socket_record_", "")
239+
240+
if token not in self.token_to_socket:
241+
# We don't know about this token, skip
242+
continue
243+
244+
event = message["data"].decode()
245+
if event in ("del", "expired", "evicted"):
246+
self._handle_socket_record_del(token)
247+
elif event == "set":
248+
await self._handle_socket_record_set(token)
249+
250+
async def _socket_record_updates_forever(self) -> None:
193251
"""Background task to monitor Redis keyspace notifications for socket record updates."""
194252
await StateManagerRedis(
195253
state=BaseState, redis=self.redis
@@ -203,33 +261,12 @@ async def _socket_record_update_task(self) -> None:
203261
except Exception as e:
204262
console.error(f"RedisTokenManager socket record update task error: {e}")
205263

206-
async def _subscribe_socket_record_updates(self, redis_db: int) -> None:
207-
"""Subscribe to Redis keyspace notifications for socket record updates."""
208-
pubsub = self.redis.pubsub()
209-
await pubsub.psubscribe(
210-
f"__keyspace@{redis_db}__:token_manager_socket_record_*"
211-
)
212-
213-
async for message in pubsub.listen():
214-
if message["type"] == "pmessage":
215-
key = message["channel"].split(b":", 1)[1].decode()
216-
event = message["data"].decode()
217-
token = key.replace("token_manager_socket_record_", "")
218-
219-
if event in ("del", "expired", "evicted"):
220-
# Remove from local dicts if exists
221-
if (
222-
socket_record := self.token_to_socket.pop(token, None)
223-
) is not None:
224-
self.sid_to_token.pop(socket_record.sid, None)
225-
elif event == "set":
226-
# Fetch updated record from Redis
227-
record_json = await self.redis.get(key)
228-
if record_json:
229-
record_data = json.loads(record_json)
230-
socket_record = SocketRecord(**record_data)
231-
self.token_to_socket[token] = socket_record
232-
self.sid_to_token[socket_record.sid] = token
264+
def _ensure_socket_record_task(self) -> None:
265+
"""Ensure the socket record updates subscriber task is running."""
266+
if self._socket_record_task is None or self._socket_record_task.done():
267+
self._socket_record_task = asyncio.create_task(
268+
self._socket_record_updates_forever()
269+
)
233270

234271
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
235272
"""Link a token to a session ID with Redis-based duplicate detection.
@@ -248,8 +285,7 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None:
248285
return None # Same token, same SID = reconnection, no Redis check needed
249286

250287
# Make sure the update subscriber is running
251-
if self._update_task is None or self._update_task.done():
252-
self._update_task = asyncio.create_task(self._socket_record_update_task())
288+
self._ensure_socket_record_task()
253289

254290
# Check Redis for cross-worker duplicates
255291
redis_key = self._get_redis_key(token)
@@ -293,8 +329,10 @@ async def disconnect_token(self, token: str, sid: str) -> None:
293329
"""
294330
# Only clean up if we own it locally (fast ownership check)
295331
if (
296-
socket_record := self.token_to_socket.get(token)
297-
) is not None and socket_record.sid == sid:
332+
(socket_record := self.token_to_socket.get(token)) is not None
333+
and socket_record.sid == sid
334+
and socket_record.instance_id == self.instance_id
335+
):
298336
# Clean up Redis
299337
redis_key = self._get_redis_key(token)
300338
try:
@@ -304,3 +342,124 @@ async def disconnect_token(self, token: str, sid: str) -> None:
304342

305343
# Clean up local dicts (always do this)
306344
await super().disconnect_token(token, sid)
345+
346+
@staticmethod
347+
def _get_lost_and_found_key(instance_id: str) -> str:
348+
"""Get the Redis key for lost and found deltas for an instance.
349+
350+
Args:
351+
instance_id: The instance ID.
352+
353+
Returns:
354+
The Redis key for lost and found deltas.
355+
"""
356+
return f"token_manager_lost_and_found_{instance_id}"
357+
358+
async def _subscribe_lost_and_found_updates(
359+
self,
360+
emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]],
361+
) -> None:
362+
"""Subscribe to Redis channel notifications for lost and found deltas.
363+
364+
Args:
365+
emit_update: The function to emit state updates.
366+
"""
367+
async with self.redis.pubsub() as pubsub:
368+
await pubsub.psubscribe(
369+
f"channel:{self._get_lost_and_found_key(self.instance_id)}"
370+
)
371+
async for message in pubsub.listen():
372+
if message["type"] == "pmessage":
373+
record = LostAndFoundRecord(**json.loads(message["data"].decode()))
374+
await emit_update(StateUpdate(**record.update), record.token)
375+
376+
async def _lost_and_found_updates_forever(
377+
self,
378+
emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]],
379+
):
380+
"""Background task to monitor Redis lost and found deltas.
381+
382+
Args:
383+
emit_update: The function to emit state updates.
384+
"""
385+
while True:
386+
try:
387+
await self._subscribe_lost_and_found_updates(emit_update)
388+
except asyncio.CancelledError: # noqa: PERF203
389+
break
390+
except Exception as e:
391+
console.error(f"RedisTokenManager lost and found task error: {e}")
392+
393+
def ensure_lost_and_found_task(
394+
self,
395+
emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]],
396+
) -> None:
397+
"""Ensure the lost and found subscriber task is running.
398+
399+
Args:
400+
emit_update: The function to emit state updates.
401+
"""
402+
if self._lost_and_found_task is None or self._lost_and_found_task.done():
403+
self._lost_and_found_task = asyncio.create_task(
404+
self._lost_and_found_updates_forever(emit_update)
405+
)
406+
407+
async def _get_token_owner(self, token: str, refresh: bool = False) -> str | None:
408+
"""Get the instance ID of the owner of a token.
409+
410+
Args:
411+
token: The client token.
412+
refresh: Whether to fetch the latest record from Redis.
413+
414+
Returns:
415+
The instance ID of the owner, or None if not found.
416+
"""
417+
if (
418+
not refresh
419+
and (socket_record := self.token_to_socket.get(token)) is not None
420+
):
421+
return socket_record.instance_id
422+
423+
redis_key = self._get_redis_key(token)
424+
try:
425+
record_json = await self.redis.get(redis_key)
426+
if record_json:
427+
record_data = json.loads(record_json)
428+
socket_record = SocketRecord(**record_data)
429+
self.token_to_socket[token] = socket_record
430+
self.sid_to_token[socket_record.sid] = token
431+
return socket_record.instance_id
432+
console.error(f"Redis token owner not found for token {token}")
433+
except Exception as e:
434+
console.error(f"Redis error getting token owner: {e}")
435+
return None
436+
437+
async def emit_lost_and_found(
438+
self,
439+
token: str,
440+
update: StateUpdate,
441+
) -> bool:
442+
"""Emit a lost and found delta to Redis.
443+
444+
Args:
445+
token: The client token.
446+
update: The state update.
447+
448+
Returns:
449+
True if the delta was published, False otherwise.
450+
"""
451+
# See where this update belongs
452+
owner_instance_id = await self._get_token_owner(token)
453+
if owner_instance_id is None:
454+
return False
455+
record = LostAndFoundRecord(token=token, update=dataclasses.asdict(update))
456+
try:
457+
await self.redis.publish(
458+
f"channel:{self._get_lost_and_found_key(owner_instance_id)}",
459+
json.dumps(dataclasses.asdict(record)),
460+
)
461+
except Exception as e:
462+
console.error(f"Redis error publishing lost and found delta: {e}")
463+
else:
464+
return True
465+
return False

tests/units/test_state.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,8 +2017,9 @@ async def test_state_proxy(
20172017
namespace = mock_app.event_namespace
20182018
assert namespace is not None
20192019
namespace.sid_to_token[router_data.session.session_id] = token
2020+
namespace._token_manager.instance_id = "mock"
20202021
namespace._token_manager.token_to_socket[token] = SocketRecord(
2021-
instance_id="", sid=router_data.session.session_id
2022+
instance_id="mock", sid=router_data.session.session_id
20222023
)
20232024
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
20242025
mock_app.state_manager.states[parent_state.router.session.client_token] = (
@@ -2230,8 +2231,9 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
22302231
namespace = mock_app.event_namespace
22312232
assert namespace is not None
22322233
namespace.sid_to_token[sid] = token
2234+
namespace._token_manager.instance_id = "mock"
22332235
namespace._token_manager.token_to_socket[token] = SocketRecord(
2234-
instance_id="", sid=sid
2236+
instance_id="mock", sid=sid
22352237
)
22362238
mock_app.state_manager.state = mock_app._state = BackgroundTaskState
22372239
async for update in rx.app.process(

tests/units/utils/test_token_manager.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import json
5+
from contextlib import asynccontextmanager
56
from unittest.mock import AsyncMock, Mock, patch
67

78
import pytest
@@ -204,9 +205,13 @@ async def listen():
204205
yield
205206
return
206207

207-
psubscribe = AsyncMock()
208-
psubscribe.listen = listen
209-
redis.pubsub = Mock(return_value=psubscribe)
208+
@asynccontextmanager
209+
async def pubsub(): # noqa: RUF029
210+
pubsub_mock = AsyncMock()
211+
pubsub_mock.listen = listen
212+
yield pubsub_mock
213+
214+
redis.pubsub = pubsub
210215
return redis
211216

212217
@pytest.fixture

0 commit comments

Comments
 (0)