Skip to content

Commit d17a70e

Browse files
committed
ENG-8540: avoid dataclasses.asdict in Lost+Found path
Use the reflex serializers registry to serialize StateUpdate objects for Lost+Found usage.
1 parent 1e71398 commit d17a70e

File tree

2 files changed

+48
-6
lines changed

2 files changed

+48
-6
lines changed

reflex/utils/token_manager.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
from abc import ABC, abstractmethod
1010
from collections.abc import AsyncIterator, Callable, Coroutine
1111
from types import MappingProxyType
12-
from typing import TYPE_CHECKING, Any, ClassVar
12+
from typing import TYPE_CHECKING, ClassVar
1313

1414
from reflex.istate.manager.redis import StateManagerRedis
1515
from reflex.state import BaseState, StateUpdate
1616
from reflex.utils import console, prerequisites
17+
from reflex.utils.format import json_dumps
1718
from reflex.utils.tasks import ensure_task
1819

1920
if TYPE_CHECKING:
@@ -42,7 +43,7 @@ class LostAndFoundRecord:
4243
"""Record for a StateUpdate for a token with its socket on another instance."""
4344

4445
token: str
45-
update: dict[str, Any]
46+
update: StateUpdate
4647

4748

4849
class TokenManager(ABC):
@@ -386,8 +387,12 @@ async def _subscribe_lost_and_found_updates(
386387
)
387388
async for message in pubsub.listen():
388389
if message["type"] == "pmessage":
389-
record = LostAndFoundRecord(**json.loads(message["data"].decode()))
390-
await emit_update(StateUpdate(**record.update), record.token)
390+
record_dict = json.loads(message["data"].decode())
391+
record = LostAndFoundRecord(
392+
token=record_dict["token"],
393+
update=StateUpdate(**record_dict["update"]),
394+
)
395+
await emit_update(record.update, record.token)
391396

392397
def ensure_lost_and_found_task(
393398
self,
@@ -454,11 +459,11 @@ async def emit_lost_and_found(
454459
owner_instance_id = await self._get_token_owner(token)
455460
if owner_instance_id is None:
456461
return False
457-
record = LostAndFoundRecord(token=token, update=dataclasses.asdict(update))
462+
record = LostAndFoundRecord(token=token, update=update)
458463
try:
459464
await self.redis.publish(
460465
f"channel:{self._get_lost_and_found_key(owner_instance_id)}",
461-
json.dumps(dataclasses.asdict(record)),
466+
json_dumps(record),
462467
)
463468
except Exception as e:
464469
console.error(f"Redis error publishing lost and found delta: {e}")

tests/units/utils/test_token_manager.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
from reflex import config
1313
from reflex.app import EventNamespace
14+
from reflex.istate.data import RouterData
1415
from reflex.state import StateUpdate
16+
from reflex.utils.format import json_dumps
1517
from reflex.utils.token_manager import (
1618
LocalTokenManager,
1719
RedisTokenManager,
@@ -670,3 +672,38 @@ async def test_redis_token_manager_lost_and_found(
670672
emit2_mock.assert_not_called()
671673
emit1_mock.assert_called_once()
672674
emit1_mock.reset_mock()
675+
676+
677+
@pytest.mark.usefixtures("redis_url")
678+
@pytest.mark.asyncio
679+
async def test_redis_token_manager_lost_and_found_router_data(
680+
event_namespace_factory: Callable[[], EventNamespace],
681+
):
682+
"""Updates emitted for lost and found tokens should serialize properly.
683+
684+
Args:
685+
event_namespace_factory: Factory fixture for EventNamespace instances.
686+
"""
687+
event_namespace1 = event_namespace_factory()
688+
emit1_mock: Mock = event_namespace1.emit # pyright: ignore[reportAssignmentType]
689+
event_namespace2 = event_namespace_factory()
690+
emit2_mock: Mock = event_namespace2.emit # pyright: ignore[reportAssignmentType]
691+
692+
await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1"))
693+
await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2"))
694+
695+
router = RouterData.from_router_data(
696+
{"headers": {"x-test": "value"}},
697+
)
698+
699+
await event_namespace2.emit_update(
700+
StateUpdate(delta={"state": {"router": router}}), token="token1"
701+
)
702+
await _wait_for_call_count_positive(emit1_mock)
703+
emit2_mock.assert_not_called()
704+
emit1_mock.assert_called_once()
705+
assert isinstance(emit1_mock.call_args[0][1], StateUpdate)
706+
assert emit1_mock.call_args[0][1].delta["state"]["router"] == json.loads(
707+
json_dumps(router)
708+
)
709+
emit1_mock.reset_mock()

0 commit comments

Comments
 (0)