Skip to content

Commit 12a42b6

Browse files
authored
var_data fixes with hooks values (#4717)
* var_data fixes with hooks values * remove the raise error
1 parent 335816c commit 12a42b6

File tree

4 files changed

+62
-4
lines changed

4 files changed

+62
-4
lines changed

reflex/event.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838

3939
from reflex import constants
40+
from reflex.constants.compiler import CompileVars, Hooks, Imports
4041
from reflex.constants.state import FRONTEND_EVENT_STATE
4142
from reflex.utils import console, format
4243
from reflex.utils.exceptions import (
@@ -1729,7 +1730,13 @@ def create(
17291730
arg_def_expr = Var(_js_expr="args")
17301731

17311732
if value.invocation is None:
1732-
invocation = FunctionStringVar.create("addEvents")
1733+
invocation = FunctionStringVar.create(
1734+
CompileVars.ADD_EVENTS,
1735+
_var_data=VarData(
1736+
imports=Imports.EVENTS,
1737+
hooks={Hooks.EVENTS: None},
1738+
),
1739+
)
17331740
else:
17341741
invocation = value.invocation
17351742

reflex/vars/base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Mapping,
3030
NoReturn,
3131
Optional,
32+
Sequence,
3233
Set,
3334
Tuple,
3435
Type,
@@ -131,7 +132,7 @@ def __init__(
131132
state: str = "",
132133
field_name: str = "",
133134
imports: ImportDict | ParsedImportDict | None = None,
134-
hooks: Mapping[str, VarData | None] | None = None,
135+
hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None,
135136
deps: list[Var] | None = None,
136137
position: Hooks.HookPosition | None = None,
137138
):
@@ -145,6 +146,10 @@ def __init__(
145146
deps: Dependencies of the var for useCallback.
146147
position: Position of the hook in the component.
147148
"""
149+
if isinstance(hooks, str):
150+
hooks = [hooks]
151+
if not isinstance(hooks, dict):
152+
hooks = {hook: None for hook in (hooks or [])}
148153
immutable_imports: ImmutableParsedImportDict = tuple(
149154
(k, tuple(v)) for k, v in parse_imports(imports or {}).items()
150155
)
@@ -155,6 +160,16 @@ def __init__(
155160
object.__setattr__(self, "deps", tuple(deps or []))
156161
object.__setattr__(self, "position", position or None)
157162

163+
if hooks and any(hooks.values()):
164+
merged_var_data = VarData.merge(self, *hooks.values())
165+
if merged_var_data is not None:
166+
object.__setattr__(self, "state", merged_var_data.state)
167+
object.__setattr__(self, "field_name", merged_var_data.field_name)
168+
object.__setattr__(self, "imports", merged_var_data.imports)
169+
object.__setattr__(self, "hooks", merged_var_data.hooks)
170+
object.__setattr__(self, "deps", merged_var_data.deps)
171+
object.__setattr__(self, "position", merged_var_data.position)
172+
158173
def old_school_imports(self) -> ImportDict:
159174
"""Return the imports as a mutable dict.
160175

tests/units/test_event.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
import reflex as rx
6+
from reflex.constants.compiler import Hooks, Imports
67
from reflex.event import (
78
Event,
89
EventChain,
@@ -14,7 +15,7 @@
1415
)
1516
from reflex.state import BaseState
1617
from reflex.utils import format
17-
from reflex.vars.base import Field, LiteralVar, Var, field
18+
from reflex.vars.base import Field, LiteralVar, Var, VarData, field
1819

1920

2021
def make_var(value) -> Var:
@@ -443,9 +444,28 @@ def _args_spec(value: Var[int]) -> tuple[Var[int]]:
443444
return (value,)
444445

445446
# Ensure chain carries _var_data
446-
chain_var = Var.create(EventChain(events=[S.s(S.x)], args_spec=_args_spec))
447+
chain_var = Var.create(
448+
EventChain(
449+
events=[S.s(S.x)],
450+
args_spec=_args_spec,
451+
invocation=rx.vars.FunctionStringVar.create(""),
452+
)
453+
)
447454
assert chain_var._get_all_var_data() == S.x._get_all_var_data()
448455

456+
chain_var_data = Var.create(
457+
EventChain(
458+
events=[],
459+
args_spec=_args_spec,
460+
)
461+
)._get_all_var_data()
462+
assert chain_var_data is not None
463+
464+
assert chain_var_data == VarData(
465+
imports=Imports.EVENTS,
466+
hooks={Hooks.EVENTS: None},
467+
)
468+
449469

450470
def test_event_bound_method() -> None:
451471
class S(BaseState):

tests/units/test_var.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1862,3 +1862,19 @@ class TestState(BaseState):
18621862

18631863
single_var = Var.create(Email())
18641864
assert single_var._var_type == Email
1865+
1866+
1867+
def test_var_data_hooks():
1868+
var_data_str = VarData(hooks="what")
1869+
var_data_list = VarData(hooks=["what"])
1870+
var_data_dict = VarData(hooks={"what": None})
1871+
assert var_data_str == var_data_list == var_data_dict
1872+
1873+
var_data_list_multiple = VarData(hooks=["what", "whot"])
1874+
var_data_dict_multiple = VarData(hooks={"what": None, "whot": None})
1875+
assert var_data_list_multiple == var_data_dict_multiple
1876+
1877+
1878+
def test_var_data_with_hooks_value():
1879+
var_data = VarData(hooks={"what": VarData(hooks={"whot": VarData(hooks="whott")})})
1880+
assert var_data == VarData(hooks=["what", "whot", "whott"])

0 commit comments

Comments
 (0)