66import dataclasses
77import dis
88import enum
9+ import importlib
910import inspect
1011import sys
11- from types import CellType , CodeType , FunctionType
12+ from types import CellType , CodeType , FunctionType , ModuleType
1213from typing import TYPE_CHECKING , Any , ClassVar , cast
1314
1415from reflex .utils .exceptions import VarValueError
@@ -43,9 +44,38 @@ class ScanStatus(enum.Enum):
4344 SCANNING = enum .auto ()
4445 GETTING_ATTR = enum .auto ()
4546 GETTING_STATE = enum .auto ()
47+ GETTING_STATE_POST_AWAIT = enum .auto ()
4648 GETTING_VAR = enum .auto ()
4749
4850
51+ class UntrackedLocalVarError (VarValueError ):
52+ """Raised when a local variable is referenced, but it is not tracked in the current scope."""
53+
54+
55+ def assert_base_state (
56+ local_value : Any ,
57+ local_name : str | None = None ,
58+ ) -> type [BaseState ]:
59+ """Assert that a local variable is a BaseState subclass.
60+
61+ Args:
62+ local_value: The value of the local variable to check.
63+ local_name: The name of the local variable to check.
64+
65+ Returns:
66+ The local variable value if it is a BaseState subclass.
67+
68+ Raises:
69+ VarValueError: If the object is not a BaseState subclass.
70+ """
71+ from reflex .state import BaseState
72+
73+ if not isinstance (local_value , type ) or not issubclass (local_value , BaseState ):
74+ msg = f"Cannot determine dependencies in fetched state { local_name !r} : { local_value !r} is not a BaseState."
75+ raise VarValueError (msg )
76+ return local_value
77+
78+
4979@dataclasses .dataclass
5080class DependencyTracker :
5181 """State machine for identifying state attributes that are accessed by a function."""
@@ -58,10 +88,15 @@ class DependencyTracker:
5888 scan_status : ScanStatus = dataclasses .field (default = ScanStatus .SCANNING )
5989 top_of_stack : str | None = dataclasses .field (default = None )
6090
61- tracked_locals : dict [str , type [BaseState ]] = dataclasses .field (default_factory = dict )
91+ tracked_locals : dict [str , type [BaseState ] | ModuleType ] = dataclasses .field (
92+ default_factory = dict
93+ )
6294
63- _getting_state_class : type [BaseState ] | None = dataclasses .field (default = None )
95+ _getting_state_class : type [BaseState ] | ModuleType | None = dataclasses .field (
96+ default = None
97+ )
6498 _get_var_value_positions : dis .Positions | None = dataclasses .field (default = None )
99+ _last_import_name : str | None = dataclasses .field (default = None )
65100
66101 INVALID_NAMES : ClassVar [list [str ]] = ["parent_state" , "substates" , "get_substate" ]
67102
@@ -90,6 +125,26 @@ def _merge_deps(self, tracker: DependencyTracker) -> None:
90125 for state_name , dep_name in tracker .dependencies .items ():
91126 self .dependencies .setdefault (state_name , set ()).update (dep_name )
92127
128+ def get_tracked_local (self , local_name : str ) -> type [BaseState ] | ModuleType :
129+ """Get the value of a local name tracked in the current function scope.
130+
131+ Args:
132+ local_name: The name of the local variable to fetch.
133+
134+ Returns:
135+ The value of local name tracked in the current scope (a referenced
136+ BaseState subclass or imported module).
137+
138+ Raises:
139+ UntrackedLocalVarError: If the local variable is not being tracked.
140+ """
141+ try :
142+ local_value = self .tracked_locals [local_name ]
143+ except KeyError as ke :
144+ msg = f"{ local_name !r} is not tracked in the current scope."
145+ raise UntrackedLocalVarError (msg ) from ke
146+ return local_value
147+
93148 def load_attr_or_method (self , instruction : dis .Instruction ) -> None :
94149 """Handle loading an attribute or method from the object on top of the stack.
95150
@@ -100,7 +155,8 @@ def load_attr_or_method(self, instruction: dis.Instruction) -> None:
100155 instruction: The dis instruction to process.
101156
102157 Raises:
103- VarValueError: if the attribute is an disallowed name.
158+ VarValueError: if the attribute is an disallowed name or attribute
159+ does not reference a BaseState.
104160 """
105161 from .base import ComputedVar
106162
@@ -122,7 +178,8 @@ def load_attr_or_method(self, instruction: dis.Instruction) -> None:
122178 self .scan_status = ScanStatus .SCANNING
123179 if not self .top_of_stack :
124180 return
125- target_state = self .tracked_locals [self .top_of_stack ]
181+ target_obj = self .get_tracked_local (self .top_of_stack )
182+ target_state = assert_base_state (target_obj , local_name = self .top_of_stack )
126183 try :
127184 ref_obj = getattr (target_state , instruction .argval )
128185 except AttributeError :
@@ -190,15 +247,14 @@ def handle_getting_state(self, instruction: dis.Instruction) -> None:
190247 Raises:
191248 VarValueError: if the state class cannot be determined from the instruction.
192249 """
193- from reflex .state import BaseState
194-
195- if instruction .opname in ("LOAD_FAST" , "LOAD_FAST_BORROW" ):
196- msg = f"Dependency detection cannot identify get_state class from local var { instruction .argval } ."
197- raise VarValueError (msg )
198250 if isinstance (self .func , CodeType ):
199251 msg = "Dependency detection cannot identify get_state class from a code object."
200252 raise VarValueError (msg )
201- if instruction .opname == "LOAD_GLOBAL" :
253+ if instruction .opname in ("LOAD_FAST" , "LOAD_FAST_BORROW" ):
254+ self ._getting_state_class = self .get_tracked_local (
255+ local_name = instruction .argval ,
256+ )
257+ elif instruction .opname == "LOAD_GLOBAL" :
202258 # Special case: referencing state class from global scope.
203259 try :
204260 self ._getting_state_class = self ._get_globals ()[instruction .argval ]
@@ -212,16 +268,43 @@ def handle_getting_state(self, instruction: dis.Instruction) -> None:
212268 except (ValueError , KeyError ) as ve :
213269 msg = f"Cached var { self !s} cannot access arbitrary state `{ instruction .argval } `, is it defined yet?"
214270 raise VarValueError (msg ) from ve
215- elif instruction .opname == "STORE_FAST" :
271+ elif instruction .opname in ("LOAD_ATTR" , "LOAD_METHOD" ):
272+ self ._getting_state_class = getattr (
273+ self ._getting_state_class ,
274+ instruction .argval ,
275+ )
276+ elif instruction .opname == "GET_AWAITABLE" :
277+ # Now inside the `await` machinery, subsequent instructions
278+ # operate on the result of the `get_state` call.
279+ self .scan_status = ScanStatus .GETTING_STATE_POST_AWAIT
280+ if self ._getting_state_class is not None :
281+ self .top_of_stack = "_"
282+ self .tracked_locals [self .top_of_stack ] = self ._getting_state_class
283+ self ._getting_state_class = None
284+
285+ def handle_getting_state_post_await (self , instruction : dis .Instruction ) -> None :
286+ """Handle bytecode analysis after `get_state` was called in the function.
287+
288+ This function is called _after_ awaiting self.get_state to capture the
289+ local variable holding the state instance or directly record access to
290+ attributes accessed on the result of get_state.
291+
292+ Args:
293+ instruction: The dis instruction to process.
294+
295+ Raises:
296+ VarValueError: if the state class cannot be determined from the instruction.
297+ """
298+ if instruction .opname == "STORE_FAST" and self .top_of_stack :
216299 # Storing the result of get_state in a local variable.
217- if not isinstance (self ._getting_state_class , type ) or not issubclass (
218- self ._getting_state_class , BaseState
219- ):
220- msg = f"Cached var { self !s} cannot determine dependencies in fetched state `{ instruction .argval } `."
221- raise VarValueError (msg )
222- self .tracked_locals [instruction .argval ] = self ._getting_state_class
300+ self .tracked_locals [instruction .argval ] = self .tracked_locals .pop (
301+ self .top_of_stack
302+ )
303+ self .top_of_stack = None
223304 self .scan_status = ScanStatus .SCANNING
224- self ._getting_state_class = None
305+ elif instruction .opname in ("LOAD_ATTR" , "LOAD_METHOD" ):
306+ # Attribute access on an inline `get_state`, not assigned to a variable.
307+ self .load_attr_or_method (instruction )
225308
226309 def _eval_var (self , positions : dis .Positions ) -> Var :
227310 """Evaluate instructions from the wrapped function to get the Var object.
@@ -262,8 +345,12 @@ def _eval_var(self, positions: dis.Positions) -> Var:
262345 ])
263346 else :
264347 snipped_source = source [0 ][start_column :end_column ]
265- # Evaluate the string in the context of the function's globals and closure.
266- return eval (f"({ snipped_source } )" , self ._get_globals (), self ._get_closure ())
348+ # Evaluate the string in the context of the function's globals, closure and tracked local scope.
349+ return eval (
350+ f"({ snipped_source } )" ,
351+ self ._get_globals (),
352+ {** self ._get_closure (), ** self .tracked_locals },
353+ )
267354
268355 def handle_getting_var (self , instruction : dis .Instruction ) -> None :
269356 """Handle bytecode analysis when `get_var_value` was called in the function.
@@ -304,16 +391,38 @@ def _populate_dependencies(self) -> None:
304391 for instruction in dis .get_instructions (self .func ):
305392 if self .scan_status == ScanStatus .GETTING_STATE :
306393 self .handle_getting_state (instruction )
394+ elif self .scan_status == ScanStatus .GETTING_STATE_POST_AWAIT :
395+ self .handle_getting_state_post_await (instruction )
307396 elif self .scan_status == ScanStatus .GETTING_VAR :
308397 self .handle_getting_var (instruction )
309398 elif (
310- instruction .opname in ("LOAD_FAST" , "LOAD_DEREF" , "LOAD_FAST_BORROW" )
399+ instruction .opname
400+ in (
401+ "LOAD_FAST" ,
402+ "LOAD_DEREF" ,
403+ "LOAD_FAST_BORROW" ,
404+ "LOAD_FAST_CHECK" ,
405+ "LOAD_FAST_AND_CLEAR" ,
406+ )
311407 and instruction .argval in self .tracked_locals
312408 ):
313409 # bytecode loaded the class instance to the top of stack, next load instruction
314410 # is referencing an attribute on self
315411 self .top_of_stack = instruction .argval
316412 self .scan_status = ScanStatus .GETTING_ATTR
413+ elif (
414+ instruction .opname
415+ in (
416+ "LOAD_FAST_LOAD_FAST" ,
417+ "LOAD_FAST_BORROW_LOAD_FAST_BORROW" ,
418+ "STORE_FAST_LOAD_FAST" ,
419+ )
420+ and instruction .argval [- 1 ] in self .tracked_locals
421+ ):
422+ # Double LOAD_FAST family instructions load multiple values onto the stack,
423+ # the last value in the argval list is the top of the stack.
424+ self .top_of_stack = instruction .argval [- 1 ]
425+ self .scan_status = ScanStatus .GETTING_ATTR
317426 elif self .scan_status == ScanStatus .GETTING_ATTR and instruction .opname in (
318427 "LOAD_ATTR" ,
319428 "LOAD_METHOD" ,
@@ -332,3 +441,35 @@ def _populate_dependencies(self) -> None:
332441 tracked_locals = self .tracked_locals ,
333442 )
334443 )
444+ elif instruction .opname == "IMPORT_NAME" and instruction .argval is not None :
445+ self ._last_import_name = instruction .argval
446+ importlib .import_module (instruction .argval )
447+ top_module_name = instruction .argval .split ("." )[0 ]
448+ self .tracked_locals [instruction .argval ] = sys .modules [top_module_name ]
449+ self .top_of_stack = instruction .argval
450+ elif instruction .opname == "IMPORT_FROM" :
451+ if not self ._last_import_name :
452+ msg = f"Cannot find package associated with import { instruction .argval } in { self .func !r} ."
453+ raise VarValueError (msg )
454+ if instruction .argval in self ._last_import_name .split ("." ):
455+ # `import ... as ...` case:
456+ # import from interim package, update tracked_locals for the last imported name.
457+ self .tracked_locals [self ._last_import_name ] = getattr (
458+ self .tracked_locals [self ._last_import_name ], instruction .argval
459+ )
460+ continue
461+ # Importing a name from a package/module.
462+ if self ._last_import_name is not None and self .top_of_stack :
463+ # The full import name does NOT end up in scope for a `from ... import`.
464+ self .tracked_locals .pop (self ._last_import_name )
465+ self .tracked_locals [instruction .argval ] = getattr (
466+ importlib .import_module (self ._last_import_name ),
467+ instruction .argval ,
468+ )
469+ # If we see a STORE_FAST, we can assign the top of stack to an aliased name.
470+ self .top_of_stack = instruction .argval
471+ elif instruction .opname == "STORE_FAST" and self .top_of_stack is not None :
472+ self .tracked_locals [instruction .argval ] = self .tracked_locals .pop (
473+ self .top_of_stack
474+ )
475+ self .top_of_stack = None
0 commit comments