Skip to content

Commit 9ebf16c

Browse files
committed
[ENG-4137] Handle generic alias passing inspect.isclass check (#4427)
On py3.9 and py3.10, `dict[str, str]` and other typing forms are kinda considered classes, but they still fail when doing `issubclass`, so specifically exclude generic aliases before calling issubclass. Fix #4424 Bonus fix: support upcasting of pydantic v1 and v2 models
1 parent 6d0fae3 commit 9ebf16c

File tree

2 files changed

+137
-7
lines changed

2 files changed

+137
-7
lines changed

reflex/state.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,7 +1748,11 @@ async def _process_event(
17481748
if value is None:
17491749
continue
17501750
hinted_args = value_inside_optional(hinted_args)
1751-
if isinstance(value, dict) and inspect.isclass(hinted_args):
1751+
if (
1752+
isinstance(value, dict)
1753+
and inspect.isclass(hinted_args)
1754+
and not types.is_generic_alias(hinted_args) # py3.9-py3.10
1755+
):
17521756
if issubclass(hinted_args, Model):
17531757
# Remove non-fields from the payload
17541758
payload[arg] = hinted_args(
@@ -1759,7 +1763,7 @@ async def _process_event(
17591763
}
17601764
)
17611765
elif dataclasses.is_dataclass(hinted_args) or issubclass(
1762-
hinted_args, Base
1766+
hinted_args, (Base, BaseModelV1, BaseModelV2)
17631767
):
17641768
payload[arg] = hinted_args(**value)
17651769
if isinstance(value, list) and (hinted_args is set or hinted_args is Set):

tests/units/test_state.py

Lines changed: 131 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,17 @@
1010
import sys
1111
import threading
1212
from textwrap import dedent
13-
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
13+
from typing import (
14+
Any,
15+
AsyncGenerator,
16+
Callable,
17+
Dict,
18+
List,
19+
Optional,
20+
Set,
21+
Tuple,
22+
Union,
23+
)
1424
from unittest.mock import AsyncMock, Mock
1525

1626
import pytest
@@ -1828,12 +1838,11 @@ async def _coro_waiter():
18281838

18291839

18301840
@pytest.fixture(scope="function")
1831-
def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
1832-
"""Mock app fixture.
1841+
def mock_app_simple(monkeypatch) -> rx.App:
1842+
"""Simple Mock app fixture.
18331843
18341844
Args:
18351845
monkeypatch: Pytest monkeypatch object.
1836-
state_manager: A state manager.
18371846
18381847
Returns:
18391848
The app, after mocking out prerequisites.get_app()
@@ -1844,7 +1853,6 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
18441853

18451854
setattr(app_module, CompileVars.APP, app)
18461855
app.state = TestState
1847-
app._state_manager = state_manager
18481856
app.event_namespace.emit = AsyncMock() # type: ignore
18491857

18501858
def _mock_get_app(*args, **kwargs):
@@ -1854,6 +1862,21 @@ def _mock_get_app(*args, **kwargs):
18541862
return app
18551863

18561864

1865+
@pytest.fixture(scope="function")
1866+
def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App:
1867+
"""Mock app fixture.
1868+
1869+
Args:
1870+
mock_app_simple: A simple mock app.
1871+
state_manager: A state manager.
1872+
1873+
Returns:
1874+
The app, after mocking out prerequisites.get_app()
1875+
"""
1876+
mock_app_simple._state_manager = state_manager
1877+
return mock_app_simple
1878+
1879+
18571880
@pytest.mark.asyncio
18581881
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
18591882
"""Test that the state proxy works.
@@ -3506,3 +3529,106 @@ class SubMixin(Mixin, mixin=True):
35063529

35073530
with pytest.raises(ReflexRuntimeError):
35083531
SubMixin()
3532+
3533+
3534+
class ReflexModel(rx.Model):
3535+
"""A model for testing."""
3536+
3537+
foo: str
3538+
3539+
3540+
class UpcastState(rx.State):
3541+
"""A state for testing upcasting."""
3542+
3543+
passed: bool = False
3544+
3545+
def rx_model(self, m: ReflexModel): # noqa: D102
3546+
assert isinstance(m, ReflexModel)
3547+
self.passed = True
3548+
3549+
def rx_base(self, o: Object): # noqa: D102
3550+
assert isinstance(o, Object)
3551+
self.passed = True
3552+
3553+
def rx_base_or_none(self, o: Optional[Object]): # noqa: D102
3554+
if o is not None:
3555+
assert isinstance(o, Object)
3556+
self.passed = True
3557+
3558+
def rx_basemodelv1(self, m: ModelV1): # noqa: D102
3559+
assert isinstance(m, ModelV1)
3560+
self.passed = True
3561+
3562+
def rx_basemodelv2(self, m: ModelV2): # noqa: D102
3563+
assert isinstance(m, ModelV2)
3564+
self.passed = True
3565+
3566+
def rx_dataclass(self, dc: ModelDC): # noqa: D102
3567+
assert isinstance(dc, ModelDC)
3568+
self.passed = True
3569+
3570+
def py_set(self, s: set): # noqa: D102
3571+
assert isinstance(s, set)
3572+
self.passed = True
3573+
3574+
def py_Set(self, s: Set): # noqa: D102
3575+
assert isinstance(s, Set)
3576+
self.passed = True
3577+
3578+
def py_tuple(self, t: tuple): # noqa: D102
3579+
assert isinstance(t, tuple)
3580+
self.passed = True
3581+
3582+
def py_Tuple(self, t: Tuple): # noqa: D102
3583+
assert isinstance(t, tuple)
3584+
self.passed = True
3585+
3586+
def py_dict(self, d: dict[str, str]): # noqa: D102
3587+
assert isinstance(d, dict)
3588+
self.passed = True
3589+
3590+
def py_list(self, ls: list[str]): # noqa: D102
3591+
assert isinstance(ls, list)
3592+
self.passed = True
3593+
3594+
def py_Any(self, a: Any): # noqa: D102
3595+
assert isinstance(a, list)
3596+
self.passed = True
3597+
3598+
def py_unresolvable(self, u: "Unresolvable"): # noqa: D102, F821 # type: ignore
3599+
assert isinstance(u, list)
3600+
self.passed = True
3601+
3602+
3603+
@pytest.mark.asyncio
3604+
@pytest.mark.usefixtures("mock_app_simple")
3605+
@pytest.mark.parametrize(
3606+
("handler", "payload"),
3607+
[
3608+
(UpcastState.rx_model, {"m": {"foo": "bar"}}),
3609+
(UpcastState.rx_base, {"o": {"foo": "bar"}}),
3610+
(UpcastState.rx_base_or_none, {"o": {"foo": "bar"}}),
3611+
(UpcastState.rx_base_or_none, {"o": None}),
3612+
(UpcastState.rx_basemodelv1, {"m": {"foo": "bar"}}),
3613+
(UpcastState.rx_basemodelv2, {"m": {"foo": "bar"}}),
3614+
(UpcastState.rx_dataclass, {"dc": {"foo": "bar"}}),
3615+
(UpcastState.py_set, {"s": ["foo", "foo"]}),
3616+
(UpcastState.py_Set, {"s": ["foo", "foo"]}),
3617+
(UpcastState.py_tuple, {"t": ["foo", "foo"]}),
3618+
(UpcastState.py_Tuple, {"t": ["foo", "foo"]}),
3619+
(UpcastState.py_dict, {"d": {"foo": "bar"}}),
3620+
(UpcastState.py_list, {"ls": ["foo", "foo"]}),
3621+
(UpcastState.py_Any, {"a": ["foo"]}),
3622+
(UpcastState.py_unresolvable, {"u": ["foo"]}),
3623+
],
3624+
)
3625+
async def test_upcast_event_handler_arg(handler, payload):
3626+
"""Test that upcast event handler args work correctly.
3627+
3628+
Args:
3629+
handler: The handler to test.
3630+
payload: The payload to test.
3631+
"""
3632+
state = UpcastState()
3633+
async for update in state._process_event(handler, state, payload):
3634+
assert update.delta == {UpcastState.get_full_name(): {"passed": True}}

0 commit comments

Comments
 (0)