Skip to content

Commit deb1f4f

Browse files
authored
[ENG-4713] Cache pages which add states when evaluating (#4788)
* cache order of imports that create BaseState subclasses * Track which pages create State subclasses during evaluation These need to be replayed on the backend to ensure state alignment. * Clean up: use constants, remove unused code Handle closing files with contextmanager * Expose app.add_all_routes_endpoint for flexgen * Include .web/backend directory in backend.zip when exporting
1 parent 7a6c712 commit deb1f4f

File tree

7 files changed

+89
-2
lines changed

7 files changed

+89
-2
lines changed

reflex/app.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
StateManager,
101101
StateUpdate,
102102
_substate_key,
103+
all_base_state_classes,
103104
code_uses_state_contexts,
104105
)
105106
from reflex.utils import (
@@ -117,6 +118,7 @@
117118
if TYPE_CHECKING:
118119
from reflex.vars import Var
119120

121+
120122
# Define custom types.
121123
ComponentCallable = Callable[[], Component]
122124
Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
@@ -375,6 +377,9 @@ class App(MiddlewareMixin, LifespanMixin):
375377
# A map from a page route to the component to render. Users should use `add_page`.
376378
_pages: Dict[str, Component] = dataclasses.field(default_factory=dict)
377379

380+
# A mapping of pages which created states as they were being evaluated.
381+
_stateful_pages: Dict[str, None] = dataclasses.field(default_factory=dict)
382+
378383
# The backend API object.
379384
_api: FastAPI | None = None
380385

@@ -592,8 +597,10 @@ def _add_optional_endpoints(self):
592597
"""Add optional api endpoints (_upload)."""
593598
if not self.api:
594599
return
595-
596-
if Upload.is_used:
600+
upload_is_used_marker = (
601+
prerequisites.get_backend_dir() / constants.Dirs.UPLOAD_IS_USED
602+
)
603+
if Upload.is_used or upload_is_used_marker.exists():
597604
# To upload files.
598605
self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
599606

@@ -603,10 +610,15 @@ def _add_optional_endpoints(self):
603610
StaticFiles(directory=get_upload_dir()),
604611
name="uploaded_files",
605612
)
613+
614+
upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True)
615+
upload_is_used_marker.touch()
606616
if codespaces.is_running_in_codespaces():
607617
self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
608618
codespaces.auth_codespace
609619
)
620+
if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get():
621+
self.add_all_routes_endpoint()
610622

611623
def _add_cors(self):
612624
"""Add CORS middleware to the app."""
@@ -747,13 +759,19 @@ def _compile_page(self, route: str, save_page: bool = True):
747759
route: The route of the page to compile.
748760
save_page: If True, the compiled page is saved to self._pages.
749761
"""
762+
n_states_before = len(all_base_state_classes)
750763
component, enable_state = compiler.compile_unevaluated_page(
751764
route, self._unevaluated_pages[route], self._state, self.style, self.theme
752765
)
753766

767+
# Indicate that the app should use state.
754768
if enable_state:
755769
self._enable_state()
756770

771+
# Indicate that evaluating this page creates one or more state classes.
772+
if len(all_base_state_classes) > n_states_before:
773+
self._stateful_pages[route] = None
774+
757775
# Add the page.
758776
self._check_routes_conflict(route)
759777
if save_page:
@@ -1042,6 +1060,20 @@ def _compile(self, export: bool = False):
10421060
def get_compilation_time() -> str:
10431061
return str(datetime.now().time()).split(".")[0]
10441062

1063+
should_compile = self._should_compile()
1064+
backend_dir = prerequisites.get_backend_dir()
1065+
if not should_compile and backend_dir.exists():
1066+
stateful_pages_marker = backend_dir / constants.Dirs.STATEFUL_PAGES
1067+
if stateful_pages_marker.exists():
1068+
with stateful_pages_marker.open("r") as f:
1069+
stateful_pages = json.load(f)
1070+
for route in stateful_pages:
1071+
console.info(f"BE Evaluating stateful page: {route}")
1072+
self._compile_page(route, save_page=False)
1073+
self._enable_state()
1074+
self._add_optional_endpoints()
1075+
return
1076+
10451077
# Render a default 404 page if the user didn't supply one
10461078
if constants.Page404.SLUG not in self._unevaluated_pages:
10471079
self.add_page(route=constants.Page404.SLUG)
@@ -1343,6 +1375,24 @@ def _submit_work(fn: Callable[..., tuple[str, str]], *args, **kwargs):
13431375
for output_path, code in compile_results:
13441376
compiler_utils.write_page(output_path, code)
13451377

1378+
# Write list of routes that create dynamic states for backend to use.
1379+
if self._state is not None:
1380+
stateful_pages_marker = (
1381+
prerequisites.get_backend_dir() / constants.Dirs.STATEFUL_PAGES
1382+
)
1383+
stateful_pages_marker.parent.mkdir(parents=True, exist_ok=True)
1384+
with stateful_pages_marker.open("w") as f:
1385+
json.dump(list(self._stateful_pages), f)
1386+
1387+
def add_all_routes_endpoint(self):
1388+
"""Add an endpoint to the app that returns all the routes."""
1389+
if not self.api:
1390+
return
1391+
1392+
@self.api.get(str(constants.Endpoint.ALL_ROUTES))
1393+
async def all_routes():
1394+
return list(self._unevaluated_pages.keys())
1395+
13461396
@contextlib.asynccontextmanager
13471397
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
13481398
"""Modify the state out of band.

reflex/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,9 @@ class EnvironmentVariables:
713713
# Paths to exclude from the hot reload. Takes precedence over include paths. Separated by a colon.
714714
REFLEX_HOT_RELOAD_EXCLUDE_PATHS: EnvVar[List[Path]] = env_var([])
715715

716+
# Used by flexgen to enumerate the pages.
717+
REFLEX_ADD_ALL_ROUTES_ENDPOINT: EnvVar[bool] = env_var(False)
718+
716719

717720
environment = EnvironmentVariables()
718721

reflex/constants/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ class Dirs(SimpleNamespace):
5353
POSTCSS_JS = "postcss.config.js"
5454
# The name of the states directory.
5555
STATES = ".states"
56+
# Where compilation artifacts for the backend are stored.
57+
BACKEND = "backend"
58+
# JSON-encoded list of page routes that need to be evaluated on the backend.
59+
STATEFUL_PAGES = "stateful_pages.json"
60+
# Marker file indicating that upload component was used in the frontend.
61+
UPLOAD_IS_USED = "upload_is_used"
5662

5763

5864
class Reflex(SimpleNamespace):

reflex/constants/event.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class Endpoint(Enum):
1212
UPLOAD = "_upload"
1313
AUTH_CODESPACE = "auth-codespace"
1414
HEALTH = "_health"
15+
ALL_ROUTES = "_all_routes"
1516

1617
def __str__(self) -> str:
1718
"""Get the string representation of the endpoint.

reflex/state.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,9 @@ async def _resolve_delta(delta: Delta) -> Delta:
327327
return delta
328328

329329

330+
all_base_state_classes: dict[str, None] = {}
331+
332+
330333
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
331334
"""The state of the app."""
332335

@@ -624,6 +627,8 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs):
624627
cls._var_dependencies = {}
625628
cls._init_var_dependency_dicts()
626629

630+
all_base_state_classes[cls.get_full_name()] = None
631+
627632
@staticmethod
628633
def _copy_fn(fn: Callable) -> Callable:
629634
"""Copy a function. Used to copy ComputedVars and EventHandlers from mixins.
@@ -4087,6 +4092,7 @@ def reload_state_module(
40874092
for subclass in tuple(state.class_subclasses):
40884093
reload_state_module(module=module, state=subclass)
40894094
if subclass.__module__ == module and module is not None:
4095+
all_base_state_classes.pop(subclass.get_full_name(), None)
40904096
state.class_subclasses.remove(subclass)
40914097
state._always_dirty_substates.discard(subclass.get_name())
40924098
state._var_dependencies = {}

reflex/utils/build.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def _zip(
6060
dirs_to_exclude: set[str] | None = None,
6161
files_to_exclude: set[str] | None = None,
6262
top_level_dirs_to_exclude: set[str] | None = None,
63+
globs_to_include: list[str] | None = None,
6364
) -> None:
6465
"""Zip utility function.
6566
@@ -72,6 +73,7 @@ def _zip(
7273
dirs_to_exclude: The directories to exclude.
7374
files_to_exclude: The files to exclude.
7475
top_level_dirs_to_exclude: The top level directory names immediately under root_dir to exclude. Do not exclude folders by these names further in the sub-directories.
76+
globs_to_include: Apply these globs from the root_dir and always include them in the zip.
7577
7678
"""
7779
target = Path(target)
@@ -103,6 +105,13 @@ def _zip(
103105
files_to_zip += [
104106
str(root / file) for file in files if file not in files_to_exclude
105107
]
108+
if globs_to_include:
109+
for glob in globs_to_include:
110+
files_to_zip += [
111+
str(file)
112+
for file in root_dir.glob(glob)
113+
if file.name not in files_to_exclude
114+
]
106115

107116
# Create a progress bar for zipping the component.
108117
progress = Progress(
@@ -160,6 +169,9 @@ def zip_app(
160169
top_level_dirs_to_exclude={"assets"},
161170
exclude_venv_dirs=True,
162171
upload_db_file=upload_db_file,
172+
globs_to_include=[
173+
str(Path(constants.Dirs.WEB) / constants.Dirs.BACKEND / "*")
174+
],
163175
)
164176

165177

reflex/utils/prerequisites.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,15 @@ def get_states_dir() -> Path:
9999
return environment.REFLEX_STATES_WORKDIR.get()
100100

101101

102+
def get_backend_dir() -> Path:
103+
"""Get the working directory for the backend.
104+
105+
Returns:
106+
The working directory.
107+
"""
108+
return get_web_dir() / constants.Dirs.BACKEND
109+
110+
102111
def check_latest_package_version(package_name: str):
103112
"""Check if the latest version of the package is installed.
104113

0 commit comments

Comments
 (0)