Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
EventChain,
EventHandler,
EventSpec,
args_specs_from_fields,
no_args_event_spec,
parse_args_spec,
pointer_event_spec,
Expand Down Expand Up @@ -909,18 +910,7 @@ def get_event_triggers(cls) -> dict[str, types.ArgsSpec | Sequence[types.ArgsSpe
"""
# Look for component specific triggers,
# e.g. variable declared as EventHandler types.
return DEFAULT_TRIGGERS | {
name: (
metadata[0]
if (
(metadata := getattr(field.annotated_type, "__metadata__", None))
is not None
)
else no_args_event_spec
)
for name, field in cls.get_fields().items()
if field.type_origin is EventHandler
} # pyright: ignore [reportOperatorIssue]
return DEFAULT_TRIGGERS | args_specs_from_fields(cls.get_fields()) # pyright: ignore [reportOperatorIssue]

def __repr__(self) -> str:
"""Represent the component in React.
Expand Down
15 changes: 15 additions & 0 deletions reflex/components/props.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing_extensions import dataclass_transform

from reflex.components.field import BaseField, FieldBasedMeta
from reflex.event import EventChain, args_specs_from_fields
from reflex.utils import format
from reflex.utils.exceptions import InvalidPropValueError
from reflex.utils.serializers import serializer
Expand Down Expand Up @@ -267,6 +268,20 @@ def __init__(self, **kwargs):
setattr(self, field_name, field.default_factory())
# Note: Fields with no default and no factory remain unset (required fields)

# Convert EventHandler to EventChain
args_specs = args_specs_from_fields(self.get_fields())
for handler_name, args_spec in args_specs.items():
if (handler := getattr(self, handler_name, None)) is not None:
setattr(
self,
handler_name,
EventChain.create(
value=handler,
args_spec=args_spec,
key=handler_name,
),
)
Comment on lines +271 to +283
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The EventChain conversion logic executes for every PropsBase instantiation even when no EventHandler fields exist. Consider adding a quick check to skip this block if args_specs is empty for better performance.


@classmethod
def get_fields(cls) -> dict[str, Any]:
"""Get the fields of the object.
Expand Down
27 changes: 27 additions & 0 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing_extensions import Self, TypeAliasType, TypedDict, TypeVarTuple, Unpack

from reflex import constants
from reflex.components.field import BaseField
from reflex.constants.compiler import CompileVars, Hooks, Imports
from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.utils import format
Expand Down Expand Up @@ -1654,6 +1655,31 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
), annotations


def args_specs_from_fields(
fields_dict: Mapping[str, BaseField],
) -> dict[str, ArgsSpec | Sequence[ArgsSpec]]:
"""Get the event triggers and arg specs from the given fields.

Args:
fields_dict: The fields, keyed by name

Returns:
The args spec for any field annotated as EventHandler.
"""
return {
name: (
metadata[0]
if (
(metadata := getattr(field.annotated_type, "__metadata__", None))
is not None
)
else no_args_event_spec
)
for name, field in fields_dict.items()
if field.type_origin is EventHandler
}


def check_fn_match_arg_spec(
user_func: Callable,
user_func_parameters: Mapping[str, inspect.Parameter],
Expand Down Expand Up @@ -2406,6 +2432,7 @@ def wrapper(
check_fn_match_arg_spec = staticmethod(check_fn_match_arg_spec)
resolve_annotation = staticmethod(resolve_annotation)
parse_args_spec = staticmethod(parse_args_spec)
args_specs_from_fields = staticmethod(args_specs_from_fields)
unwrap_var_annotation = staticmethod(unwrap_var_annotation)
get_fn_signature = staticmethod(get_fn_signature)

Expand Down
36 changes: 36 additions & 0 deletions tests/units/components/test_props.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from __future__ import annotations

import pytest
from pydantic.v1 import ValidationError

from reflex.components.props import NoExtrasAllowedProps, PropsBase
from reflex.event import (
EventChain,
EventHandler,
event,
no_args_event_spec,
passthrough_event_spec,
)
from reflex.state import State
from reflex.utils.exceptions import InvalidPropValueError


Expand Down Expand Up @@ -177,3 +187,29 @@ def test_props_base_dict_conversion(props_class, props_kwargs, expected_dict):
props = props_class(**props_kwargs)
result = props.dict()
assert result == expected_dict


class EventProps(PropsBase):
"""Test props with event handler fields."""

on_click: EventHandler[no_args_event_spec]
not_start_with_on: EventHandler[passthrough_event_spec(str)]


def test_event_handler_props():
class FooState(State):
@event
def handle_click(self):
pass

@event
def handle_input(self, value: str):
pass

props = EventProps(
on_click=FooState.handle_click, # pyright: ignore[reportArgumentType]
not_start_with_on=FooState.handle_input, # pyright: ignore[reportArgumentType]
)
props_dict = props.dict()
assert isinstance(props_dict["onClick"], EventChain)
assert isinstance(props_dict["notStartWithOn"], EventChain)
Loading