1818from dataclasses import _MISSING_TYPE , MISSING
1919from decimal import Decimal
2020from types import CodeType , FunctionType
21- from typing import ( # noqa: UP035
21+ from typing import (
2222 TYPE_CHECKING ,
2323 Annotated ,
2424 Any ,
2525 ClassVar ,
26- Dict ,
27- FrozenSet ,
2826 Generic ,
29- List ,
3027 Literal ,
3128 NoReturn ,
3229 ParamSpec ,
3330 Protocol ,
34- Set ,
35- Tuple ,
3631 TypeGuard ,
3732 TypeVar ,
3833 cast ,
7873
7974if TYPE_CHECKING :
8075 from reflex .components .component import BaseComponent
76+ from reflex .constants .colors import Color
8177 from reflex .state import BaseState
8278
79+ from .color import LiteralColorVar
8380 from .number import BooleanVar , LiteralBooleanVar , LiteralNumberVar , NumberVar
8481 from .object import LiteralObjectVar , ObjectVar
8582 from .sequence import ArrayVar , LiteralArrayVar , LiteralStringVar , StringVar
@@ -642,6 +639,14 @@ def create(
642639 _var_data : VarData | None = None ,
643640 ) -> LiteralNumberVar [Decimal ]: ...
644641
642+ @overload
643+ @classmethod
644+ def create ( # pyright: ignore [reportOverlappingOverload]
645+ cls ,
646+ value : Color ,
647+ _var_data : VarData | None = None ,
648+ ) -> LiteralColorVar : ...
649+
645650 @overload
646651 @classmethod
647652 def create ( # pyright: ignore [reportOverlappingOverload]
@@ -3080,7 +3085,8 @@ def transform(fn: Callable[[Var], Var]) -> Callable[[Var], Var]:
30803085 TypeError: If the Var return type does not have a generic type.
30813086 ValueError: If a function for the generic type is already registered.
30823087 """
3083- return_type = fn .__annotations__ ["return" ]
3088+ types = get_type_hints (fn )
3089+ return_type = types ["return" ]
30843090
30853091 origin = get_origin (return_type )
30863092
@@ -3105,101 +3111,6 @@ def transform(fn: Callable[[Var], Var]) -> Callable[[Var], Var]:
31053111 return fn
31063112
31073113
3108- def generic_type_to_actual_type_map (
3109- generic_type : GenericType , actual_type : GenericType
3110- ) -> dict [TypeVar , GenericType ]:
3111- """Map the generic type to the actual type.
3112-
3113- Args:
3114- generic_type: The generic type.
3115- actual_type: The actual type.
3116-
3117- Returns:
3118- The mapping of type variables to actual types.
3119-
3120- Raises:
3121- TypeError: If the generic type and actual type do not match.
3122- TypeError: If the number of generic arguments and actual arguments do not match.
3123- """
3124- generic_origin = get_origin (generic_type ) or generic_type
3125- actual_origin = get_origin (actual_type ) or actual_type
3126-
3127- if generic_origin is not actual_origin :
3128- if isinstance (generic_origin , TypeVar ):
3129- return {generic_origin : actual_origin }
3130- msg = f"Type mismatch: expected { generic_origin } , got { actual_origin } ."
3131- raise TypeError (msg )
3132-
3133- generic_args = get_args (generic_type )
3134- actual_args = get_args (actual_type )
3135-
3136- if len (generic_args ) != len (actual_args ):
3137- msg = f"Number of generic arguments mismatch: expected { len (generic_args )} , got { len (actual_args )} ."
3138- raise TypeError (msg )
3139-
3140- # call recursively for nested generic types and merge the results
3141- return {
3142- k : v
3143- for generic_arg , actual_arg in zip (generic_args , actual_args , strict = True )
3144- for k , v in generic_type_to_actual_type_map (generic_arg , actual_arg ).items ()
3145- }
3146-
3147-
3148- def resolve_generic_type_with_mapping (
3149- generic_type : GenericType , type_mapping : dict [TypeVar , GenericType ]
3150- ):
3151- """Resolve a generic type with a type mapping.
3152-
3153- Args:
3154- generic_type: The generic type.
3155- type_mapping: The type mapping.
3156-
3157- Returns:
3158- The resolved generic type.
3159- """
3160- if isinstance (generic_type , TypeVar ):
3161- return type_mapping .get (generic_type , generic_type )
3162-
3163- generic_origin = get_origin (generic_type ) or generic_type
3164-
3165- generic_args = get_args (generic_type )
3166-
3167- if not generic_args :
3168- return generic_type
3169-
3170- mapping_for_older_python = {
3171- list : List , # noqa: UP006
3172- set : Set , # noqa: UP006
3173- dict : Dict , # noqa: UP006
3174- tuple : Tuple , # noqa: UP006
3175- frozenset : FrozenSet , # noqa: UP006
3176- }
3177-
3178- return mapping_for_older_python .get (generic_origin , generic_origin )[
3179- tuple (
3180- resolve_generic_type_with_mapping (arg , type_mapping ) for arg in generic_args
3181- )
3182- ]
3183-
3184-
3185- def resolve_arg_type_from_return_type (
3186- arg_type : GenericType , return_type : GenericType , actual_return_type : GenericType
3187- ) -> GenericType :
3188- """Resolve the argument type from the return type.
3189-
3190- Args:
3191- arg_type: The argument type.
3192- return_type: The return type.
3193- actual_return_type: The requested return type.
3194-
3195- Returns:
3196- The argument type without the generics that are resolved.
3197- """
3198- return resolve_generic_type_with_mapping (
3199- arg_type , generic_type_to_actual_type_map (return_type , actual_return_type )
3200- )
3201-
3202-
32033114def dispatch (
32043115 field_name : str ,
32053116 var_data : VarData ,
@@ -3227,11 +3138,12 @@ def dispatch(
32273138
32283139 if result_origin_var_type in dispatchers :
32293140 fn = dispatchers [result_origin_var_type ]
3230- fn_first_arg_type = next (
3231- iter (inspect .signature (fn ).parameters .values ())
3232- ).annotation
3141+ fn_types = get_type_hints (fn )
3142+ fn_first_arg_type = fn_types .get (
3143+ next (iter (inspect .signature (fn ).parameters .values ())).name , Any
3144+ )
32333145
3234- fn_return = inspect . signature ( fn ). return_annotation
3146+ fn_return = fn_types . get ( "return" , Any )
32353147
32363148 fn_return_origin = get_origin (fn_return ) or fn_return
32373149
@@ -3257,22 +3169,17 @@ def dispatch(
32573169 msg = f"Expected generic type of { fn_first_arg_type } to be a type."
32583170 raise TypeError (msg )
32593171
3260- arg_type = arg_generic_args [0 ]
32613172 fn_return_type = fn_return_generic_args [0 ]
32623173
32633174 var = (
32643175 Var (
32653176 field_name ,
32663177 _var_data = var_data ,
3267- _var_type = resolve_arg_type_from_return_type (
3268- arg_type , fn_return_type , result_var_type
3269- ),
3178+ _var_type = fn_return_type ,
32703179 ).guess_type ()
32713180 if existing_var is None
32723181 else existing_var ._replace (
3273- _var_type = resolve_arg_type_from_return_type (
3274- arg_type , fn_return_type , result_var_type
3275- ),
3182+ _var_type = fn_return_type ,
32763183 _var_data = var_data ,
32773184 _js_expr = field_name ,
32783185 ).guess_type ()
0 commit comments