Skip to content

Commit a05f14f

Browse files
committed
Merge remote-tracking branch 'origin/main' into hide-__getattribute__-from-type-checking
2 parents 2a7dc77 + fcb937c commit a05f14f

File tree

10 files changed

+244
-35
lines changed

10 files changed

+244
-35
lines changed

pyi_hashes.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"reflex/components/core/window_events.pyi": "af33ccec866b9540ee7fbec6dbfbd151",
2424
"reflex/components/datadisplay/__init__.pyi": "52755871369acbfd3a96b46b9a11d32e",
2525
"reflex/components/datadisplay/code.pyi": "b86769987ef4d1cbdddb461be88539fd",
26-
"reflex/components/datadisplay/dataeditor.pyi": "35391d4ba147cf20ce4ac7a782066d61",
26+
"reflex/components/datadisplay/dataeditor.pyi": "fb26f3e702fcb885539d1cf82a854be3",
2727
"reflex/components/datadisplay/shiki_code_block.pyi": "1d53e75b6be0d3385a342e7b3011babd",
2828
"reflex/components/el/__init__.pyi": "0adfd001a926a2a40aee94f6fa725ecc",
2929
"reflex/components/el/element.pyi": "c5974a92fbc310e42d0f6cfdd13472f4",

reflex/components/datadisplay/dataeditor.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from reflex.utils.serializers import serializer
1616
from reflex.vars import get_unique_variable_name
1717
from reflex.vars.base import Var
18+
from reflex.vars.function import FunctionStringVar
1819
from reflex.vars.sequence import ArrayVar
1920

2021

@@ -260,6 +261,12 @@ class DataEditor(NoSSRComponent):
260261
# Allow columns selections. ("none", "single", "multi")
261262
column_select: Var[Literal["none", "single", "multi"]]
262263

264+
# Allow range selections. ("none", "cell", "rect", "multi-cell", "multi-rect").
265+
range_select: Var[Literal["none", "cell", "rect", "multi-cell", "multi-rect"]]
266+
267+
# Allow row selections. ("none", "single", "multi").
268+
row_select: Var[Literal["none", "single", "multi"]]
269+
263270
# Prevent diagonal scrolling.
264271
prevent_diagonal_scrolling: Var[bool]
265272

@@ -275,6 +282,18 @@ class DataEditor(NoSSRComponent):
275282
# Initial scroll offset on the vertical axis.
276283
scroll_offset_y: Var[int]
277284

285+
# Controls which types of range selections can exist at the same time. ("exclusive", "mixed").
286+
range_selection_blending: Var[Literal["exclusive", "mixed"]]
287+
288+
# Controls which types of column selections can exist at the same time. ("exclusive", "mixed").
289+
column_selection_blending: Var[Literal["exclusive", "mixed"]]
290+
291+
# Controls which types of row selections can exist at the same time. ("exclusive", "mixed").
292+
row_selection_blending: Var[Literal["exclusive", "mixed"]]
293+
294+
# Controls how spans are handled in selections. ("default", "allowPartial").
295+
span_range_behavior: Var[Literal["default", "allowPartial"]]
296+
278297
# global theme
279298
theme: Var[DataEditorTheme | dict]
280299

@@ -326,6 +345,12 @@ class DataEditor(NoSSRComponent):
326345
# Fired when a row is appended.
327346
on_row_appended: EventHandler[no_args_event_spec]
328347

348+
# The current grid selection state (columns, rows, and current cell/range). Must be used when on_grid_selection_change is used otherwise updates will not be reflected in the grid.
349+
grid_selection: Var[GridSelection]
350+
351+
# Fired when the grid selection changes. Will pass the current selection, the selected columns and the selected rows.
352+
on_grid_selection_change: EventHandler[passthrough_event_spec(GridSelection)]
353+
329354
# Fired when the selection is cleared.
330355
on_selection_cleared: EventHandler[no_args_event_spec]
331356

@@ -342,12 +367,61 @@ def add_imports(self) -> ImportDict:
342367
return {}
343368
return {
344369
"": f"{format.format_library_name(self.library)}/dist/index.css",
345-
self.library: "GridCellKind",
370+
self.library: ["GridCellKind", "CompactSelection"],
346371
"$/utils/helpers/dataeditor.js": ImportVar(
347372
tag="formatDataEditorCells", is_default=False, install=False
348373
),
349374
}
350375

376+
def add_custom_code(self) -> list[str]:
377+
"""Add custom code for reconstructing GridSelection with CompactSelection objects.
378+
379+
Note: When using on_grid_selection_change, Glide Data Grid will not update its internal selection state automatically. Instead,
380+
the grid_selection prop must be updated with a GridSelection object that has CompactSelection objects for the columns and rows properties.
381+
This function provides the necessary JavaScript code to reconstruct the GridSelection object from a dict representation.
382+
383+
Returns:
384+
JavaScript code to reconstruct GridSelection.
385+
"""
386+
return [
387+
"""
388+
function reconstructGridSelection(selection) {
389+
if (!selection || typeof selection !== 'object') {
390+
return undefined;
391+
}
392+
393+
const reconstructCompactSelection = (data) => {
394+
if (!data || !data.items || !Array.isArray(data.items)) {
395+
return CompactSelection.empty();
396+
}
397+
398+
const items = data.items;
399+
if (items.length === 0) {
400+
return CompactSelection.empty();
401+
}
402+
403+
let result = CompactSelection.empty();
404+
405+
// Items are stored as [start, end) ranges in CompactSelection internal format
406+
for (const item of items) {
407+
if (Array.isArray(item) && item.length === 2) {
408+
const [start, end] = item;
409+
result = result.add([start, end]);
410+
}
411+
}
412+
413+
return result;
414+
};
415+
416+
return {
417+
current: selection.current || undefined,
418+
columns: reconstructCompactSelection(selection.columns),
419+
rows: reconstructCompactSelection(selection.rows)
420+
};
421+
}
422+
"""
423+
]
424+
351425
def add_hooks(self) -> list[str]:
352426
"""Get the hooks to render.
353427
@@ -429,6 +503,15 @@ def create(cls, *children, **props) -> Component:
429503
console.warn(
430504
"get_cell_content is not user configurable, the provided value will be discarded"
431505
)
506+
507+
# Apply the reconstruction function to grid_selection if it's a Var
508+
if (grid_selection := props.get("grid_selection")) is not None and isinstance(
509+
grid_selection, Var
510+
):
511+
props["grid_selection"] = FunctionStringVar.create(
512+
"reconstructGridSelection"
513+
).call(grid_selection)
514+
432515
grid = super().create(*children, **props)
433516
return Div.create(
434517
grid,

reflex/istate/proxy.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,13 @@ def _wrap_recursive(self, value: Any) -> Any:
508508
# When called from dataclasses internal code, return the unwrapped value
509509
if self._is_called_from_dataclasses_internal():
510510
return value
511-
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
512-
if is_mutable_type(type(value)) and not isinstance(value, MutableProxy):
511+
# If we already have a proxy, make sure the state reference is up to date and return it.
512+
if isinstance(value, MutableProxy):
513+
if value._self_state is not self._self_state:
514+
value._self_state = self._self_state
515+
return value
516+
# Recursively wrap mutable types.
517+
if is_mutable_type(type(value)):
513518
base_cls = globals()[self.__base_proxy__]
514519
return base_cls(
515520
wrapped=value,

reflex/istate/shared.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,10 +348,11 @@ async def _modify_linked_states(
348348
linked_state_name
349349
)
350350
)
351-
# TODO: Avoid always fetched linked states, it should be based on
352-
# whether the state is accessed, however then `get_state` would need
353-
# to know how to fetch in a linked state.
354-
original_state = await self.get_state(linked_state_cls)
351+
try:
352+
original_state = self._get_state_from_cache(linked_state_cls)
353+
except ValueError:
354+
# This state wasn't required for processing the event.
355+
continue
355356
linked_state = await original_state._internal_patch_linked_state(
356357
linked_token
357358
)
@@ -410,3 +411,9 @@ def __init_subclass__(cls, **kwargs):
410411
root_state = cls.get_root_state()
411412
if root_state.backend_vars["_reflex_internal_links"] is None:
412413
root_state.backend_vars["_reflex_internal_links"] = {}
414+
if root_state is State:
415+
# Always fetch SharedStateBaseInternal to access
416+
# `_modify_linked_states` without having to use `.get_state()` which
417+
# pulls in all linked states and substates which may not actually be
418+
# accessed for this event.
419+
root_state._always_dirty_substates.add(SharedStateBaseInternal.get_name())

reflex/state.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2470,6 +2470,35 @@ class State(BaseState):
24702470
# Maps the state full_name to an arbitrary token it is linked to for shared state.
24712471
_reflex_internal_links: dict[str, str] | None = None
24722472

2473+
@_override_base_method
2474+
async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE:
2475+
"""Get a state instance from redis with linking support.
2476+
2477+
Args:
2478+
state_cls: The class of the state.
2479+
2480+
Returns:
2481+
The instance of state_cls associated with this state's client_token.
2482+
"""
2483+
state_instance = await super()._get_state_from_redis(state_cls)
2484+
if (
2485+
self._reflex_internal_links
2486+
and (
2487+
linked_token := self._reflex_internal_links.get(
2488+
state_cls.get_full_name()
2489+
)
2490+
)
2491+
is not None
2492+
and (
2493+
internal_patch_linked_state := getattr(
2494+
state_instance, "_internal_patch_linked_state", None
2495+
)
2496+
)
2497+
is not None
2498+
):
2499+
return await internal_patch_linked_state(linked_token)
2500+
return state_instance
2501+
24732502
@event
24742503
def set_is_hydrated(self, value: bool) -> None:
24752504
"""Set the hydrated state.

reflex/utils/token_manager.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
import asyncio
66
import dataclasses
7-
import json
7+
import pickle
88
import uuid
99
from abc import ABC, abstractmethod
1010
from collections.abc import AsyncIterator, Callable, Coroutine
1111
from types import MappingProxyType
12-
from typing import TYPE_CHECKING, Any, ClassVar
12+
from typing import TYPE_CHECKING, ClassVar
1313

1414
from reflex.istate.manager.redis import StateManagerRedis
1515
from reflex.state import BaseState, StateUpdate
@@ -42,7 +42,7 @@ class LostAndFoundRecord:
4242
"""Record for a StateUpdate for a token with its socket on another instance."""
4343

4444
token: str
45-
update: dict[str, Any]
45+
update: StateUpdate
4646

4747

4848
class TokenManager(ABC):
@@ -328,7 +328,7 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None:
328328
try:
329329
await self.redis.set(
330330
redis_key,
331-
json.dumps(dataclasses.asdict(socket_record)),
331+
pickle.dumps(socket_record),
332332
ex=self.token_expiration,
333333
)
334334
except Exception as e:
@@ -386,8 +386,8 @@ async def _subscribe_lost_and_found_updates(
386386
)
387387
async for message in pubsub.listen():
388388
if message["type"] == "pmessage":
389-
record = LostAndFoundRecord(**json.loads(message["data"].decode()))
390-
await emit_update(StateUpdate(**record.update), record.token)
389+
record = pickle.loads(message["data"])
390+
await emit_update(record.update, record.token)
391391

392392
def ensure_lost_and_found_task(
393393
self,
@@ -424,10 +424,9 @@ async def _get_token_owner(self, token: str, refresh: bool = False) -> str | Non
424424

425425
redis_key = self._get_redis_key(token)
426426
try:
427-
record_json = await self.redis.get(redis_key)
428-
if record_json:
429-
record_data = json.loads(record_json)
430-
socket_record = SocketRecord(**record_data)
427+
record_pkl = await self.redis.get(redis_key)
428+
if record_pkl:
429+
socket_record = pickle.loads(record_pkl)
431430
self.token_to_socket[token] = socket_record
432431
self.sid_to_token[socket_record.sid] = token
433432
return socket_record.instance_id
@@ -454,11 +453,11 @@ async def emit_lost_and_found(
454453
owner_instance_id = await self._get_token_owner(token)
455454
if owner_instance_id is None:
456455
return False
457-
record = LostAndFoundRecord(token=token, update=dataclasses.asdict(update))
456+
record = LostAndFoundRecord(token=token, update=update)
458457
try:
459458
await self.redis.publish(
460459
f"channel:{self._get_lost_and_found_key(owner_instance_id)}",
461-
json.dumps(dataclasses.asdict(record)),
460+
pickle.dumps(record),
462461
)
463462
except Exception as e:
464463
console.error(f"Redis error publishing lost and found delta: {e}")

reflex/vars/dep_tracking.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,11 @@ def load_attr_or_method(self, instruction: dis.Instruction) -> None:
179179
if not self.top_of_stack:
180180
return
181181
target_obj = self.get_tracked_local(self.top_of_stack)
182-
target_state = assert_base_state(target_obj, local_name=self.top_of_stack)
182+
try:
183+
target_state = assert_base_state(target_obj, local_name=self.top_of_stack)
184+
except VarValueError:
185+
# If the target state is not a BaseState, we cannot track dependencies on it.
186+
return
183187
try:
184188
ref_obj = getattr(target_state, instruction.argval)
185189
except AttributeError:

tests/integration/test_connection_banner.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test case for displaying the connection banner when the websocket drops."""
22

3+
import pickle
34
from collections.abc import Generator
45

56
import pytest
@@ -10,7 +11,7 @@
1011
from reflex.environment import environment
1112
from reflex.istate.manager.redis import StateManagerRedis
1213
from reflex.testing import AppHarness, WebDriver
13-
from reflex.utils.token_manager import RedisTokenManager
14+
from reflex.utils.token_manager import RedisTokenManager, SocketRecord
1415

1516
from .utils import SessionStorage
1617

@@ -166,11 +167,10 @@ async def test_connection_banner(connection_banner: AppHarness):
166167
sid_before = app_token_manager.token_to_sid[token]
167168
if isinstance(connection_banner.state_manager, StateManagerRedis):
168169
assert isinstance(app_token_manager, RedisTokenManager)
169-
assert (
170-
await connection_banner.state_manager.redis.get(
171-
app_token_manager._get_redis_key(token)
172-
)
173-
== f'{{"instance_id": "{app_token_manager.instance_id}", "sid": "{sid_before}"}}'.encode()
170+
assert await connection_banner.state_manager.redis.get(
171+
app_token_manager._get_redis_key(token)
172+
) == pickle.dumps(
173+
SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_before)
174174
)
175175

176176
delay_button = driver.find_element(By.ID, "delay")
@@ -226,11 +226,10 @@ async def test_connection_banner(connection_banner: AppHarness):
226226
assert sid_before != sid_after
227227
if isinstance(connection_banner.state_manager, StateManagerRedis):
228228
assert isinstance(app_token_manager, RedisTokenManager)
229-
assert (
230-
await connection_banner.state_manager.redis.get(
231-
app_token_manager._get_redis_key(token)
232-
)
233-
== f'{{"instance_id": "{app_token_manager.instance_id}", "sid": "{sid_after}"}}'.encode()
229+
assert await connection_banner.state_manager.redis.get(
230+
app_token_manager._get_redis_key(token)
231+
) == pickle.dumps(
232+
SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_after)
234233
)
235234

236235
# Count should have incremented after coming back up

0 commit comments

Comments
 (0)