Skip to content

Commit 3a6f747

Browse files
authored
relax foreach to handle optional (#4901)
* relax foreach to handle optional * simplify get index
1 parent 1e07ec7 commit 3a6f747

File tree

8 files changed

+138
-60
lines changed

8 files changed

+138
-60
lines changed

reflex/components/core/cond.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from reflex.components.tags import CondTag, Tag
1010
from reflex.constants import Dirs
1111
from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode
12+
from reflex.utils import types
1213
from reflex.utils.imports import ImportDict, ImportVar
1314
from reflex.vars import VarData
1415
from reflex.vars.base import LiteralVar, Var
@@ -145,20 +146,20 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
145146
if c2 is None:
146147
raise ValueError("For conditional vars, the second argument must be set.")
147148

148-
def create_var(cond_part: Any) -> Var[Any]:
149-
return LiteralVar.create(cond_part)
150-
151149
# convert the truth and false cond parts into vars so the _var_data can be obtained.
152-
c1 = create_var(c1)
153-
c2 = create_var(c2)
150+
c1_var = Var.create(c1)
151+
c2_var = Var.create(c2)
152+
153+
if condition is c1_var:
154+
c1_var = c1_var.to(types.value_inside_optional(c1_var._var_type))
154155

155156
# Create the conditional var.
156157
return ternary_operation(
157158
cond_var.bool()._replace(
158159
merge_var_data=VarData(imports=_IS_TRUE_IMPORT),
159160
),
160-
c1,
161-
c2,
161+
c1_var,
162+
c2_var,
162163
)
163164

164165

reflex/components/core/foreach.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88

99
from reflex.components.base.fragment import Fragment
1010
from reflex.components.component import Component
11+
from reflex.components.core.cond import cond
1112
from reflex.components.tags import IterTag
1213
from reflex.constants import MemoizationMode
1314
from reflex.state import ComponentState
15+
from reflex.utils import types
1416
from reflex.utils.exceptions import UntypedVarError
1517
from reflex.vars.base import LiteralVar, Var
1618

@@ -85,6 +87,9 @@ def create(
8587
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
8688
)
8789

90+
if types.is_optional(iterable._var_type):
91+
iterable = cond(iterable, iterable, [])
92+
8893
component = cls(
8994
iterable=iterable,
9095
render_fn=render_fn,
@@ -164,7 +169,6 @@ def render(self):
164169
iterable_state=str(tag.iterable),
165170
arg_name=tag.arg_var_name,
166171
arg_index=tag.get_index_var_arg(),
167-
iterable_type=tag.iterable._var_type.mro()[0].__name__,
168172
)
169173

170174

reflex/components/tags/iter_tag.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
import dataclasses
66
import inspect
7-
from typing import TYPE_CHECKING, Any, Callable, Iterable, Type, Union, get_args
7+
from typing import TYPE_CHECKING, Callable, Iterable
88

99
from reflex.components.tags.tag import Tag
10+
from reflex.utils.types import GenericType
1011
from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
12+
from reflex.vars.sequence import _determine_value_of_array_index
1113

1214
if TYPE_CHECKING:
1315
from reflex.components.component import Component
@@ -31,24 +33,13 @@ class IterTag(Tag):
3133
# The name of the index var.
3234
index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
3335

34-
def get_iterable_var_type(self) -> Type:
36+
def get_iterable_var_type(self) -> GenericType:
3537
"""Get the type of the iterable var.
3638
3739
Returns:
3840
The type of the iterable var.
3941
"""
40-
iterable = self.iterable
41-
try:
42-
if iterable._var_type.mro()[0] is dict:
43-
# Arg is a tuple of (key, value).
44-
return tuple[get_args(iterable._var_type)] # pyright: ignore [reportReturnType]
45-
elif iterable._var_type.mro()[0] is tuple:
46-
# Arg is a union of any possible values in the tuple.
47-
return Union[get_args(iterable._var_type)] # pyright: ignore [reportReturnType]
48-
else:
49-
return get_args(iterable._var_type)[0]
50-
except Exception:
51-
return Any # pyright: ignore [reportReturnType]
42+
return _determine_value_of_array_index(self.iterable._var_type)
5243

5344
def get_index_var(self) -> Var:
5445
"""Get the index var for the tag (with curly braces).

reflex/vars/base.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,7 +1598,14 @@ def var_operation( # pyright: ignore [reportOverlappingOverload]
15981598

15991599
@overload
16001600
def var_operation(
1601-
func: Callable[P, CustomVarOperationReturn[bool]],
1601+
func: Callable[P, CustomVarOperationReturn[None]],
1602+
) -> Callable[P, NoneVar]: ...
1603+
1604+
1605+
@overload
1606+
def var_operation( # pyright: ignore [reportOverlappingOverload]
1607+
func: Callable[P, CustomVarOperationReturn[bool]]
1608+
| Callable[P, CustomVarOperationReturn[bool | None]],
16021609
) -> Callable[P, BooleanVar]: ...
16031610

16041611

@@ -1607,13 +1614,15 @@ def var_operation(
16071614

16081615
@overload
16091616
def var_operation(
1610-
func: Callable[P, CustomVarOperationReturn[NUMBER_T]],
1617+
func: Callable[P, CustomVarOperationReturn[NUMBER_T]]
1618+
| Callable[P, CustomVarOperationReturn[NUMBER_T | None]],
16111619
) -> Callable[P, NumberVar[NUMBER_T]]: ...
16121620

16131621

16141622
@overload
16151623
def var_operation(
1616-
func: Callable[P, CustomVarOperationReturn[str]],
1624+
func: Callable[P, CustomVarOperationReturn[str]]
1625+
| Callable[P, CustomVarOperationReturn[str | None]],
16171626
) -> Callable[P, StringVar]: ...
16181627

16191628

@@ -1622,7 +1631,8 @@ def var_operation(
16221631

16231632
@overload
16241633
def var_operation(
1625-
func: Callable[P, CustomVarOperationReturn[LIST_T]],
1634+
func: Callable[P, CustomVarOperationReturn[LIST_T]]
1635+
| Callable[P, CustomVarOperationReturn[LIST_T | None]],
16261636
) -> Callable[P, ArrayVar[LIST_T]]: ...
16271637

16281638

@@ -1631,13 +1641,15 @@ def var_operation(
16311641

16321642
@overload
16331643
def var_operation(
1634-
func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]],
1644+
func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]]
1645+
| Callable[P, CustomVarOperationReturn[OBJECT_TYPE | None]],
16351646
) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ...
16361647

16371648

16381649
@overload
16391650
def var_operation(
1640-
func: Callable[P, CustomVarOperationReturn[T]],
1651+
func: Callable[P, CustomVarOperationReturn[T]]
1652+
| Callable[P, CustomVarOperationReturn[T | None]],
16411653
) -> Callable[P, Var[T]]: ...
16421654

16431655

@@ -3278,53 +3290,71 @@ def __set__(self, instance: Any, value: FIELD_TYPE):
32783290
"""
32793291

32803292
@overload
3281-
def __get__(self: Field[bool], instance: None, owner: Any) -> BooleanVar: ...
3293+
def __get__(self: Field[None], instance: None, owner: Any) -> NoneVar: ...
32823294

32833295
@overload
32843296
def __get__(
3285-
self: Field[int] | Field[float] | Field[int | float], instance: None, owner: Any
3286-
) -> NumberVar: ...
3297+
self: Field[bool] | Field[bool | None], instance: None, owner: Any
3298+
) -> BooleanVar: ...
32873299

32883300
@overload
3289-
def __get__(self: Field[str], instance: None, owner: Any) -> StringVar: ...
3301+
def __get__(
3302+
self: Field[int]
3303+
| Field[float]
3304+
| Field[int | float]
3305+
| Field[int | None]
3306+
| Field[float | None]
3307+
| Field[int | float | None],
3308+
instance: None,
3309+
owner: Any,
3310+
) -> NumberVar: ...
32903311

32913312
@overload
3292-
def __get__(self: Field[None], instance: None, owner: Any) -> NoneVar: ...
3313+
def __get__(
3314+
self: Field[str] | Field[str | None], instance: None, owner: Any
3315+
) -> StringVar: ...
32933316

32943317
@overload
32953318
def __get__(
3296-
self: Field[list[V]] | Field[set[V]],
3319+
self: Field[list[V]]
3320+
| Field[set[V]]
3321+
| Field[list[V] | None]
3322+
| Field[set[V] | None],
32973323
instance: None,
32983324
owner: Any,
32993325
) -> ArrayVar[Sequence[V]]: ...
33003326

33013327
@overload
33023328
def __get__(
3303-
self: Field[SEQUENCE_TYPE],
3329+
self: Field[SEQUENCE_TYPE] | Field[SEQUENCE_TYPE | None],
33043330
instance: None,
33053331
owner: Any,
33063332
) -> ArrayVar[SEQUENCE_TYPE]: ...
33073333

33083334
@overload
33093335
def __get__(
3310-
self: Field[MAPPING_TYPE], instance: None, owner: Any
3336+
self: Field[MAPPING_TYPE] | Field[MAPPING_TYPE | None],
3337+
instance: None,
3338+
owner: Any,
33113339
) -> ObjectVar[MAPPING_TYPE]: ...
33123340

33133341
@overload
33143342
def __get__(
3315-
self: Field[BASE_TYPE], instance: None, owner: Any
3343+
self: Field[BASE_TYPE] | Field[BASE_TYPE | None], instance: None, owner: Any
33163344
) -> ObjectVar[BASE_TYPE]: ...
33173345

33183346
@overload
33193347
def __get__(
3320-
self: Field[SQLA_TYPE], instance: None, owner: Any
3348+
self: Field[SQLA_TYPE] | Field[SQLA_TYPE | None], instance: None, owner: Any
33213349
) -> ObjectVar[SQLA_TYPE]: ...
33223350

33233351
if TYPE_CHECKING:
33243352

33253353
@overload
33263354
def __get__(
3327-
self: Field[DATACLASS_TYPE], instance: None, owner: Any
3355+
self: Field[DATACLASS_TYPE] | Field[DATACLASS_TYPE | None],
3356+
instance: None,
3357+
owner: Any,
33283358
) -> ObjectVar[DATACLASS_TYPE]: ...
33293359

33303360
@overload

reflex/vars/object.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,14 @@ def object_keys_operation(value: ObjectVar):
441441
Returns:
442442
The keys of the object.
443443
"""
444+
if not types.is_optional(value._var_type):
445+
return var_operation_return(
446+
js_expression=f"Object.keys({value})",
447+
var_type=list[str],
448+
)
444449
return var_operation_return(
445-
js_expression=f"Object.keys({value})",
446-
var_type=list[str],
450+
js_expression=f"((value) => value ?? undefined === undefined ? undefined : Object.keys(value))({value})",
451+
var_type=(list[str] | None),
447452
)
448453

449454

@@ -457,9 +462,14 @@ def object_values_operation(value: ObjectVar):
457462
Returns:
458463
The values of the object.
459464
"""
465+
if not types.is_optional(value._var_type):
466+
return var_operation_return(
467+
js_expression=f"Object.values({value})",
468+
var_type=list[value._value_type()],
469+
)
460470
return var_operation_return(
461-
js_expression=f"Object.values({value})",
462-
var_type=list[value._value_type()],
471+
js_expression=f"((value) => value ?? undefined === undefined ? undefined : Object.values(value))({value})",
472+
var_type=(list[value._value_type()] | None),
463473
)
464474

465475

@@ -473,9 +483,14 @@ def object_entries_operation(value: ObjectVar):
473483
Returns:
474484
The entries of the object.
475485
"""
486+
if not types.is_optional(value._var_type):
487+
return var_operation_return(
488+
js_expression=f"Object.entries({value})",
489+
var_type=list[tuple[str, value._value_type()]],
490+
)
476491
return var_operation_return(
477-
js_expression=f"Object.entries({value})",
478-
var_type=list[tuple[str, value._value_type()]],
492+
js_expression=f"((value) => value ?? undefined === undefined ? undefined : Object.entries(value))({value})",
493+
var_type=(list[tuple[str, value._value_type()]] | None),
479494
)
480495

481496

tests/integration/test_var_operations.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ class VarOperationState(rx.State):
3333
list2: rx.Field[list] = rx.field([3, 4])
3434
list3: rx.Field[list] = rx.field(["first", "second", "third"])
3535
list4: rx.Field[list] = rx.field([Object(name="obj_1"), Object(name="obj_2")])
36+
optional_list: rx.Field[list | None] = rx.field(None)
37+
optional_dict: rx.Field[dict[str, str] | None] = rx.field(None)
38+
optional_list_value: rx.Field[list[str] | None] = rx.field(["red", "yellow"])
39+
optional_dict_value: rx.Field[dict[str, str] | None] = rx.field({"name": "red"})
3640
str_var1: rx.Field[str] = rx.field("first")
3741
str_var2: rx.Field[str] = rx.field("second")
3842
str_var3: rx.Field[str] = rx.field("ThIrD")
@@ -645,6 +649,22 @@ def index():
645649
),
646650
id="typed_dict_in_foreach",
647651
),
652+
rx.box(
653+
rx.foreach(VarOperationState.optional_list, rx.text.span),
654+
id="optional_list",
655+
),
656+
rx.box(
657+
rx.foreach(VarOperationState.optional_dict, rx.text.span),
658+
id="optional_dict",
659+
),
660+
rx.box(
661+
rx.foreach(VarOperationState.optional_list_value, rx.text.span),
662+
id="optional_list_value",
663+
),
664+
rx.box(
665+
rx.foreach(VarOperationState.optional_dict_value, rx.text.span),
666+
id="optional_dict_value",
667+
),
648668
)
649669

650670

0 commit comments

Comments
 (0)