|
1 | 1 | import ast |
| 2 | +import copy |
2 | 3 | import inspect |
3 | 4 | import types |
4 | 5 | import textwrap |
@@ -341,11 +342,34 @@ def visit(self, node): |
341 | 342 | return node |
342 | 343 |
|
343 | 344 |
|
| 345 | +class _SelectiveAssignNormalizer(ast.NodeTransformer): |
| 346 | + def visit_If(self, node): |
| 347 | + self.generic_visit(node) |
| 348 | + node.body = [self._transform_if_needed(stmt) for stmt in node.body] |
| 349 | + node.orelse = [self._transform_if_needed(stmt) for stmt in node.orelse] |
| 350 | + return node |
| 351 | + |
| 352 | + def _transform_if_needed(self, stmt): |
| 353 | + if isinstance(stmt, ast.AugAssign): |
| 354 | + return ast.Assign( |
| 355 | + targets=[stmt.target], |
| 356 | + value=ast.BinOp(left=copy.deepcopy(stmt.target), op=stmt.op, right=stmt.value), |
| 357 | + ) |
| 358 | + if isinstance(stmt, ast.AnnAssign) and stmt.value is not None: |
| 359 | + return ast.Assign(targets=[stmt.target], value=stmt.value) |
| 360 | + return self.visit(stmt) |
| 361 | + |
| 362 | + |
344 | 363 | def inplace_add_parent(tree: "ast.Node"): |
345 | 364 | """Adds parents to an AST tree.""" |
346 | 365 | _AddParentTransformer().visit(tree) |
347 | 366 |
|
348 | 367 |
|
| 368 | +def normalize_assignment_in_test(tree: "ast.Node"): |
| 369 | + """Split AugAssign into BinOp and Assign to simplify whatever comes after.""" |
| 370 | + _SelectiveAssignNormalizer().visit(tree) |
| 371 | + |
| 372 | + |
349 | 373 | def transform_method( |
350 | 374 | func: Callable, |
351 | 375 | prefix: str = "branch_cond", |
@@ -451,6 +475,7 @@ def forward(self, x, y): |
451 | 475 | skip_objects=modules, |
452 | 476 | args_names=set(sig.parameters), |
453 | 477 | ) |
| 478 | + normalize_assignment_in_test(tree) |
454 | 479 | inplace_add_parent(tree) |
455 | 480 | new_tree = transformer.visit(tree) |
456 | 481 | if verbose > 1: |
|
0 commit comments