Skip to content

Commit 83c8fcf

Browse files
committed
Implement real redis-backed test cases for lost+found
1 parent d8f040e commit 83c8fcf

File tree

2 files changed

+213
-1
lines changed

2 files changed

+213
-1
lines changed

reflex/app.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2077,11 +2077,14 @@ async def on_connect(self, sid: str, environ: dict):
20772077
f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}."
20782078
)
20792079

2080-
def on_disconnect(self, sid: str):
2080+
def on_disconnect(self, sid: str) -> asyncio.Task | None:
20812081
"""Event for when the websocket disconnects.
20822082
20832083
Args:
20842084
sid: The Socket.IO session id.
2085+
2086+
Returns:
2087+
An asyncio Task for cleaning up the token, or None.
20852088
"""
20862089
# Get token before cleaning up
20872090
disconnect_token = self.sid_to_token.get(sid)
@@ -2096,6 +2099,8 @@ def on_disconnect(self, sid: str):
20962099
lambda t: t.exception()
20972100
and console.error(f"Token cleanup error: {t.exception()}")
20982101
)
2102+
return task
2103+
return None
20992104

21002105
async def emit_update(self, update: StateUpdate, token: str) -> None:
21012106
"""Emit an update to the client.

tests/units/utils/test_token_manager.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22

33
import asyncio
44
import json
5+
from collections.abc import Callable, Generator
56
from contextlib import asynccontextmanager
67
from unittest.mock import AsyncMock, Mock, patch
78

89
import pytest
910

11+
from reflex import config
12+
from reflex.app import EventNamespace
13+
from reflex.state import StateUpdate
1014
from reflex.utils.token_manager import (
1115
LocalTokenManager,
1216
RedisTokenManager,
@@ -179,6 +183,41 @@ async def test_disconnect_nonexistent_token(self, manager):
179183
assert len(manager.token_to_sid) == 0
180184
assert len(manager.sid_to_token) == 0
181185

186+
async def test_enumerate_tokens(self, manager):
187+
"""Test enumerate_tokens yields all linked tokens.
188+
189+
Args:
190+
manager: LocalTokenManager fixture instance.
191+
"""
192+
tokens_sids = [("token1", "sid1"), ("token2", "sid2"), ("token3", "sid3")]
193+
194+
for token, sid in tokens_sids:
195+
await manager.link_token_to_sid(token, sid)
196+
197+
found_tokens = set()
198+
async for token in manager.enumerate_tokens():
199+
found_tokens.add(token)
200+
201+
expected_tokens = {token for token, _ in tokens_sids}
202+
assert found_tokens == expected_tokens
203+
204+
# Disconnect a token and ensure it's removed.
205+
await manager.disconnect_token("token2", "sid2")
206+
expected_tokens.remove("token2")
207+
208+
found_tokens = set()
209+
async for token in manager.enumerate_tokens():
210+
found_tokens.add(token)
211+
212+
assert found_tokens == expected_tokens
213+
214+
# Disconnect all tokens, none should remain
215+
await manager.disconnect_all()
216+
found_tokens = set()
217+
async for token in manager.enumerate_tokens():
218+
found_tokens.add(token)
219+
assert not found_tokens
220+
182221

183222
class TestRedisTokenManager:
184223
"""Tests for RedisTokenManager."""
@@ -445,3 +484,171 @@ def test_inheritance_from_local_manager(self, manager):
445484
assert isinstance(manager, LocalTokenManager)
446485
assert hasattr(manager, "token_to_sid")
447486
assert hasattr(manager, "sid_to_token")
487+
488+
489+
@pytest.fixture
490+
def redis_url():
491+
"""Returns the Redis URL from the environment."""
492+
redis_url = config.get_config().redis_url
493+
if redis_url is None:
494+
pytest.skip("Redis URL not configured")
495+
return redis_url
496+
497+
498+
def query_string_for(token: str) -> dict[str, str]:
499+
"""Generate query string for given token.
500+
501+
Args:
502+
token: The token to generate query string for.
503+
504+
Returns:
505+
The generated query string.
506+
"""
507+
return {"QUERY_STRING": f"token={token}"}
508+
509+
510+
@pytest.fixture
511+
def event_namespace_factory() -> Generator[Callable[[], EventNamespace], None, None]:
512+
"""Yields the EventNamespace factory function."""
513+
namespace = config.get_config().get_event_namespace()
514+
created_objs = []
515+
516+
def new_event_namespace() -> EventNamespace:
517+
state = Mock()
518+
state.router_data = {}
519+
520+
mock_app = Mock()
521+
mock_app.modify_state = Mock(
522+
return_value=AsyncMock(__aenter__=AsyncMock(return_value=state))
523+
)
524+
525+
event_namespace = EventNamespace(namespace=namespace, app=mock_app)
526+
event_namespace.emit = AsyncMock()
527+
created_objs.append(event_namespace)
528+
return event_namespace
529+
530+
yield new_event_namespace
531+
532+
for obj in created_objs:
533+
asyncio.run(obj._token_manager.disconnect_all())
534+
535+
536+
@pytest.mark.usefixtures("redis_url")
537+
@pytest.mark.asyncio
538+
async def test_redis_token_manager_enumerate_tokens(
539+
event_namespace_factory: Callable[[], EventNamespace],
540+
):
541+
"""Integration test for RedisTokenManager enumerate_tokens interface.
542+
543+
Should support enumerating tokens across separate instances of the
544+
RedisTokenManager.
545+
546+
Args:
547+
event_namespace_factory: Factory fixture for EventNamespace instances.
548+
"""
549+
event_namespace1 = event_namespace_factory()
550+
event_namespace2 = event_namespace_factory()
551+
552+
await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1"))
553+
await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2"))
554+
555+
found_tokens = set()
556+
async for token in event_namespace1._token_manager.enumerate_tokens():
557+
found_tokens.add(token)
558+
559+
assert "token1" in found_tokens
560+
assert "token2" in found_tokens
561+
assert len(found_tokens) == 2
562+
563+
await event_namespace1._token_manager.disconnect_all()
564+
565+
found_tokens = set()
566+
async for token in event_namespace1._token_manager.enumerate_tokens():
567+
found_tokens.add(token)
568+
assert "token2" in found_tokens
569+
assert len(found_tokens) == 1
570+
571+
await event_namespace2._token_manager.disconnect_all()
572+
573+
found_tokens = set()
574+
async for token in event_namespace1._token_manager.enumerate_tokens():
575+
found_tokens.add(token)
576+
assert not found_tokens
577+
578+
579+
@pytest.mark.usefixtures("redis_url")
580+
@pytest.mark.asyncio
581+
async def test_redis_token_manager_get_token_owner(
582+
event_namespace_factory: Callable[[], EventNamespace],
583+
):
584+
"""Integration test for RedisTokenManager get_token_owner interface.
585+
586+
Should support retrieving the owner of a token across separate instances of the
587+
RedisTokenManager.
588+
589+
Args:
590+
event_namespace_factory: Factory fixture for EventNamespace instances.
591+
"""
592+
event_namespace1 = event_namespace_factory()
593+
event_namespace2 = event_namespace_factory()
594+
595+
await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1"))
596+
await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2"))
597+
598+
assert isinstance((manager1 := event_namespace1._token_manager), RedisTokenManager)
599+
assert isinstance((manager2 := event_namespace2._token_manager), RedisTokenManager)
600+
601+
assert await manager1._get_token_owner("token1") == manager1.instance_id
602+
assert await manager1._get_token_owner("token2") == manager2.instance_id
603+
assert await manager2._get_token_owner("token1") == manager1.instance_id
604+
assert await manager2._get_token_owner("token2") == manager2.instance_id
605+
606+
607+
@pytest.mark.usefixtures("redis_url")
608+
@pytest.mark.asyncio
609+
async def test_redis_token_manager_lost_and_found(
610+
event_namespace_factory: Callable[[], EventNamespace],
611+
):
612+
"""Updates emitted for lost and found tokens should be routed correctly via redis.
613+
614+
Args:
615+
event_namespace_factory: Factory fixture for EventNamespace instances.
616+
"""
617+
event_namespace1 = event_namespace_factory()
618+
emit1_mock: Mock = event_namespace1.emit # pyright: ignore[reportAssignmentType]
619+
event_namespace2 = event_namespace_factory()
620+
emit2_mock: Mock = event_namespace2.emit # pyright: ignore[reportAssignmentType]
621+
622+
await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1"))
623+
await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2"))
624+
625+
await event_namespace2.emit_update(StateUpdate(), token="token1")
626+
emit2_mock.assert_not_called()
627+
emit1_mock.assert_called_once()
628+
emit1_mock.reset_mock()
629+
630+
await event_namespace2.emit_update(StateUpdate(), token="token2")
631+
emit1_mock.assert_not_called()
632+
emit2_mock.assert_called_once()
633+
emit2_mock.reset_mock()
634+
635+
if task := event_namespace1.on_disconnect(sid="sid1"):
636+
await task
637+
await event_namespace2.emit_update(StateUpdate(), token="token1")
638+
# Update should be dropped on the floor.
639+
emit1_mock.assert_not_called()
640+
emit2_mock.assert_not_called()
641+
642+
await event_namespace2.on_connect(sid="sid1", environ=query_string_for("token1"))
643+
await event_namespace2.emit_update(StateUpdate(), token="token1")
644+
emit1_mock.assert_not_called()
645+
emit2_mock.assert_called_once()
646+
emit2_mock.reset_mock()
647+
648+
if task := event_namespace2.on_disconnect(sid="sid1"):
649+
await task
650+
await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1"))
651+
await event_namespace2.emit_update(StateUpdate(), token="token1")
652+
emit2_mock.assert_not_called()
653+
emit1_mock.assert_called_once()
654+
emit1_mock.reset_mock()

0 commit comments

Comments
 (0)