Skip to content

Commit 9a7858b

Browse files
authored
ENG-8509: computed var dependency tracking for locally imported states (#6035)
* ENG-8509: computed var dependency tracking for locally imported states Import dep tracking: * all forms of function-local imports should be usable in `get_state` * get deps from get_state through chained attributes * py3.11 compatibility * skip get_var_value test on py3.10 * Fix more edge cases with multiple imports and nested list comprehensions
1 parent 1f255e8 commit 9a7858b

File tree

3 files changed

+364
-24
lines changed

3 files changed

+364
-24
lines changed

reflex/vars/dep_tracking.py

Lines changed: 163 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import dataclasses
77
import dis
88
import enum
9+
import importlib
910
import inspect
1011
import sys
11-
from types import CellType, CodeType, FunctionType
12+
from types import CellType, CodeType, FunctionType, ModuleType
1213
from typing import TYPE_CHECKING, Any, ClassVar, cast
1314

1415
from reflex.utils.exceptions import VarValueError
@@ -43,9 +44,38 @@ class ScanStatus(enum.Enum):
4344
SCANNING = enum.auto()
4445
GETTING_ATTR = enum.auto()
4546
GETTING_STATE = enum.auto()
47+
GETTING_STATE_POST_AWAIT = enum.auto()
4648
GETTING_VAR = enum.auto()
4749

4850

51+
class UntrackedLocalVarError(VarValueError):
52+
"""Raised when a local variable is referenced, but it is not tracked in the current scope."""
53+
54+
55+
def assert_base_state(
56+
local_value: Any,
57+
local_name: str | None = None,
58+
) -> type[BaseState]:
59+
"""Assert that a local variable is a BaseState subclass.
60+
61+
Args:
62+
local_value: The value of the local variable to check.
63+
local_name: The name of the local variable to check.
64+
65+
Returns:
66+
The local variable value if it is a BaseState subclass.
67+
68+
Raises:
69+
VarValueError: If the object is not a BaseState subclass.
70+
"""
71+
from reflex.state import BaseState
72+
73+
if not isinstance(local_value, type) or not issubclass(local_value, BaseState):
74+
msg = f"Cannot determine dependencies in fetched state {local_name!r}: {local_value!r} is not a BaseState."
75+
raise VarValueError(msg)
76+
return local_value
77+
78+
4979
@dataclasses.dataclass
5080
class DependencyTracker:
5181
"""State machine for identifying state attributes that are accessed by a function."""
@@ -58,10 +88,15 @@ class DependencyTracker:
5888
scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING)
5989
top_of_stack: str | None = dataclasses.field(default=None)
6090

61-
tracked_locals: dict[str, type[BaseState]] = dataclasses.field(default_factory=dict)
91+
tracked_locals: dict[str, type[BaseState] | ModuleType] = dataclasses.field(
92+
default_factory=dict
93+
)
6294

63-
_getting_state_class: type[BaseState] | None = dataclasses.field(default=None)
95+
_getting_state_class: type[BaseState] | ModuleType | None = dataclasses.field(
96+
default=None
97+
)
6498
_get_var_value_positions: dis.Positions | None = dataclasses.field(default=None)
99+
_last_import_name: str | None = dataclasses.field(default=None)
65100

66101
INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"]
67102

@@ -90,6 +125,26 @@ def _merge_deps(self, tracker: DependencyTracker) -> None:
90125
for state_name, dep_name in tracker.dependencies.items():
91126
self.dependencies.setdefault(state_name, set()).update(dep_name)
92127

128+
def get_tracked_local(self, local_name: str) -> type[BaseState] | ModuleType:
129+
"""Get the value of a local name tracked in the current function scope.
130+
131+
Args:
132+
local_name: The name of the local variable to fetch.
133+
134+
Returns:
135+
The value of local name tracked in the current scope (a referenced
136+
BaseState subclass or imported module).
137+
138+
Raises:
139+
UntrackedLocalVarError: If the local variable is not being tracked.
140+
"""
141+
try:
142+
local_value = self.tracked_locals[local_name]
143+
except KeyError as ke:
144+
msg = f"{local_name!r} is not tracked in the current scope."
145+
raise UntrackedLocalVarError(msg) from ke
146+
return local_value
147+
93148
def load_attr_or_method(self, instruction: dis.Instruction) -> None:
94149
"""Handle loading an attribute or method from the object on top of the stack.
95150
@@ -100,7 +155,8 @@ def load_attr_or_method(self, instruction: dis.Instruction) -> None:
100155
instruction: The dis instruction to process.
101156
102157
Raises:
103-
VarValueError: if the attribute is an disallowed name.
158+
VarValueError: if the attribute is an disallowed name or attribute
159+
does not reference a BaseState.
104160
"""
105161
from .base import ComputedVar
106162

@@ -122,7 +178,8 @@ def load_attr_or_method(self, instruction: dis.Instruction) -> None:
122178
self.scan_status = ScanStatus.SCANNING
123179
if not self.top_of_stack:
124180
return
125-
target_state = self.tracked_locals[self.top_of_stack]
181+
target_obj = self.get_tracked_local(self.top_of_stack)
182+
target_state = assert_base_state(target_obj, local_name=self.top_of_stack)
126183
try:
127184
ref_obj = getattr(target_state, instruction.argval)
128185
except AttributeError:
@@ -190,15 +247,14 @@ def handle_getting_state(self, instruction: dis.Instruction) -> None:
190247
Raises:
191248
VarValueError: if the state class cannot be determined from the instruction.
192249
"""
193-
from reflex.state import BaseState
194-
195-
if instruction.opname in ("LOAD_FAST", "LOAD_FAST_BORROW"):
196-
msg = f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
197-
raise VarValueError(msg)
198250
if isinstance(self.func, CodeType):
199251
msg = "Dependency detection cannot identify get_state class from a code object."
200252
raise VarValueError(msg)
201-
if instruction.opname == "LOAD_GLOBAL":
253+
if instruction.opname in ("LOAD_FAST", "LOAD_FAST_BORROW"):
254+
self._getting_state_class = self.get_tracked_local(
255+
local_name=instruction.argval,
256+
)
257+
elif instruction.opname == "LOAD_GLOBAL":
202258
# Special case: referencing state class from global scope.
203259
try:
204260
self._getting_state_class = self._get_globals()[instruction.argval]
@@ -212,16 +268,43 @@ def handle_getting_state(self, instruction: dis.Instruction) -> None:
212268
except (ValueError, KeyError) as ve:
213269
msg = f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?"
214270
raise VarValueError(msg) from ve
215-
elif instruction.opname == "STORE_FAST":
271+
elif instruction.opname in ("LOAD_ATTR", "LOAD_METHOD"):
272+
self._getting_state_class = getattr(
273+
self._getting_state_class,
274+
instruction.argval,
275+
)
276+
elif instruction.opname == "GET_AWAITABLE":
277+
# Now inside the `await` machinery, subsequent instructions
278+
# operate on the result of the `get_state` call.
279+
self.scan_status = ScanStatus.GETTING_STATE_POST_AWAIT
280+
if self._getting_state_class is not None:
281+
self.top_of_stack = "_"
282+
self.tracked_locals[self.top_of_stack] = self._getting_state_class
283+
self._getting_state_class = None
284+
285+
def handle_getting_state_post_await(self, instruction: dis.Instruction) -> None:
286+
"""Handle bytecode analysis after `get_state` was called in the function.
287+
288+
This function is called _after_ awaiting self.get_state to capture the
289+
local variable holding the state instance or directly record access to
290+
attributes accessed on the result of get_state.
291+
292+
Args:
293+
instruction: The dis instruction to process.
294+
295+
Raises:
296+
VarValueError: if the state class cannot be determined from the instruction.
297+
"""
298+
if instruction.opname == "STORE_FAST" and self.top_of_stack:
216299
# Storing the result of get_state in a local variable.
217-
if not isinstance(self._getting_state_class, type) or not issubclass(
218-
self._getting_state_class, BaseState
219-
):
220-
msg = f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
221-
raise VarValueError(msg)
222-
self.tracked_locals[instruction.argval] = self._getting_state_class
300+
self.tracked_locals[instruction.argval] = self.tracked_locals.pop(
301+
self.top_of_stack
302+
)
303+
self.top_of_stack = None
223304
self.scan_status = ScanStatus.SCANNING
224-
self._getting_state_class = None
305+
elif instruction.opname in ("LOAD_ATTR", "LOAD_METHOD"):
306+
# Attribute access on an inline `get_state`, not assigned to a variable.
307+
self.load_attr_or_method(instruction)
225308

226309
def _eval_var(self, positions: dis.Positions) -> Var:
227310
"""Evaluate instructions from the wrapped function to get the Var object.
@@ -262,8 +345,12 @@ def _eval_var(self, positions: dis.Positions) -> Var:
262345
])
263346
else:
264347
snipped_source = source[0][start_column:end_column]
265-
# Evaluate the string in the context of the function's globals and closure.
266-
return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
348+
# Evaluate the string in the context of the function's globals, closure and tracked local scope.
349+
return eval(
350+
f"({snipped_source})",
351+
self._get_globals(),
352+
{**self._get_closure(), **self.tracked_locals},
353+
)
267354

268355
def handle_getting_var(self, instruction: dis.Instruction) -> None:
269356
"""Handle bytecode analysis when `get_var_value` was called in the function.
@@ -304,16 +391,38 @@ def _populate_dependencies(self) -> None:
304391
for instruction in dis.get_instructions(self.func):
305392
if self.scan_status == ScanStatus.GETTING_STATE:
306393
self.handle_getting_state(instruction)
394+
elif self.scan_status == ScanStatus.GETTING_STATE_POST_AWAIT:
395+
self.handle_getting_state_post_await(instruction)
307396
elif self.scan_status == ScanStatus.GETTING_VAR:
308397
self.handle_getting_var(instruction)
309398
elif (
310-
instruction.opname in ("LOAD_FAST", "LOAD_DEREF", "LOAD_FAST_BORROW")
399+
instruction.opname
400+
in (
401+
"LOAD_FAST",
402+
"LOAD_DEREF",
403+
"LOAD_FAST_BORROW",
404+
"LOAD_FAST_CHECK",
405+
"LOAD_FAST_AND_CLEAR",
406+
)
311407
and instruction.argval in self.tracked_locals
312408
):
313409
# bytecode loaded the class instance to the top of stack, next load instruction
314410
# is referencing an attribute on self
315411
self.top_of_stack = instruction.argval
316412
self.scan_status = ScanStatus.GETTING_ATTR
413+
elif (
414+
instruction.opname
415+
in (
416+
"LOAD_FAST_LOAD_FAST",
417+
"LOAD_FAST_BORROW_LOAD_FAST_BORROW",
418+
"STORE_FAST_LOAD_FAST",
419+
)
420+
and instruction.argval[-1] in self.tracked_locals
421+
):
422+
# Double LOAD_FAST family instructions load multiple values onto the stack,
423+
# the last value in the argval list is the top of the stack.
424+
self.top_of_stack = instruction.argval[-1]
425+
self.scan_status = ScanStatus.GETTING_ATTR
317426
elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in (
318427
"LOAD_ATTR",
319428
"LOAD_METHOD",
@@ -332,3 +441,35 @@ def _populate_dependencies(self) -> None:
332441
tracked_locals=self.tracked_locals,
333442
)
334443
)
444+
elif instruction.opname == "IMPORT_NAME" and instruction.argval is not None:
445+
self._last_import_name = instruction.argval
446+
importlib.import_module(instruction.argval)
447+
top_module_name = instruction.argval.split(".")[0]
448+
self.tracked_locals[instruction.argval] = sys.modules[top_module_name]
449+
self.top_of_stack = instruction.argval
450+
elif instruction.opname == "IMPORT_FROM":
451+
if not self._last_import_name:
452+
msg = f"Cannot find package associated with import {instruction.argval} in {self.func!r}."
453+
raise VarValueError(msg)
454+
if instruction.argval in self._last_import_name.split("."):
455+
# `import ... as ...` case:
456+
# import from interim package, update tracked_locals for the last imported name.
457+
self.tracked_locals[self._last_import_name] = getattr(
458+
self.tracked_locals[self._last_import_name], instruction.argval
459+
)
460+
continue
461+
# Importing a name from a package/module.
462+
if self._last_import_name is not None and self.top_of_stack:
463+
# The full import name does NOT end up in scope for a `from ... import`.
464+
self.tracked_locals.pop(self._last_import_name)
465+
self.tracked_locals[instruction.argval] = getattr(
466+
importlib.import_module(self._last_import_name),
467+
instruction.argval,
468+
)
469+
# If we see a STORE_FAST, we can assign the top of stack to an aliased name.
470+
self.top_of_stack = instruction.argval
471+
elif instruction.opname == "STORE_FAST" and self.top_of_stack is not None:
472+
self.tracked_locals[instruction.argval] = self.tracked_locals.pop(
473+
self.top_of_stack
474+
)
475+
self.top_of_stack = None

tests/units/states/mutation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ def reassign_mutables(self):
4646
"mod_third_key": {"key": "value"},
4747
}
4848
self.test_set = {1, 2, 3, 4, "five"}
49+
50+
def _get_array(self) -> list[str | int | list | dict[str, str]]:
51+
return self.array

0 commit comments

Comments
 (0)