1010import sys
1111import threading
1212from textwrap import dedent
13- from typing import Any , AsyncGenerator , Callable , Dict , List , Optional , Union
13+ from typing import (
14+ Any ,
15+ AsyncGenerator ,
16+ Callable ,
17+ Dict ,
18+ List ,
19+ Optional ,
20+ Set ,
21+ Tuple ,
22+ Union ,
23+ )
1424from unittest .mock import AsyncMock , Mock
1525
1626import pytest
@@ -1828,12 +1838,11 @@ async def _coro_waiter():
18281838
18291839
18301840@pytest .fixture (scope = "function" )
1831- def mock_app (monkeypatch , state_manager : StateManager ) -> rx .App :
1832- """Mock app fixture.
1841+ def mock_app_simple (monkeypatch ) -> rx .App :
1842+ """Simple Mock app fixture.
18331843
18341844 Args:
18351845 monkeypatch: Pytest monkeypatch object.
1836- state_manager: A state manager.
18371846
18381847 Returns:
18391848 The app, after mocking out prerequisites.get_app()
@@ -1844,7 +1853,6 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
18441853
18451854 setattr (app_module , CompileVars .APP , app )
18461855 app .state = TestState
1847- app ._state_manager = state_manager
18481856 app .event_namespace .emit = AsyncMock () # type: ignore
18491857
18501858 def _mock_get_app (* args , ** kwargs ):
@@ -1854,6 +1862,21 @@ def _mock_get_app(*args, **kwargs):
18541862 return app
18551863
18561864
1865+ @pytest .fixture (scope = "function" )
1866+ def mock_app (mock_app_simple : rx .App , state_manager : StateManager ) -> rx .App :
1867+ """Mock app fixture.
1868+
1869+ Args:
1870+ mock_app_simple: A simple mock app.
1871+ state_manager: A state manager.
1872+
1873+ Returns:
1874+ The app, after mocking out prerequisites.get_app()
1875+ """
1876+ mock_app_simple ._state_manager = state_manager
1877+ return mock_app_simple
1878+
1879+
18571880@pytest .mark .asyncio
18581881async def test_state_proxy (grandchild_state : GrandchildState , mock_app : rx .App ):
18591882 """Test that the state proxy works.
@@ -3506,3 +3529,106 @@ class SubMixin(Mixin, mixin=True):
35063529
35073530 with pytest .raises (ReflexRuntimeError ):
35083531 SubMixin ()
3532+
3533+
3534+ class ReflexModel (rx .Model ):
3535+ """A model for testing."""
3536+
3537+ foo : str
3538+
3539+
3540+ class UpcastState (rx .State ):
3541+ """A state for testing upcasting."""
3542+
3543+ passed : bool = False
3544+
3545+ def rx_model (self , m : ReflexModel ): # noqa: D102
3546+ assert isinstance (m , ReflexModel )
3547+ self .passed = True
3548+
3549+ def rx_base (self , o : Object ): # noqa: D102
3550+ assert isinstance (o , Object )
3551+ self .passed = True
3552+
3553+ def rx_base_or_none (self , o : Optional [Object ]): # noqa: D102
3554+ if o is not None :
3555+ assert isinstance (o , Object )
3556+ self .passed = True
3557+
3558+ def rx_basemodelv1 (self , m : ModelV1 ): # noqa: D102
3559+ assert isinstance (m , ModelV1 )
3560+ self .passed = True
3561+
3562+ def rx_basemodelv2 (self , m : ModelV2 ): # noqa: D102
3563+ assert isinstance (m , ModelV2 )
3564+ self .passed = True
3565+
3566+ def rx_dataclass (self , dc : ModelDC ): # noqa: D102
3567+ assert isinstance (dc , ModelDC )
3568+ self .passed = True
3569+
3570+ def py_set (self , s : set ): # noqa: D102
3571+ assert isinstance (s , set )
3572+ self .passed = True
3573+
3574+ def py_Set (self , s : Set ): # noqa: D102
3575+ assert isinstance (s , Set )
3576+ self .passed = True
3577+
3578+ def py_tuple (self , t : tuple ): # noqa: D102
3579+ assert isinstance (t , tuple )
3580+ self .passed = True
3581+
3582+ def py_Tuple (self , t : Tuple ): # noqa: D102
3583+ assert isinstance (t , tuple )
3584+ self .passed = True
3585+
3586+ def py_dict (self , d : dict [str , str ]): # noqa: D102
3587+ assert isinstance (d , dict )
3588+ self .passed = True
3589+
3590+ def py_list (self , ls : list [str ]): # noqa: D102
3591+ assert isinstance (ls , list )
3592+ self .passed = True
3593+
3594+ def py_Any (self , a : Any ): # noqa: D102
3595+ assert isinstance (a , list )
3596+ self .passed = True
3597+
3598+ def py_unresolvable (self , u : "Unresolvable" ): # noqa: D102, F821 # type: ignore
3599+ assert isinstance (u , list )
3600+ self .passed = True
3601+
3602+
3603+ @pytest .mark .asyncio
3604+ @pytest .mark .usefixtures ("mock_app_simple" )
3605+ @pytest .mark .parametrize (
3606+ ("handler" , "payload" ),
3607+ [
3608+ (UpcastState .rx_model , {"m" : {"foo" : "bar" }}),
3609+ (UpcastState .rx_base , {"o" : {"foo" : "bar" }}),
3610+ (UpcastState .rx_base_or_none , {"o" : {"foo" : "bar" }}),
3611+ (UpcastState .rx_base_or_none , {"o" : None }),
3612+ (UpcastState .rx_basemodelv1 , {"m" : {"foo" : "bar" }}),
3613+ (UpcastState .rx_basemodelv2 , {"m" : {"foo" : "bar" }}),
3614+ (UpcastState .rx_dataclass , {"dc" : {"foo" : "bar" }}),
3615+ (UpcastState .py_set , {"s" : ["foo" , "foo" ]}),
3616+ (UpcastState .py_Set , {"s" : ["foo" , "foo" ]}),
3617+ (UpcastState .py_tuple , {"t" : ["foo" , "foo" ]}),
3618+ (UpcastState .py_Tuple , {"t" : ["foo" , "foo" ]}),
3619+ (UpcastState .py_dict , {"d" : {"foo" : "bar" }}),
3620+ (UpcastState .py_list , {"ls" : ["foo" , "foo" ]}),
3621+ (UpcastState .py_Any , {"a" : ["foo" ]}),
3622+ (UpcastState .py_unresolvable , {"u" : ["foo" ]}),
3623+ ],
3624+ )
3625+ async def test_upcast_event_handler_arg (handler , payload ):
3626+ """Test that upcast event handler args work correctly.
3627+
3628+ Args:
3629+ handler: The handler to test.
3630+ payload: The payload to test.
3631+ """
3632+ state = UpcastState ()
3633+ async for update in state ._process_event (handler , state , payload ):
3634+ assert update .delta == {UpcastState .get_full_name (): {"passed" : True }}
0 commit comments