Skip to content

Commit 49570ea

Browse files
authored
snip source correctly for dep tracker (#5613)
* snip source correctly for dep tracker * add tests * fix those version issues * remove unused code
1 parent c9c4203 commit 49570ea

File tree

3 files changed

+618
-30
lines changed

3 files changed

+618
-30
lines changed

reflex/vars/dep_tracking.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import dis
88
import enum
99
import inspect
10+
import sys
1011
from types import CellType, CodeType, FunctionType
1112
from typing import TYPE_CHECKING, Any, ClassVar, cast
1213

@@ -60,9 +61,7 @@ class DependencyTracker:
6061
tracked_locals: dict[str, type[BaseState]] = dataclasses.field(default_factory=dict)
6162

6263
_getting_state_class: type[BaseState] | None = dataclasses.field(default=None)
63-
_getting_var_instructions: list[dis.Instruction] = dataclasses.field(
64-
default_factory=list
65-
)
64+
_get_var_value_positions: dis.Positions | None = dataclasses.field(default=None)
6665

6766
INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"]
6867

@@ -114,6 +113,8 @@ def load_attr_or_method(self, instruction: dis.Instruction) -> None:
114113
return
115114
if instruction.argval == "get_var_value":
116115
# Special case: arbitrary var access requested.
116+
if sys.version_info >= (3, 11):
117+
self._get_var_value_positions = instruction.positions
117118
self.scan_status = ScanStatus.GETTING_VAR
118119
return
119120

@@ -222,9 +223,12 @@ def handle_getting_state(self, instruction: dis.Instruction) -> None:
222223
self.scan_status = ScanStatus.SCANNING
223224
self._getting_state_class = None
224225

225-
def _eval_var(self) -> Var:
226+
def _eval_var(self, positions: dis.Positions) -> Var:
226227
"""Evaluate instructions from the wrapped function to get the Var object.
227228
229+
Args:
230+
positions: The disassembly positions of the get_var_value call.
231+
228232
Returns:
229233
The Var object.
230234
@@ -233,15 +237,13 @@ def _eval_var(self) -> Var:
233237
"""
234238
# Get the original source code and eval it to get the Var.
235239
module = inspect.getmodule(self.func)
236-
positions0 = self._getting_var_instructions[0].positions
237-
positions1 = self._getting_var_instructions[-1].positions
238-
if module is None or positions0 is None or positions1 is None:
240+
if module is None or self._get_var_value_positions is None:
239241
msg = f"Cannot determine the source code for the var in {self.func!r}."
240242
raise VarValueError(msg)
241-
start_line = positions0.lineno
242-
start_column = positions0.col_offset
243-
end_line = positions1.end_lineno
244-
end_column = positions1.end_col_offset
243+
start_line = self._get_var_value_positions.end_lineno
244+
start_column = self._get_var_value_positions.end_col_offset
245+
end_line = positions.end_lineno
246+
end_column = positions.end_col_offset
245247
if (
246248
start_line is None
247249
or start_column is None
@@ -254,14 +256,10 @@ def _eval_var(self) -> Var:
254256
# Create a python source string snippet.
255257
if len(source) > 1:
256258
snipped_source = "".join(
257-
[
258-
*source[0][start_column:],
259-
*(source[1:-2] if len(source) > 2 else []),
260-
*source[-1][: end_column - 1],
261-
]
259+
[*source[0][start_column:], *source[1:-1], *source[-1][:end_column]]
262260
)
263261
else:
264-
snipped_source = source[0][start_column : end_column - 1]
262+
snipped_source = source[0][start_column:end_column]
265263
# Evaluate the string in the context of the function's globals and closure.
266264
return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
267265

@@ -279,20 +277,19 @@ def handle_getting_var(self, instruction: dis.Instruction) -> None:
279277
Raises:
280278
VarValueError: if the source code for the var cannot be determined.
281279
"""
282-
if instruction.opname == "CALL" and self._getting_var_instructions:
283-
if self._getting_var_instructions:
284-
the_var = self._eval_var()
285-
the_var_data = the_var._get_all_var_data()
286-
if the_var_data is None:
287-
msg = f"Cannot determine the source code for the var in {self.func!r}."
288-
raise VarValueError(msg)
289-
self.dependencies.setdefault(the_var_data.state, set()).add(
290-
the_var_data.field_name
291-
)
292-
self._getting_var_instructions.clear()
280+
if instruction.opname == "CALL":
281+
if instruction.positions is None:
282+
msg = f"Cannot determine the source code for the var in {self.func!r}."
283+
raise VarValueError(msg)
284+
the_var = self._eval_var(instruction.positions)
285+
the_var_data = the_var._get_all_var_data()
286+
if the_var_data is None:
287+
msg = f"Cannot determine the source code for the var in {self.func!r}."
288+
raise VarValueError(msg)
289+
self.dependencies.setdefault(the_var_data.state, set()).add(
290+
the_var_data.field_name
291+
)
293292
self.scan_status = ScanStatus.SCANNING
294-
else:
295-
self._getting_var_instructions.append(instruction)
296293

297294
def _populate_dependencies(self) -> None:
298295
"""Update self.dependencies based on the disassembly of self.func.

0 commit comments

Comments
 (0)