Skip to content

Commit 247c021

Browse files
authored
PropsBase converts EventHandler-annotated props to EventChain (#5765)
* PropsBase converts EventHandler-annotated props to EventChain * Move _resolve_annotations to FieldBasedMeta Allow both Component and PropsBase to resolve annotations from the module namespace.
1 parent f0b2075 commit 247c021

File tree

5 files changed

+85
-21
lines changed

5 files changed

+85
-21
lines changed

reflex/components/component.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
EventChain,
4444
EventHandler,
4545
EventSpec,
46+
args_specs_from_fields,
4647
no_args_event_spec,
4748
parse_args_spec,
4849
pointer_event_spec,
@@ -143,14 +144,6 @@ class BaseComponentMeta(FieldBasedMeta, ABCMeta):
143144
_fields: Mapping[str, ComponentField]
144145
_js_fields: Mapping[str, ComponentField]
145146

146-
@classmethod
147-
def _resolve_annotations(
148-
cls, namespace: dict[str, Any], name: str
149-
) -> dict[str, Any]:
150-
return types.resolve_annotations(
151-
namespace.get("__annotations__", {}), namespace["__module__"]
152-
)
153-
154147
@classmethod
155148
def _process_annotated_fields(
156149
cls,
@@ -909,18 +902,7 @@ def get_event_triggers(cls) -> dict[str, types.ArgsSpec | Sequence[types.ArgsSpe
909902
"""
910903
# Look for component specific triggers,
911904
# e.g. variable declared as EventHandler types.
912-
return DEFAULT_TRIGGERS | {
913-
name: (
914-
metadata[0]
915-
if (
916-
(metadata := getattr(field.annotated_type, "__metadata__", None))
917-
is not None
918-
)
919-
else no_args_event_spec
920-
)
921-
for name, field in cls.get_fields().items()
922-
if field.type_origin is EventHandler
923-
} # pyright: ignore [reportOperatorIssue]
905+
return DEFAULT_TRIGGERS | args_specs_from_fields(cls.get_fields()) # pyright: ignore [reportOperatorIssue]
924906

925907
def __repr__(self) -> str:
926908
"""Represent the component in React.

reflex/components/field.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from dataclasses import _MISSING_TYPE, MISSING
77
from typing import Annotated, Any, Generic, TypeVar, get_origin
88

9+
from reflex.utils import types
10+
911
FIELD_TYPE = TypeVar("FIELD_TYPE")
1012

1113

@@ -114,7 +116,9 @@ def _collect_inherited_fields(cls, bases: tuple[type]) -> dict[str, Any]:
114116
def _resolve_annotations(
115117
cls, namespace: dict[str, Any], name: str
116118
) -> dict[str, Any]:
117-
return namespace.get("__annotations__", {})
119+
return types.resolve_annotations(
120+
namespace.get("__annotations__", {}), namespace["__module__"]
121+
)
118122

119123
@classmethod
120124
def _process_field_overrides(

reflex/components/props.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing_extensions import dataclass_transform
1010

1111
from reflex.components.field import BaseField, FieldBasedMeta
12+
from reflex.event import EventChain, args_specs_from_fields
1213
from reflex.utils import format
1314
from reflex.utils.exceptions import InvalidPropValueError
1415
from reflex.utils.serializers import serializer
@@ -267,6 +268,20 @@ def __init__(self, **kwargs):
267268
setattr(self, field_name, field.default_factory())
268269
# Note: Fields with no default and no factory remain unset (required fields)
269270

271+
# Convert EventHandler to EventChain
272+
args_specs = args_specs_from_fields(self.get_fields())
273+
for handler_name, args_spec in args_specs.items():
274+
if (handler := getattr(self, handler_name, None)) is not None:
275+
setattr(
276+
self,
277+
handler_name,
278+
EventChain.create(
279+
value=handler,
280+
args_spec=args_spec,
281+
key=handler_name,
282+
),
283+
)
284+
270285
@classmethod
271286
def get_fields(cls) -> dict[str, Any]:
272287
"""Get the fields of the object.

reflex/event.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing_extensions import Self, TypeAliasType, TypedDict, TypeVarTuple, Unpack
2727

2828
from reflex import constants
29+
from reflex.components.field import BaseField
2930
from reflex.constants.compiler import CompileVars, Hooks, Imports
3031
from reflex.constants.state import FRONTEND_EVENT_STATE
3132
from reflex.utils import format
@@ -1684,6 +1685,31 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
16841685
), annotations
16851686

16861687

1688+
def args_specs_from_fields(
1689+
fields_dict: Mapping[str, BaseField],
1690+
) -> dict[str, ArgsSpec | Sequence[ArgsSpec]]:
1691+
"""Get the event triggers and arg specs from the given fields.
1692+
1693+
Args:
1694+
fields_dict: The fields, keyed by name
1695+
1696+
Returns:
1697+
The args spec for any field annotated as EventHandler.
1698+
"""
1699+
return {
1700+
name: (
1701+
metadata[0]
1702+
if (
1703+
(metadata := getattr(field.annotated_type, "__metadata__", None))
1704+
is not None
1705+
)
1706+
else no_args_event_spec
1707+
)
1708+
for name, field in fields_dict.items()
1709+
if field.type_origin is EventHandler
1710+
}
1711+
1712+
16871713
def check_fn_match_arg_spec(
16881714
user_func: Callable,
16891715
user_func_parameters: Mapping[str, inspect.Parameter],
@@ -2436,6 +2462,7 @@ def wrapper(
24362462
check_fn_match_arg_spec = staticmethod(check_fn_match_arg_spec)
24372463
resolve_annotation = staticmethod(resolve_annotation)
24382464
parse_args_spec = staticmethod(parse_args_spec)
2465+
args_specs_from_fields = staticmethod(args_specs_from_fields)
24392466
unwrap_var_annotation = staticmethod(unwrap_var_annotation)
24402467
get_fn_signature = staticmethod(get_fn_signature)
24412468

tests/units/components/test_props.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
1+
from __future__ import annotations
2+
13
import pytest
24
from pydantic.v1 import ValidationError
35

46
from reflex.components.props import NoExtrasAllowedProps, PropsBase
7+
from reflex.event import (
8+
EventChain,
9+
EventHandler,
10+
event,
11+
no_args_event_spec,
12+
passthrough_event_spec,
13+
)
14+
from reflex.state import State
515
from reflex.utils.exceptions import InvalidPropValueError
616

717

@@ -177,3 +187,29 @@ def test_props_base_dict_conversion(props_class, props_kwargs, expected_dict):
177187
props = props_class(**props_kwargs)
178188
result = props.dict()
179189
assert result == expected_dict
190+
191+
192+
class EventProps(PropsBase):
193+
"""Test props with event handler fields."""
194+
195+
on_click: EventHandler[no_args_event_spec]
196+
not_start_with_on: EventHandler[passthrough_event_spec(str)]
197+
198+
199+
def test_event_handler_props():
200+
class FooState(State):
201+
@event
202+
def handle_click(self):
203+
pass
204+
205+
@event
206+
def handle_input(self, value: str):
207+
pass
208+
209+
props = EventProps(
210+
on_click=FooState.handle_click, # pyright: ignore[reportArgumentType]
211+
not_start_with_on=FooState.handle_input, # pyright: ignore[reportArgumentType]
212+
)
213+
props_dict = props.dict()
214+
assert isinstance(props_dict["onClick"], EventChain)
215+
assert isinstance(props_dict["notStartWithOn"], EventChain)

0 commit comments

Comments
 (0)