Skip to content

Commit bb7e73d

Browse files
Lendemormasenf
andauthored
fix lifespan tasks regression (#5218)
* fix lifespan tasks regression * fix lifespan issue even when transformer is used * add_cors to top asgi app * test_lifespan with FastAPI and api_transformer used * avoid test warnings when _state_manager is not initialized * call .app_instance() in the correct directory * Call the app_instance to get the asgi object during initialization * revert unnecessary chdir scope * revert more unnecessary changes --------- Co-authored-by: Masen Furer <[email protected]>
1 parent 6eec8e3 commit bb7e73d

File tree

4 files changed

+82
-22
lines changed

4 files changed

+82
-22
lines changed

reflex/app.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def __post_init__(self):
488488
set_breakpoints(self.style.pop("breakpoints"))
489489

490490
# Set up the API.
491-
self._api = Starlette(lifespan=self._run_lifespan_tasks)
491+
self._api = Starlette()
492492
App._add_cors(self._api)
493493
self._add_default_endpoints()
494494

@@ -629,6 +629,7 @@ def __call__(self) -> ASGIApp:
629629

630630
if not self._api:
631631
raise ValueError("The app has not been initialized.")
632+
632633
if self._cached_fastapi_app is not None:
633634
asgi_app = self._cached_fastapi_app
634635
asgi_app.mount("", self._api)
@@ -653,7 +654,11 @@ def __call__(self) -> ASGIApp:
653654
# Transform the asgi app.
654655
asgi_app = api_transformer(asgi_app)
655656

656-
return asgi_app
657+
top_asgi_app = Starlette(lifespan=self._run_lifespan_tasks)
658+
top_asgi_app.mount("", asgi_app)
659+
App._add_cors(top_asgi_app)
660+
661+
return top_asgi_app
657662

658663
def _add_default_endpoints(self):
659664
"""Add default api endpoints (ping)."""

reflex/testing.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
)
4646
from reflex.utils import console
4747
from reflex.utils.export import export
48+
from reflex.utils.types import ASGIApp
4849

4950
try:
5051
from selenium import webdriver
@@ -110,6 +111,7 @@ class AppHarness:
110111
app_module_path: Path
111112
app_module: types.ModuleType | None = None
112113
app_instance: reflex.App | None = None
114+
app_asgi: ASGIApp | None = None
113115
frontend_process: subprocess.Popen | None = None
114116
frontend_url: str | None = None
115117
frontend_output_thread: threading.Thread | None = None
@@ -270,11 +272,14 @@ def _initialize_app(self):
270272
# Ensure the AppHarness test does not skip State assignment due to running via pytest
271273
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
272274
os.environ[reflex.constants.APP_HARNESS_FLAG] = "true"
273-
self.app_module = reflex.utils.prerequisites.get_compiled_app(
274-
# Do not reload the module for pre-existing apps (only apps generated from source)
275-
reload=self.app_source is not None
275+
# Ensure we actually compile the app during first initialization.
276+
self.app_instance, self.app_module = (
277+
reflex.utils.prerequisites.get_and_validate_app(
278+
# Do not reload the module for pre-existing apps (only apps generated from source)
279+
reload=self.app_source is not None
280+
)
276281
)
277-
self.app_instance = self.app_module.app
282+
self.app_asgi = self.app_instance()
278283
if self.app_instance and isinstance(
279284
self.app_instance._state_manager, StateManagerRedis
280285
):
@@ -300,10 +305,10 @@ def _get_backend_shutdown_handler(self):
300305
async def _shutdown(*args, **kwargs) -> None:
301306
# ensure redis is closed before event loop
302307
if self.app_instance is not None and isinstance(
303-
self.app_instance.state_manager, StateManagerRedis
308+
self.app_instance._state_manager, StateManagerRedis
304309
):
305310
with contextlib.suppress(ValueError):
306-
await self.app_instance.state_manager.close()
311+
await self.app_instance._state_manager.close()
307312

308313
# socketio shutdown handler
309314
if self.app_instance is not None and self.app_instance.sio is not None:
@@ -323,11 +328,11 @@ async def _shutdown(*args, **kwargs) -> None:
323328
return _shutdown
324329

325330
def _start_backend(self, port: int = 0):
326-
if self.app_instance is None or self.app_instance._api is None:
331+
if self.app_asgi is None:
327332
raise RuntimeError("App was not initialized.")
328333
self.backend = uvicorn.Server(
329334
uvicorn.Config(
330-
app=self.app_instance._api,
335+
app=self.app_asgi,
331336
host="127.0.0.1",
332337
port=port,
333338
)
@@ -349,13 +354,13 @@ async def _reset_backend_state_manager(self):
349354
if (
350355
self.app_instance is not None
351356
and isinstance(
352-
self.app_instance.state_manager,
357+
self.app_instance._state_manager,
353358
StateManagerRedis,
354359
)
355360
and self.app_instance._state is not None
356361
):
357362
with contextlib.suppress(RuntimeError):
358-
await self.app_instance.state_manager.close()
363+
await self.app_instance._state_manager.close()
359364
self.app_instance._state_manager = StateManagerRedis.create(
360365
state=self.app_instance._state,
361366
)
@@ -959,12 +964,12 @@ def _wait_frontend(self):
959964
raise RuntimeError("Frontend did not start")
960965

961966
def _start_backend(self):
962-
if self.app_instance is None:
967+
if self.app_asgi is None:
963968
raise RuntimeError("App was not initialized.")
964969
environment.REFLEX_SKIP_COMPILE.set(True)
965970
self.backend = uvicorn.Server(
966971
uvicorn.Config(
967-
app=self.app_instance,
972+
app=self.app_asgi,
968973
host="127.0.0.1",
969974
port=0,
970975
workers=reflex.utils.processes.get_num_workers(),

tests/integration/test_connection_banner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Test case for displaying the connection banner when the websocket drops."""
22

3-
import functools
43
from collections.abc import Generator
54

65
import pytest
@@ -77,7 +76,7 @@ def connection_banner(
7776

7877
with AppHarness.create(
7978
root=tmp_path,
80-
app_source=functools.partial(ConnectionBanner),
79+
app_source=ConnectionBanner,
8180
app_name=(
8281
"connection_banner_reflex_cloud"
8382
if simulate_compile_context == constants.CompileContext.DEPLOY

tests/integration/test_lifespan.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test cases for the Starlette lifespan integration."""
22

3+
import functools
34
from collections.abc import Generator
45

56
import pytest
@@ -10,8 +11,15 @@
1011
from .utils import SessionStorage
1112

1213

13-
def LifespanApp():
14-
"""App with lifespan tasks and context."""
14+
def LifespanApp(
15+
mount_cached_fastapi: bool = False, mount_api_transformer: bool = False
16+
) -> None:
17+
"""App with lifespan tasks and context.
18+
19+
Args:
20+
mount_cached_fastapi: Whether to mount the cached FastAPI app.
21+
mount_api_transformer: Whether to mount the API transformer.
22+
"""
1523
import asyncio
1624
from contextlib import asynccontextmanager
1725

@@ -72,25 +80,68 @@ def index():
7280
),
7381
)
7482

75-
app = rx.App()
83+
from fastapi import FastAPI
84+
85+
app = rx.App(api_transformer=FastAPI() if mount_api_transformer else None)
86+
87+
if mount_cached_fastapi:
88+
assert app.api is not None
89+
7690
app.register_lifespan_task(lifespan_task)
7791
app.register_lifespan_task(lifespan_context, inc=2)
7892
app.add_page(index)
7993

8094

95+
@pytest.fixture(
96+
params=[False, True], ids=["no_api_transformer", "mount_api_transformer"]
97+
)
98+
def mount_api_transformer(request: pytest.FixtureRequest) -> bool:
99+
"""Whether to use api_transformer in the app.
100+
101+
Args:
102+
request: pytest fixture request object
103+
104+
Returns:
105+
bool: Whether to use api_transformer
106+
"""
107+
return request.param
108+
109+
110+
@pytest.fixture(params=[False, True], ids=["no_fastapi", "mount_cached_fastapi"])
111+
def mount_cached_fastapi(request: pytest.FixtureRequest) -> bool:
112+
"""Whether to use cached FastAPI in the app (app.api).
113+
114+
Args:
115+
request: pytest fixture request object
116+
117+
Returns:
118+
Whether to use cached FastAPI
119+
"""
120+
return request.param
121+
122+
81123
@pytest.fixture()
82-
def lifespan_app(tmp_path) -> Generator[AppHarness, None, None]:
124+
def lifespan_app(
125+
tmp_path, mount_api_transformer: bool, mount_cached_fastapi: bool
126+
) -> Generator[AppHarness, None, None]:
83127
"""Start LifespanApp app at tmp_path via AppHarness.
84128
85129
Args:
86130
tmp_path: pytest tmp_path fixture
131+
mount_api_transformer: Whether to mount the API transformer.
132+
mount_cached_fastapi: Whether to mount the cached FastAPI app.
87133
88134
Yields:
89135
running AppHarness instance
90136
"""
91137
with AppHarness.create(
92138
root=tmp_path,
93-
app_source=LifespanApp,
139+
app_source=functools.partial(
140+
LifespanApp,
141+
mount_cached_fastapi=mount_cached_fastapi,
142+
mount_api_transformer=mount_api_transformer,
143+
),
144+
app_name=f"lifespanapp_fastapi{mount_cached_fastapi}_transformer{mount_api_transformer}",
94145
) as harness:
95146
yield harness
96147

@@ -112,7 +163,7 @@ async def test_lifespan(lifespan_app: AppHarness):
112163
context_global = driver.find_element(By.ID, "context_global")
113164
task_global = driver.find_element(By.ID, "task_global")
114165

115-
assert context_global.text == "2"
166+
assert lifespan_app.poll_for_content(context_global, exp_not_equal="0") == "2"
116167
assert lifespan_app.app_module.lifespan_context_global == 2
117168

118169
original_task_global_text = task_global.text

0 commit comments

Comments
 (0)