Skip to content

Commit ab01997

Browse files
authored
delay api mounting until app finishes compile (#5184)
* delay api mounting until app finishes compile * mock _compile * add cors * add cors to asgi app * dang it darglint * refactor code to make it more readable thanks to @masenf
1 parent 27bb987 commit ab01997

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

reflex/app.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def __post_init__(self):
489489

490490
# Set up the API.
491491
self._api = Starlette(lifespan=self._run_lifespan_tasks)
492-
self._add_cors()
492+
App._add_cors(self._api)
493493
self._add_default_endpoints()
494494

495495
for clz in App.__mro__:
@@ -613,19 +613,6 @@ def __call__(self) -> ASGIApp:
613613
Returns:
614614
The backend api.
615615
"""
616-
if self._cached_fastapi_app is not None:
617-
asgi_app = self._cached_fastapi_app
618-
619-
if not asgi_app or not self._api:
620-
raise ValueError("The app has not been initialized.")
621-
622-
asgi_app.mount("", self._api)
623-
else:
624-
asgi_app = self._api
625-
626-
if not asgi_app:
627-
raise ValueError("The app has not been initialized.")
628-
629616
# For py3.9 compatibility when redis is used, we MUST add any decorator pages
630617
# before compiling the app in a thread to avoid event loop error (REF-2172).
631618
self._apply_decorated_pages()
@@ -637,9 +624,17 @@ def __call__(self) -> ASGIApp:
637624
# Force background compile errors to print eagerly
638625
lambda f: f.result()
639626
)
640-
# Wait for the compile to finish in prod mode to ensure all optional endpoints are mounted.
641-
if is_prod_mode():
642-
compile_future.result()
627+
# Wait for the compile to finish to ensure all optional endpoints are mounted.
628+
compile_future.result()
629+
630+
if not self._api:
631+
raise ValueError("The app has not been initialized.")
632+
if self._cached_fastapi_app is not None:
633+
asgi_app = self._cached_fastapi_app
634+
asgi_app.mount("", self._api)
635+
App._add_cors(asgi_app)
636+
else:
637+
asgi_app = self._api
643638

644639
if self.api_transformer is not None:
645640
api_transformers: Sequence[Starlette | Callable[[ASGIApp], ASGIApp]] = (
@@ -651,6 +646,7 @@ def __call__(self) -> ASGIApp:
651646
for api_transformer in api_transformers:
652647
if isinstance(api_transformer, Starlette):
653648
# Mount the api to the fastapi app.
649+
App._add_cors(api_transformer)
654650
api_transformer.mount("", asgi_app)
655651
asgi_app = api_transformer
656652
else:
@@ -709,11 +705,14 @@ def _add_optional_endpoints(self):
709705
if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get():
710706
self.add_all_routes_endpoint()
711707

712-
def _add_cors(self):
713-
"""Add CORS middleware to the app."""
714-
if not self._api:
715-
return
716-
self._api.add_middleware(
708+
@staticmethod
709+
def _add_cors(api: Starlette):
710+
"""Add CORS middleware to the app.
711+
712+
Args:
713+
api: The Starlette app to add CORS middleware to.
714+
"""
715+
api.add_middleware(
717716
cors.CORSMiddleware,
718717
allow_credentials=True,
719718
allow_methods=["*"],

tests/units/test_app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,6 +1502,7 @@ def test_raise_on_state():
15021502
def test_call_app():
15031503
"""Test that the app can be called."""
15041504
app = App()
1505+
app._compile = unittest.mock.Mock()
15051506
api = app()
15061507
assert isinstance(api, Starlette)
15071508

0 commit comments

Comments
 (0)