Skip to content

Commit fcb937c

Browse files
authored
ENG-8540: avoid dataclasses.asdict in Lost+Found path (#6057)
* ENG-8540: avoid dataclasses.asdict in Lost+Found path Use the reflex serializers registry to serialize StateUpdate objects for Lost+Found usage. * Use pickle instead of JSON for private records * oopsie * Fix pickle test expectation for test_connection_banner
1 parent 63809a3 commit fcb937c

File tree

3 files changed

+59
-26
lines changed

3 files changed

+59
-26
lines changed

reflex/utils/token_manager.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
import asyncio
66
import dataclasses
7-
import json
7+
import pickle
88
import uuid
99
from abc import ABC, abstractmethod
1010
from collections.abc import AsyncIterator, Callable, Coroutine
1111
from types import MappingProxyType
12-
from typing import TYPE_CHECKING, Any, ClassVar
12+
from typing import TYPE_CHECKING, ClassVar
1313

1414
from reflex.istate.manager.redis import StateManagerRedis
1515
from reflex.state import BaseState, StateUpdate
@@ -42,7 +42,7 @@ class LostAndFoundRecord:
4242
"""Record for a StateUpdate for a token with its socket on another instance."""
4343

4444
token: str
45-
update: dict[str, Any]
45+
update: StateUpdate
4646

4747

4848
class TokenManager(ABC):
@@ -328,7 +328,7 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None:
328328
try:
329329
await self.redis.set(
330330
redis_key,
331-
json.dumps(dataclasses.asdict(socket_record)),
331+
pickle.dumps(socket_record),
332332
ex=self.token_expiration,
333333
)
334334
except Exception as e:
@@ -386,8 +386,8 @@ async def _subscribe_lost_and_found_updates(
386386
)
387387
async for message in pubsub.listen():
388388
if message["type"] == "pmessage":
389-
record = LostAndFoundRecord(**json.loads(message["data"].decode()))
390-
await emit_update(StateUpdate(**record.update), record.token)
389+
record = pickle.loads(message["data"])
390+
await emit_update(record.update, record.token)
391391

392392
def ensure_lost_and_found_task(
393393
self,
@@ -424,10 +424,9 @@ async def _get_token_owner(self, token: str, refresh: bool = False) -> str | Non
424424

425425
redis_key = self._get_redis_key(token)
426426
try:
427-
record_json = await self.redis.get(redis_key)
428-
if record_json:
429-
record_data = json.loads(record_json)
430-
socket_record = SocketRecord(**record_data)
427+
record_pkl = await self.redis.get(redis_key)
428+
if record_pkl:
429+
socket_record = pickle.loads(record_pkl)
431430
self.token_to_socket[token] = socket_record
432431
self.sid_to_token[socket_record.sid] = token
433432
return socket_record.instance_id
@@ -454,11 +453,11 @@ async def emit_lost_and_found(
454453
owner_instance_id = await self._get_token_owner(token)
455454
if owner_instance_id is None:
456455
return False
457-
record = LostAndFoundRecord(token=token, update=dataclasses.asdict(update))
456+
record = LostAndFoundRecord(token=token, update=update)
458457
try:
459458
await self.redis.publish(
460459
f"channel:{self._get_lost_and_found_key(owner_instance_id)}",
461-
json.dumps(dataclasses.asdict(record)),
460+
pickle.dumps(record),
462461
)
463462
except Exception as e:
464463
console.error(f"Redis error publishing lost and found delta: {e}")

tests/integration/test_connection_banner.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test case for displaying the connection banner when the websocket drops."""
22

3+
import pickle
34
from collections.abc import Generator
45

56
import pytest
@@ -10,7 +11,7 @@
1011
from reflex.environment import environment
1112
from reflex.istate.manager.redis import StateManagerRedis
1213
from reflex.testing import AppHarness, WebDriver
13-
from reflex.utils.token_manager import RedisTokenManager
14+
from reflex.utils.token_manager import RedisTokenManager, SocketRecord
1415

1516
from .utils import SessionStorage
1617

@@ -166,11 +167,10 @@ async def test_connection_banner(connection_banner: AppHarness):
166167
sid_before = app_token_manager.token_to_sid[token]
167168
if isinstance(connection_banner.state_manager, StateManagerRedis):
168169
assert isinstance(app_token_manager, RedisTokenManager)
169-
assert (
170-
await connection_banner.state_manager.redis.get(
171-
app_token_manager._get_redis_key(token)
172-
)
173-
== f'{{"instance_id": "{app_token_manager.instance_id}", "sid": "{sid_before}"}}'.encode()
170+
assert await connection_banner.state_manager.redis.get(
171+
app_token_manager._get_redis_key(token)
172+
) == pickle.dumps(
173+
SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_before)
174174
)
175175

176176
delay_button = driver.find_element(By.ID, "delay")
@@ -226,11 +226,10 @@ async def test_connection_banner(connection_banner: AppHarness):
226226
assert sid_before != sid_after
227227
if isinstance(connection_banner.state_manager, StateManagerRedis):
228228
assert isinstance(app_token_manager, RedisTokenManager)
229-
assert (
230-
await connection_banner.state_manager.redis.get(
231-
app_token_manager._get_redis_key(token)
232-
)
233-
== f'{{"instance_id": "{app_token_manager.instance_id}", "sid": "{sid_after}"}}'.encode()
229+
assert await connection_banner.state_manager.redis.get(
230+
app_token_manager._get_redis_key(token)
231+
) == pickle.dumps(
232+
SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_after)
234233
)
235234

236235
# Count should have incremented after coming back up

tests/units/utils/test_token_manager.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Unit tests for TokenManager implementations."""
22

33
import asyncio
4-
import json
4+
import pickle
55
import time
66
from collections.abc import Callable, Generator
77
from contextlib import asynccontextmanager
@@ -11,6 +11,7 @@
1111

1212
from reflex import config
1313
from reflex.app import EventNamespace
14+
from reflex.istate.data import RouterData
1415
from reflex.state import StateUpdate
1516
from reflex.utils.token_manager import (
1617
LocalTokenManager,
@@ -300,7 +301,7 @@ async def test_link_token_to_sid_normal_case(self, manager, mock_redis):
300301
)
301302
mock_redis.set.assert_called_once_with(
302303
f"token_manager_socket_record_{token}",
303-
json.dumps({"instance_id": manager.instance_id, "sid": sid}),
304+
pickle.dumps(SocketRecord(instance_id=manager.instance_id, sid=sid)),
304305
ex=3600,
305306
)
306307
assert manager.token_to_socket[token].sid == sid
@@ -347,7 +348,7 @@ async def test_link_token_to_sid_duplicate_detected(self, manager, mock_redis):
347348
)
348349
mock_redis.set.assert_called_once_with(
349350
f"token_manager_socket_record_{result}",
350-
json.dumps({"instance_id": manager.instance_id, "sid": sid}),
351+
pickle.dumps(SocketRecord(instance_id=manager.instance_id, sid=sid)),
351352
ex=3600,
352353
)
353354
assert manager.token_to_sid[result] == sid
@@ -670,3 +671,37 @@ async def test_redis_token_manager_lost_and_found(
670671
emit2_mock.assert_not_called()
671672
emit1_mock.assert_called_once()
672673
emit1_mock.reset_mock()
674+
675+
676+
@pytest.mark.usefixtures("redis_url")
677+
@pytest.mark.asyncio
678+
async def test_redis_token_manager_lost_and_found_router_data(
679+
event_namespace_factory: Callable[[], EventNamespace],
680+
):
681+
"""Updates emitted for lost and found tokens should serialize properly.
682+
683+
Args:
684+
event_namespace_factory: Factory fixture for EventNamespace instances.
685+
"""
686+
event_namespace1 = event_namespace_factory()
687+
emit1_mock: Mock = event_namespace1.emit # pyright: ignore[reportAssignmentType]
688+
event_namespace2 = event_namespace_factory()
689+
emit2_mock: Mock = event_namespace2.emit # pyright: ignore[reportAssignmentType]
690+
691+
await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1"))
692+
await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2"))
693+
694+
router = RouterData.from_router_data(
695+
{"headers": {"x-test": "value"}},
696+
)
697+
698+
await event_namespace2.emit_update(
699+
StateUpdate(delta={"state": {"router": router}}), token="token1"
700+
)
701+
await _wait_for_call_count_positive(emit1_mock)
702+
emit2_mock.assert_not_called()
703+
emit1_mock.assert_called_once()
704+
assert isinstance(emit1_mock.call_args[0][1], StateUpdate)
705+
assert isinstance(emit1_mock.call_args[0][1].delta["state"]["router"], RouterData)
706+
assert emit1_mock.call_args[0][1].delta["state"]["router"] == router
707+
emit1_mock.reset_mock()

0 commit comments

Comments
 (0)