|
2 | 2 |
|
3 | 3 | import asyncio |
4 | 4 | import json |
| 5 | +from collections.abc import Callable, Generator |
5 | 6 | from contextlib import asynccontextmanager |
6 | 7 | from unittest.mock import AsyncMock, Mock, patch |
7 | 8 |
|
8 | 9 | import pytest |
9 | 10 |
|
| 11 | +from reflex import config |
| 12 | +from reflex.app import EventNamespace |
| 13 | +from reflex.state import StateUpdate |
10 | 14 | from reflex.utils.token_manager import ( |
11 | 15 | LocalTokenManager, |
12 | 16 | RedisTokenManager, |
@@ -179,6 +183,41 @@ async def test_disconnect_nonexistent_token(self, manager): |
179 | 183 | assert len(manager.token_to_sid) == 0 |
180 | 184 | assert len(manager.sid_to_token) == 0 |
181 | 185 |
|
| 186 | + async def test_enumerate_tokens(self, manager): |
| 187 | + """Test enumerate_tokens yields all linked tokens. |
| 188 | +
|
| 189 | + Args: |
| 190 | + manager: LocalTokenManager fixture instance. |
| 191 | + """ |
| 192 | + tokens_sids = [("token1", "sid1"), ("token2", "sid2"), ("token3", "sid3")] |
| 193 | + |
| 194 | + for token, sid in tokens_sids: |
| 195 | + await manager.link_token_to_sid(token, sid) |
| 196 | + |
| 197 | + found_tokens = set() |
| 198 | + async for token in manager.enumerate_tokens(): |
| 199 | + found_tokens.add(token) |
| 200 | + |
| 201 | + expected_tokens = {token for token, _ in tokens_sids} |
| 202 | + assert found_tokens == expected_tokens |
| 203 | + |
| 204 | + # Disconnect a token and ensure it's removed. |
| 205 | + await manager.disconnect_token("token2", "sid2") |
| 206 | + expected_tokens.remove("token2") |
| 207 | + |
| 208 | + found_tokens = set() |
| 209 | + async for token in manager.enumerate_tokens(): |
| 210 | + found_tokens.add(token) |
| 211 | + |
| 212 | + assert found_tokens == expected_tokens |
| 213 | + |
| 214 | + # Disconnect all tokens, none should remain |
| 215 | + await manager.disconnect_all() |
| 216 | + found_tokens = set() |
| 217 | + async for token in manager.enumerate_tokens(): |
| 218 | + found_tokens.add(token) |
| 219 | + assert not found_tokens |
| 220 | + |
182 | 221 |
|
183 | 222 | class TestRedisTokenManager: |
184 | 223 | """Tests for RedisTokenManager.""" |
@@ -445,3 +484,171 @@ def test_inheritance_from_local_manager(self, manager): |
445 | 484 | assert isinstance(manager, LocalTokenManager) |
446 | 485 | assert hasattr(manager, "token_to_sid") |
447 | 486 | assert hasattr(manager, "sid_to_token") |
| 487 | + |
| 488 | + |
| 489 | +@pytest.fixture |
| 490 | +def redis_url(): |
| 491 | + """Returns the Redis URL from the environment.""" |
| 492 | + redis_url = config.get_config().redis_url |
| 493 | + if redis_url is None: |
| 494 | + pytest.skip("Redis URL not configured") |
| 495 | + return redis_url |
| 496 | + |
| 497 | + |
| 498 | +def query_string_for(token: str) -> dict[str, str]: |
| 499 | + """Generate query string for given token. |
| 500 | +
|
| 501 | + Args: |
| 502 | + token: The token to generate query string for. |
| 503 | +
|
| 504 | + Returns: |
| 505 | + The generated query string. |
| 506 | + """ |
| 507 | + return {"QUERY_STRING": f"token={token}"} |
| 508 | + |
| 509 | + |
| 510 | +@pytest.fixture |
| 511 | +def event_namespace_factory() -> Generator[Callable[[], EventNamespace], None, None]: |
| 512 | + """Yields the EventNamespace factory function.""" |
| 513 | + namespace = config.get_config().get_event_namespace() |
| 514 | + created_objs = [] |
| 515 | + |
| 516 | + def new_event_namespace() -> EventNamespace: |
| 517 | + state = Mock() |
| 518 | + state.router_data = {} |
| 519 | + |
| 520 | + mock_app = Mock() |
| 521 | + mock_app.modify_state = Mock( |
| 522 | + return_value=AsyncMock(__aenter__=AsyncMock(return_value=state)) |
| 523 | + ) |
| 524 | + |
| 525 | + event_namespace = EventNamespace(namespace=namespace, app=mock_app) |
| 526 | + event_namespace.emit = AsyncMock() |
| 527 | + created_objs.append(event_namespace) |
| 528 | + return event_namespace |
| 529 | + |
| 530 | + yield new_event_namespace |
| 531 | + |
| 532 | + for obj in created_objs: |
| 533 | + asyncio.run(obj._token_manager.disconnect_all()) |
| 534 | + |
| 535 | + |
| 536 | +@pytest.mark.usefixtures("redis_url") |
| 537 | +@pytest.mark.asyncio |
| 538 | +async def test_redis_token_manager_enumerate_tokens( |
| 539 | + event_namespace_factory: Callable[[], EventNamespace], |
| 540 | +): |
| 541 | + """Integration test for RedisTokenManager enumerate_tokens interface. |
| 542 | +
|
| 543 | + Should support enumerating tokens across separate instances of the |
| 544 | + RedisTokenManager. |
| 545 | +
|
| 546 | + Args: |
| 547 | + event_namespace_factory: Factory fixture for EventNamespace instances. |
| 548 | + """ |
| 549 | + event_namespace1 = event_namespace_factory() |
| 550 | + event_namespace2 = event_namespace_factory() |
| 551 | + |
| 552 | + await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1")) |
| 553 | + await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2")) |
| 554 | + |
| 555 | + found_tokens = set() |
| 556 | + async for token in event_namespace1._token_manager.enumerate_tokens(): |
| 557 | + found_tokens.add(token) |
| 558 | + |
| 559 | + assert "token1" in found_tokens |
| 560 | + assert "token2" in found_tokens |
| 561 | + assert len(found_tokens) == 2 |
| 562 | + |
| 563 | + await event_namespace1._token_manager.disconnect_all() |
| 564 | + |
| 565 | + found_tokens = set() |
| 566 | + async for token in event_namespace1._token_manager.enumerate_tokens(): |
| 567 | + found_tokens.add(token) |
| 568 | + assert "token2" in found_tokens |
| 569 | + assert len(found_tokens) == 1 |
| 570 | + |
| 571 | + await event_namespace2._token_manager.disconnect_all() |
| 572 | + |
| 573 | + found_tokens = set() |
| 574 | + async for token in event_namespace1._token_manager.enumerate_tokens(): |
| 575 | + found_tokens.add(token) |
| 576 | + assert not found_tokens |
| 577 | + |
| 578 | + |
| 579 | +@pytest.mark.usefixtures("redis_url") |
| 580 | +@pytest.mark.asyncio |
| 581 | +async def test_redis_token_manager_get_token_owner( |
| 582 | + event_namespace_factory: Callable[[], EventNamespace], |
| 583 | +): |
| 584 | + """Integration test for RedisTokenManager get_token_owner interface. |
| 585 | +
|
| 586 | + Should support retrieving the owner of a token across separate instances of the |
| 587 | + RedisTokenManager. |
| 588 | +
|
| 589 | + Args: |
| 590 | + event_namespace_factory: Factory fixture for EventNamespace instances. |
| 591 | + """ |
| 592 | + event_namespace1 = event_namespace_factory() |
| 593 | + event_namespace2 = event_namespace_factory() |
| 594 | + |
| 595 | + await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1")) |
| 596 | + await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2")) |
| 597 | + |
| 598 | + assert isinstance((manager1 := event_namespace1._token_manager), RedisTokenManager) |
| 599 | + assert isinstance((manager2 := event_namespace2._token_manager), RedisTokenManager) |
| 600 | + |
| 601 | + assert await manager1._get_token_owner("token1") == manager1.instance_id |
| 602 | + assert await manager1._get_token_owner("token2") == manager2.instance_id |
| 603 | + assert await manager2._get_token_owner("token1") == manager1.instance_id |
| 604 | + assert await manager2._get_token_owner("token2") == manager2.instance_id |
| 605 | + |
| 606 | + |
| 607 | +@pytest.mark.usefixtures("redis_url") |
| 608 | +@pytest.mark.asyncio |
| 609 | +async def test_redis_token_manager_lost_and_found( |
| 610 | + event_namespace_factory: Callable[[], EventNamespace], |
| 611 | +): |
| 612 | + """Updates emitted for lost and found tokens should be routed correctly via redis. |
| 613 | +
|
| 614 | + Args: |
| 615 | + event_namespace_factory: Factory fixture for EventNamespace instances. |
| 616 | + """ |
| 617 | + event_namespace1 = event_namespace_factory() |
| 618 | + emit1_mock: Mock = event_namespace1.emit # pyright: ignore[reportAssignmentType] |
| 619 | + event_namespace2 = event_namespace_factory() |
| 620 | + emit2_mock: Mock = event_namespace2.emit # pyright: ignore[reportAssignmentType] |
| 621 | + |
| 622 | + await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1")) |
| 623 | + await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2")) |
| 624 | + |
| 625 | + await event_namespace2.emit_update(StateUpdate(), token="token1") |
| 626 | + emit2_mock.assert_not_called() |
| 627 | + emit1_mock.assert_called_once() |
| 628 | + emit1_mock.reset_mock() |
| 629 | + |
| 630 | + await event_namespace2.emit_update(StateUpdate(), token="token2") |
| 631 | + emit1_mock.assert_not_called() |
| 632 | + emit2_mock.assert_called_once() |
| 633 | + emit2_mock.reset_mock() |
| 634 | + |
| 635 | + if task := event_namespace1.on_disconnect(sid="sid1"): |
| 636 | + await task |
| 637 | + await event_namespace2.emit_update(StateUpdate(), token="token1") |
| 638 | + # Update should be dropped on the floor. |
| 639 | + emit1_mock.assert_not_called() |
| 640 | + emit2_mock.assert_not_called() |
| 641 | + |
| 642 | + await event_namespace2.on_connect(sid="sid1", environ=query_string_for("token1")) |
| 643 | + await event_namespace2.emit_update(StateUpdate(), token="token1") |
| 644 | + emit1_mock.assert_not_called() |
| 645 | + emit2_mock.assert_called_once() |
| 646 | + emit2_mock.reset_mock() |
| 647 | + |
| 648 | + if task := event_namespace2.on_disconnect(sid="sid1"): |
| 649 | + await task |
| 650 | + await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1")) |
| 651 | + await event_namespace2.emit_update(StateUpdate(), token="token1") |
| 652 | + emit2_mock.assert_not_called() |
| 653 | + emit1_mock.assert_called_once() |
| 654 | + emit1_mock.reset_mock() |
0 commit comments