Skip to content

Commit bea266b

Browse files
authored
make object var handle all mapping instead of just dict (#4602)
* make object var handle all mapping instead of just dict * unbreak ci * get it right pyright * create generic variable for field * add support for typeddict (to some degree) * import from extensions
1 parent abaaa22 commit bea266b

File tree

8 files changed

+100
-64
lines changed

8 files changed

+100
-64
lines changed

reflex/utils/types.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,22 @@ def wrapper(*args, **kwargs):
829829
StateIterBases = get_base_class(StateIterVar)
830830

831831

832+
def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]):
833+
"""Check if a class is a subclass of another class. Returns False if internal error occurs.
834+
835+
Args:
836+
cls: The class to check.
837+
cls_check: The class to check against.
838+
839+
Returns:
840+
Whether the class is a subclass of the other class.
841+
"""
842+
try:
843+
return issubclass(cls, cls_check)
844+
except TypeError:
845+
return False
846+
847+
832848
def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
833849
"""Check if a type hint is a subclass of another type hint.
834850

reflex/vars/base.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Iterable,
2727
List,
2828
Literal,
29+
Mapping,
2930
NoReturn,
3031
Optional,
3132
Set,
@@ -64,6 +65,7 @@
6465
_isinstance,
6566
get_origin,
6667
has_args,
68+
safe_issubclass,
6769
unionize,
6870
)
6971

@@ -127,7 +129,7 @@ def __init__(
127129
state: str = "",
128130
field_name: str = "",
129131
imports: ImportDict | ParsedImportDict | None = None,
130-
hooks: dict[str, VarData | None] | None = None,
132+
hooks: Mapping[str, VarData | None] | None = None,
131133
deps: list[Var] | None = None,
132134
position: Hooks.HookPosition | None = None,
133135
):
@@ -643,8 +645,8 @@ def to(
643645
@overload
644646
def to(
645647
self,
646-
output: type[dict],
647-
) -> ObjectVar[dict]: ...
648+
output: type[Mapping],
649+
) -> ObjectVar[Mapping]: ...
648650

649651
@overload
650652
def to(
@@ -686,7 +688,9 @@ def to(
686688

687689
# If the first argument is a python type, we map it to the corresponding Var type.
688690
for var_subclass in _var_subclasses[::-1]:
689-
if fixed_output_type in var_subclass.python_types:
691+
if fixed_output_type in var_subclass.python_types or safe_issubclass(
692+
fixed_output_type, var_subclass.python_types
693+
):
690694
return self.to(var_subclass.var_subclass, output)
691695

692696
if fixed_output_type is None:
@@ -820,7 +824,7 @@ def _get_default_value(self) -> Any:
820824
return False
821825
if issubclass(type_, list):
822826
return []
823-
if issubclass(type_, dict):
827+
if issubclass(type_, Mapping):
824828
return {}
825829
if issubclass(type_, tuple):
826830
return ()
@@ -1026,7 +1030,7 @@ def _as_ref(self) -> Var:
10261030
f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")]
10271031
}
10281032
),
1029-
).to(ObjectVar, Dict[str, str])
1033+
).to(ObjectVar, Mapping[str, str])
10301034
return refs[LiteralVar.create(str(self))]
10311035

10321036
@deprecated("Use `.js_type()` instead.")
@@ -1373,7 +1377,7 @@ def create(
13731377

13741378
serialized_value = serializers.serialize(value)
13751379
if serialized_value is not None:
1376-
if isinstance(serialized_value, dict):
1380+
if isinstance(serialized_value, Mapping):
13771381
return LiteralObjectVar.create(
13781382
serialized_value,
13791383
_var_type=type(value),
@@ -1498,7 +1502,7 @@ def var_operation(
14981502
) -> Callable[P, ArrayVar[LIST_T]]: ...
14991503

15001504

1501-
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
1505+
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping)
15021506

15031507

15041508
@overload
@@ -1573,8 +1577,8 @@ def figure_out_type(value: Any) -> types.GenericType:
15731577
return Set[unionize(*(figure_out_type(v) for v in value))]
15741578
if isinstance(value, tuple):
15751579
return Tuple[unionize(*(figure_out_type(v) for v in value)), ...]
1576-
if isinstance(value, dict):
1577-
return Dict[
1580+
if isinstance(value, Mapping):
1581+
return Mapping[
15781582
unionize(*(figure_out_type(k) for k in value)),
15791583
unionize(*(figure_out_type(v) for v in value.values())),
15801584
]
@@ -2002,10 +2006,10 @@ def __get__(
20022006

20032007
@overload
20042008
def __get__(
2005-
self: ComputedVar[dict[DICT_KEY, DICT_VAL]],
2009+
self: ComputedVar[Mapping[DICT_KEY, DICT_VAL]],
20062010
instance: None,
20072011
owner: Type,
2008-
) -> ObjectVar[dict[DICT_KEY, DICT_VAL]]: ...
2012+
) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
20092013

20102014
@overload
20112015
def __get__(
@@ -2915,11 +2919,14 @@ def dispatch(
29152919

29162920
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
29172921

2922+
FIELD_TYPE = TypeVar("FIELD_TYPE")
2923+
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
2924+
29182925

2919-
class Field(Generic[T]):
2926+
class Field(Generic[FIELD_TYPE]):
29202927
"""Shadow class for Var to allow for type hinting in the IDE."""
29212928

2922-
def __set__(self, instance, value: T):
2929+
def __set__(self, instance, value: FIELD_TYPE):
29232930
"""Set the Var.
29242931
29252932
Args:
@@ -2931,7 +2938,9 @@ def __set__(self, instance, value: T):
29312938
def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
29322939

29332940
@overload
2934-
def __get__(self: Field[int], instance: None, owner) -> NumberVar: ...
2941+
def __get__(
2942+
self: Field[int] | Field[float] | Field[int | float], instance: None, owner
2943+
) -> NumberVar: ...
29352944

29362945
@overload
29372946
def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
@@ -2948,19 +2957,19 @@ def __get__(
29482957

29492958
@overload
29502959
def __get__(
2951-
self: Field[Dict[str, V]], instance: None, owner
2952-
) -> ObjectVar[Dict[str, V]]: ...
2960+
self: Field[MAPPING_TYPE], instance: None, owner
2961+
) -> ObjectVar[MAPPING_TYPE]: ...
29532962

29542963
@overload
29552964
def __get__(
29562965
self: Field[BASE_TYPE], instance: None, owner
29572966
) -> ObjectVar[BASE_TYPE]: ...
29582967

29592968
@overload
2960-
def __get__(self, instance: None, owner) -> Var[T]: ...
2969+
def __get__(self, instance: None, owner) -> Var[FIELD_TYPE]: ...
29612970

29622971
@overload
2963-
def __get__(self, instance, owner) -> T: ...
2972+
def __get__(self, instance, owner) -> FIELD_TYPE: ...
29642973

29652974
def __get__(self, instance, owner): # type: ignore
29662975
"""Get the Var.
@@ -2971,7 +2980,7 @@ def __get__(self, instance, owner): # type: ignore
29712980
"""
29722981

29732982

2974-
def field(value: T) -> Field[T]:
2983+
def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]:
29752984
"""Create a Field with a value.
29762985
29772986
Args:

0 commit comments

Comments
 (0)