Skip to content

Commit 321545c

Browse files
adhami3310masenf
andauthored
allow arguments to be passed to rx memo event handlers (#5021)
* allow arguments to be passed * add test case for memo components accepting event handlers * AppHarness: quit browsers early to avoid lingering events * fix ruff --------- Co-authored-by: Masen Furer <[email protected]>
1 parent 7a2ec85 commit 321545c

File tree

4 files changed

+90
-16
lines changed

4 files changed

+90
-16
lines changed

reflex/compiler/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def compile_custom_component(
318318
}
319319

320320
# Concatenate the props.
321-
props = [prop._js_expr for prop in component.get_prop_vars()]
321+
props = list(component.props)
322322

323323
# Compile the component.
324324
return (

reflex/components/component.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,13 @@
4949
from reflex.constants.compiler import SpecialAttributes
5050
from reflex.constants.state import FRONTEND_EVENT_STATE
5151
from reflex.event import (
52-
EventActionsMixin,
5352
EventCallback,
5453
EventChain,
5554
EventHandler,
5655
EventSpec,
5756
no_args_event_spec,
57+
parse_args_spec,
58+
run_script,
5859
)
5960
from reflex.style import Style, format_as_emotion
6061
from reflex.utils import console, format, imports, types
@@ -66,7 +67,7 @@
6667
Var,
6768
cached_property_no_lock,
6869
)
69-
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar
70+
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar, FunctionVar
7071
from reflex.vars.number import ternary_operation
7172
from reflex.vars.object import ObjectVar
7273
from reflex.vars.sequence import LiteralArrayVar
@@ -1900,7 +1901,44 @@ def _get_all_custom_components(
19001901

19011902
return custom_components
19021903

1903-
def get_prop_vars(self) -> List[Var]:
1904+
@staticmethod
1905+
def _get_event_spec_from_args_spec(name: str, event: EventChain) -> Callable:
1906+
"""Get the event spec from the args spec.
1907+
1908+
Args:
1909+
name: The name of the event
1910+
event: The args spec.
1911+
1912+
Returns:
1913+
The event spec.
1914+
"""
1915+
1916+
def fn(*args):
1917+
return run_script(Var(name).to(FunctionVar).call(*args))
1918+
1919+
if event.args_spec:
1920+
arg_spec = (
1921+
event.args_spec
1922+
if not isinstance(event.args_spec, Sequence)
1923+
else event.args_spec[0]
1924+
)
1925+
names = inspect.getfullargspec(arg_spec).args
1926+
fn.__signature__ = inspect.Signature( # pyright: ignore[reportFunctionMemberAccess]
1927+
parameters=[
1928+
inspect.Parameter(
1929+
name=name,
1930+
kind=inspect.Parameter.POSITIONAL_ONLY,
1931+
annotation=arg._var_type,
1932+
)
1933+
for name, arg in zip(
1934+
names, parse_args_spec(event.args_spec), strict=True
1935+
)
1936+
]
1937+
)
1938+
1939+
return fn
1940+
1941+
def get_prop_vars(self) -> List[Var | Callable]:
19041942
"""Get the prop vars.
19051943
19061944
Returns:
@@ -1909,16 +1947,10 @@ def get_prop_vars(self) -> List[Var]:
19091947
return [
19101948
Var(
19111949
_js_expr=name,
1912-
_var_type=(
1913-
prop._var_type
1914-
if isinstance(prop, Var)
1915-
else (
1916-
type(prop)
1917-
if not isinstance(prop, EventActionsMixin)
1918-
else EventChain
1919-
)
1920-
),
1950+
_var_type=(prop._var_type if isinstance(prop, Var) else type(prop)),
19211951
).guess_type()
1952+
if isinstance(prop, Var) or not isinstance(prop, EventChain)
1953+
else CustomComponent._get_event_spec_from_args_spec(name, prop)
19221954
for name, prop in self.props.items()
19231955
]
19241956

reflex/testing.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,10 @@ def __enter__(self) -> "AppHarness":
462462

463463
def stop(self) -> None:
464464
"""Stop the frontend and backend servers."""
465+
# Quit browsers first to avoid any lingering events being sent during shutdown.
466+
for driver in self._frontends:
467+
driver.quit()
468+
465469
self._reload_state_module()
466470

467471
if self.backend is not None:
@@ -492,8 +496,6 @@ def stop(self) -> None:
492496
self.backend_thread.join()
493497
if self.frontend_output_thread is not None:
494498
self.frontend_output_thread.join()
495-
for driver in self._frontends:
496-
driver.quit()
497499

498500
# Cleanup decorated pages added during testing
499501
for page in self._decorated_pages:

tests/integration/test_memo.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,32 @@ def foo_component(t: str):
2626
def foo_component2(t: str):
2727
return FooComponent.create(t, rx.Var("foo"))
2828

29+
class MemoState(rx.State):
30+
last_value: str = ""
31+
32+
@rx.event
33+
def set_last_value(self, value: str):
34+
self.last_value = value
35+
36+
@rx.memo
37+
def my_memoed_component(
38+
some_value: str,
39+
event: rx.EventHandler[rx.event.passthrough_event_spec(str)],
40+
) -> rx.Component:
41+
return rx.vstack(
42+
rx.button(some_value, id="memo-button", on_click=event(some_value)),
43+
rx.input(id="memo-input", on_change=event),
44+
)
45+
2946
def index() -> rx.Component:
3047
return rx.vstack(
31-
foo_component(t="foo"), foo_component2(t="bar"), id="memo-custom-code"
48+
rx.vstack(
49+
foo_component(t="foo"), foo_component2(t="bar"), id="memo-custom-code"
50+
),
51+
rx.text(MemoState.last_value, id="memo-last-value"),
52+
my_memoed_component(
53+
some_value="memod_some_value", event=MemoState.set_last_value
54+
),
3255
)
3356

3457
app = rx.App()
@@ -64,4 +87,21 @@ async def test_memo_app(memo_app: AppHarness):
6487

6588
# check that the output matches
6689
memo_custom_code_stack = driver.find_element(By.ID, "memo-custom-code")
90+
assert (
91+
memo_app.poll_for_content(memo_custom_code_stack, exp_not_equal="")
92+
== "foobarbarbar"
93+
)
6794
assert memo_custom_code_stack.text == "foobarbarbar"
95+
96+
# click the button to trigger partial event application
97+
button = driver.find_element(By.ID, "memo-button")
98+
button.click()
99+
last_value = driver.find_element(By.ID, "memo-last-value")
100+
assert memo_app.poll_for_content(last_value, exp_not_equal="") == "memod_some_value"
101+
102+
# enter text to trigger passed argument to event handler
103+
textbox = driver.find_element(By.ID, "memo-input")
104+
textbox.send_keys("new_value")
105+
assert memo_app._poll_for(
106+
lambda: memo_app.poll_for_content(last_value) == "new_value"
107+
)

0 commit comments

Comments
 (0)