Skip to content

Commit b9fe004

Browse files
committed
AI CR feedback
1 parent 0a13453 commit b9fe004

File tree

1 file changed

+63
-44
lines changed

1 file changed

+63
-44
lines changed

reflex/istate/shared.py

Lines changed: 63 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,27 @@
88
from reflex.constants import ROUTER_DATA
99
from reflex.event import Event, get_hydrate_event
1010
from reflex.state import BaseState, State, _override_base_method, _substate_key
11+
from reflex.utils import console
1112
from reflex.utils.exceptions import ReflexRuntimeError
1213

1314
UPDATE_OTHER_CLIENT_TASKS: set[asyncio.Task] = set()
1415
LINKED_STATE = TypeVar("LINKED_STATE", bound="SharedStateBaseInternal")
1516

1617

18+
def _log_update_client_errors(task: asyncio.Task):
19+
"""Log errors from updating other clients.
20+
21+
Args:
22+
task: The asyncio task to check for errors.
23+
"""
24+
try:
25+
task.result()
26+
except Exception as e:
27+
console.warn(f"Error updating linked client: {e}")
28+
finally:
29+
UPDATE_OTHER_CLIENT_TASKS.discard(task)
30+
31+
1732
def _do_update_other_tokens(
1833
affected_tokens: set[str],
1934
previous_dirty_vars: dict[str, set[str]],
@@ -47,10 +62,10 @@ async def _update_client(token: str):
4762
# Don't send updates for disconnected clients.
4863
if affected_token not in app.event_namespace._token_manager.token_to_socket:
4964
continue
50-
# TODO: remove disconnected client's after some time.
65+
# TODO: remove disconnected clients after some time.
5166
t = asyncio.create_task(_update_client(affected_token))
5267
UPDATE_OTHER_CLIENT_TASKS.add(t)
53-
t.add_done_callback(UPDATE_OTHER_CLIENT_TASKS.discard)
68+
t.add_done_callback(_log_update_client_errors)
5469
tasks.append(t)
5570
return tasks
5671

@@ -99,7 +114,7 @@ class SharedStateBaseInternal(State):
99114
"""The private base state for all shared states."""
100115

101116
_exit_stack: contextlib.AsyncExitStack | None = None
102-
_held_locks: dict[str, dict[type[BaseState], BaseState]] = {}
117+
_held_locks: dict[str, dict[type[BaseState], BaseState]] | None = None
103118

104119
def __getstate__(self):
105120
"""Override redis serialization to remove temporary fields.
@@ -164,7 +179,7 @@ async def _link_to(self, token: str) -> Self:
164179
clients linked to that token.
165180
166181
Args:
167-
token: The token to link to.
182+
token: The token to link to (Cannot contain underscore characters).
168183
169184
Returns:
170185
The newly linked state.
@@ -196,7 +211,7 @@ async def _unlink(self):
196211
"""Unlink this shared state from its linked token.
197212
198213
Returns:
199-
The events to rehydrate the state after unlinking (these should be returned/yielded
214+
The events to rehydrate the state after unlinking (these should be returned/yielded).
200215
"""
201216
from reflex.istate.manager import get_state_manager
202217

@@ -242,7 +257,7 @@ async def _internal_patch_linked_state(
242257
"""
243258
from reflex.istate.manager import get_state_manager
244259

245-
if self._exit_stack is None:
260+
if self._exit_stack is None or self._held_locks is None:
246261
msg = "Cannot link shared state outside of _modify_linked_states context."
247262
raise ReflexRuntimeError(msg)
248263

@@ -279,6 +294,8 @@ def _held_locks_linked_states(self) -> list["SharedState"]:
279294
Returns:
280295
The list of linked states currently held.
281296
"""
297+
if self._held_locks is None:
298+
return []
282299
return [
283300
linked_state
284301
for linked_state_cls_to_instance in self._held_locks.values()
@@ -313,44 +330,46 @@ async def _modify_linked_states(
313330
self._held_locks = {}
314331
current_dirty_vars: dict[str, set[str]] = {}
315332
affected_tokens: set[str] = set()
316-
# Go through all linked states and patch them in if they are present in the tree
317-
for linked_state_name, linked_token in self._reflex_internal_links.items():
318-
linked_state_cls: type[SharedState] = (
319-
self.get_root_state().get_class_substate( # pyright: ignore[reportAssignmentType]
320-
linked_state_name
321-
)
322-
)
323-
# TODO: Avoid always fetched linked states, it should be based on
324-
# whether the state is accessed, however then `get_state` would need
325-
# to know how to fetch in a linked state.
326-
original_state = await self.get_state(linked_state_cls)
327-
linked_state = await original_state._internal_patch_linked_state(
328-
linked_token
329-
)
330-
if (
331-
previous_dirty_vars
332-
and (dv := previous_dirty_vars.get(linked_state_name)) is not None
333-
):
334-
linked_state.dirty_vars.update(dv)
335-
linked_state._mark_dirty()
336-
async with self._exit_stack:
337-
yield None
338-
# Collect dirty vars and other affected clients that need to be updated.
339-
for linked_state in self._held_locks_linked_states():
340-
if linked_state._previous_dirty_vars is not None:
341-
current_dirty_vars[linked_state.get_full_name()] = set(
342-
linked_state._previous_dirty_vars
333+
try:
334+
# Go through all linked states and patch them in if they are present in the tree
335+
for linked_state_name, linked_token in self._reflex_internal_links.items():
336+
linked_state_cls: type[SharedState] = (
337+
self.get_root_state().get_class_substate( # pyright: ignore[reportAssignmentType]
338+
linked_state_name
343339
)
340+
)
341+
# TODO: Avoid always fetched linked states, it should be based on
342+
# whether the state is accessed, however then `get_state` would need
343+
# to know how to fetch in a linked state.
344+
original_state = await self.get_state(linked_state_cls)
345+
linked_state = await original_state._internal_patch_linked_state(
346+
linked_token
347+
)
344348
if (
345-
linked_state._get_was_touched()
346-
or linked_state._previous_dirty_vars is not None
349+
previous_dirty_vars
350+
and (dv := previous_dirty_vars.get(linked_state_name)) is not None
347351
):
348-
affected_tokens.update(
349-
token
350-
for token in linked_state._linked_from
351-
if token != self.router.session.client_token
352-
)
353-
self._exit_stack = None
352+
linked_state.dirty_vars.update(dv)
353+
linked_state._mark_dirty()
354+
async with self._exit_stack:
355+
yield None
356+
# Collect dirty vars and other affected clients that need to be updated.
357+
for linked_state in self._held_locks_linked_states():
358+
if linked_state._previous_dirty_vars is not None:
359+
current_dirty_vars[linked_state.get_full_name()] = set(
360+
linked_state._previous_dirty_vars
361+
)
362+
if (
363+
linked_state._get_was_touched()
364+
or linked_state._previous_dirty_vars is not None
365+
):
366+
affected_tokens.update(
367+
token
368+
for token in linked_state._linked_from
369+
if token != self.router.session.client_token
370+
)
371+
finally:
372+
self._exit_stack = None
354373

355374
# Only propagate dirty vars when we are not already propagating from another state.
356375
if previous_dirty_vars is None:
@@ -364,9 +383,9 @@ async def _modify_linked_states(
364383
class SharedState(SharedStateBaseInternal, mixin=True):
365384
"""Mixin for defining new shared states."""
366385

367-
_linked_from: set[str]
368-
_linked_to: str
369-
_previous_dirty_vars: set[str]
386+
_linked_from: set[str] = set()
387+
_linked_to: str = ""
388+
_previous_dirty_vars: set[str] = set()
370389

371390
@classmethod
372391
def __init_subclass__(cls, **kwargs):

0 commit comments

Comments
 (0)