|
| 1 | +import functools |
1 | 2 | import inspect |
| 3 | +import operator |
2 | 4 | import os |
3 | 5 | import traceback |
4 | 6 | from functools import reduce |
5 | | -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union |
| 7 | +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union |
6 | 8 | import torch |
7 | 9 | from torch._subclasses.fake_tensor import FakeTensorMode |
8 | 10 |
|
@@ -397,6 +399,284 @@ def _log_guard( |
397 | 399 | # stacklevel=0, |
398 | 400 | # ) |
399 | 401 |
|
| 402 | + def _evaluate_expr( |
| 403 | + self, |
| 404 | + orig_expr: "sympy.Basic", # noqa: F821 |
| 405 | + hint: Optional[Union[bool, int, float]] = None, |
| 406 | + fx_node: Optional[torch.fx.Node] = None, |
| 407 | + size_oblivious: bool = False, |
| 408 | + fallback_value: Optional[bool] = None, |
| 409 | + *, |
| 410 | + forcing_spec: bool = False, |
| 411 | + ) -> "sympy.Basic": # noqa: F821 |
| 412 | + # TODO: split conjunctions and evaluate them separately |
| 413 | + import sympy |
| 414 | + from torch.fx.experimental import _config as config |
| 415 | + from torch.fx.experimental.symbolic_shapes import ( |
| 416 | + SympyBoolean, |
| 417 | + log, |
| 418 | + SymT, |
| 419 | + symbol_is_type, |
| 420 | + ) |
| 421 | + from torch._guards import ShapeGuard |
| 422 | + |
| 423 | + if isinstance( |
| 424 | + orig_expr, |
| 425 | + (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse), |
| 426 | + ): |
| 427 | + return orig_expr |
| 428 | + |
| 429 | + # Don't track this one. (Because this cache is inside this function the |
| 430 | + # cache only lasts for the invocation of this function call) |
| 431 | + @functools.cache |
| 432 | + def compute_concrete_val() -> sympy.Basic: |
| 433 | + if hint is None: |
| 434 | + # This is only ever called for expressions WITHOUT unbacked |
| 435 | + # symbols |
| 436 | + r = self.size_hint(orig_expr) |
| 437 | + assert r is not None |
| 438 | + return r |
| 439 | + else: |
| 440 | + return sympy.sympify(hint) |
| 441 | + |
| 442 | + concrete_val: Optional[sympy.Basic] |
| 443 | + |
| 444 | + # Check if: |
| 445 | + # 1. 'translation_validation' is set |
| 446 | + # 2. the corresponding 'fx_node' is not 'None' |
| 447 | + # 3. the guard should not be suppressed |
| 448 | + # 4. the guard doesn't contain backed symfloat symbols |
| 449 | + # since z3 can't handle floats |
| 450 | + # 5. fallback_value is none. |
| 451 | + # If all of the above check, we create an FX node representing the |
| 452 | + # actual expression to be guarded. |
| 453 | + node = None |
| 454 | + fresh = False |
| 455 | + if ( |
| 456 | + self._translation_validation_enabled |
| 457 | + and fx_node is not None |
| 458 | + and not self._suppress_guards_tls() |
| 459 | + and not size_oblivious |
| 460 | + and not any(symbol_is_type(s, SymT.FLOAT) for s in orig_expr.free_symbols) |
| 461 | + and fallback_value is None |
| 462 | + ): |
| 463 | + # TODO: does this even worked with unbacked :think: |
| 464 | + concrete_val = compute_concrete_val() |
| 465 | + if concrete_val is sympy.true: |
| 466 | + node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) |
| 467 | + elif concrete_val is sympy.false: |
| 468 | + neg, _ = self._create_fx_call_function(operator.not_, (fx_node,)) |
| 469 | + node, fresh = self._create_fx_call_function(torch._assert, (neg,)) |
| 470 | + else: |
| 471 | + eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val)) |
| 472 | + node, fresh = self._create_fx_call_function(torch._assert, (eql,)) |
| 473 | + |
| 474 | + assert node is not None |
| 475 | + # If this is a fresh node, we have to remember the event index that |
| 476 | + # corresponds to this assertion node. |
| 477 | + # Reason: so that, given an assertion node, we can replay the ShapeEnv |
| 478 | + # events until the point where this assertion node was freshly created. |
| 479 | + if fresh: |
| 480 | + self._add_fx_node_metadata(node) |
| 481 | + |
| 482 | + # After creating the FX node corresponding to orig_expr, we must make sure that |
| 483 | + # no error will be raised until the end of this function. |
| 484 | + # |
| 485 | + # Reason: the translation validation may become invalid otherwise. |
| 486 | + # |
| 487 | + # If an error is raised before the end of this function, we remove the FX node |
| 488 | + # inserted, and re-raise the error. |
| 489 | + guard = None |
| 490 | + |
| 491 | + try: |
| 492 | + if orig_expr.is_number: |
| 493 | + self.log.debug("eval %s [trivial]", orig_expr) |
| 494 | + if hint is not None: |
| 495 | + if isinstance(hint, bool): |
| 496 | + assert orig_expr == hint, f"{orig_expr} != {hint}" |
| 497 | + else: |
| 498 | + assert sympy.Eq(orig_expr, hint), f"{orig_expr} != {hint}" |
| 499 | + return orig_expr |
| 500 | + |
| 501 | + expr = orig_expr |
| 502 | + |
| 503 | + static_expr = self._maybe_evaluate_static(expr, size_oblivious=size_oblivious) |
| 504 | + if static_expr is not None: |
| 505 | + self.log.debug( |
| 506 | + "eval %s == %s [statically known]", |
| 507 | + (f"size_oblivious({orig_expr})" if size_oblivious else size_oblivious), |
| 508 | + static_expr, |
| 509 | + ) |
| 510 | + if not size_oblivious and config.backed_size_oblivious and hint is not None: |
| 511 | + # TODO: maybe reconcile this with use of counterfactual hints |
| 512 | + # in unbacked case |
| 513 | + assert static_expr == hint, f"{static_expr} != {hint}" |
| 514 | + return static_expr |
| 515 | + |
| 516 | + transmute_into_runtime_assert = False |
| 517 | + |
| 518 | + concrete_val = None |
| 519 | + if not (expr.free_symbols <= self.var_to_val.keys()): |
| 520 | + # TODO: dedupe this with _maybe_evaluate_static |
| 521 | + # Attempt to eliminate the unbacked SymInt |
| 522 | + new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) |
| 523 | + assert new_expr is not None |
| 524 | + if not (new_expr.free_symbols <= self.var_to_val.keys()): |
| 525 | + ok = False |
| 526 | + |
| 527 | + # fallback_value is set when guard_or_true or guard_or_false are used. |
| 528 | + if not ok and fallback_value is not None: |
| 529 | + self._log_suppressed_dde(orig_expr, fallback_value) |
| 530 | + return fallback_value |
| 531 | + |
| 532 | + # oblivious_var_to_val will be defined iff we have sizes |
| 533 | + # with DimDynamic.OBLIVIOUS_SIZE type. |
| 534 | + # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113 |
| 535 | + if ( |
| 536 | + self.oblivious_var_to_val |
| 537 | + and not ( |
| 538 | + correct_hint := orig_expr.xreplace(self.oblivious_var_to_val) |
| 539 | + ).free_symbols |
| 540 | + and not ( |
| 541 | + counterfactual_hint := orig_expr.xreplace( |
| 542 | + {k: max(2, v) for k, v in self.oblivious_var_to_val.items()} |
| 543 | + ) |
| 544 | + ).free_symbols |
| 545 | + and correct_hint == counterfactual_hint |
| 546 | + ): |
| 547 | + # TODO: better logging |
| 548 | + log.info( |
| 549 | + "oblivious_size %s -> %s (passed counterfactual)", |
| 550 | + orig_expr, |
| 551 | + # pyrefly: ignore # unbound-name |
| 552 | + correct_hint, |
| 553 | + ) |
| 554 | + # pyrefly: ignore # unbound-name |
| 555 | + concrete_val = correct_hint |
| 556 | + # NB: do NOT transmute into runtime assert |
| 557 | + ok = True |
| 558 | + |
| 559 | + # unbacked_var_to_val is not None iff propagate_real_tensors is on. |
| 560 | + # if propagate_real_tensors is on, we check the example values |
| 561 | + # to generate (unsound_result) |
| 562 | + # and if they pass we add a runtime assertions and continue. |
| 563 | + if ( |
| 564 | + not ok |
| 565 | + and self.unbacked_var_to_val |
| 566 | + and not ( |
| 567 | + unsound_result := orig_expr.xreplace( |
| 568 | + self.unbacked_var_to_val |
| 569 | + ).xreplace(self.var_to_val) |
| 570 | + ).free_symbols |
| 571 | + ): |
| 572 | + # pyrefly: ignore # unbound-name |
| 573 | + self._log_real_tensor_propagation(orig_expr, unsound_result) |
| 574 | + transmute_into_runtime_assert = True |
| 575 | + # pyrefly: ignore # unbound-name |
| 576 | + concrete_val = unsound_result |
| 577 | + ok = True |
| 578 | + |
| 579 | + # Check if this is coming from a python assert statement, |
| 580 | + # if so, convert it to a runtime assertion |
| 581 | + # instead of failing. |
| 582 | + if not ok and self.trace_asserts and self._is_python_assert(): |
| 583 | + concrete_val = sympy.true |
| 584 | + transmute_into_runtime_assert = True |
| 585 | + ok = True |
| 586 | + |
| 587 | + # PATCHED: ok -> True |
| 588 | + ok = True |
| 589 | + # if not ok: |
| 590 | + # raise self._make_data_dependent_error( |
| 591 | + # expr.xreplace(self.var_to_val), |
| 592 | + # expr, |
| 593 | + # expr_sym_node_id=self._expr_sym_node_id, |
| 594 | + # ) |
| 595 | + else: |
| 596 | + expr = new_expr |
| 597 | + |
| 598 | + if concrete_val is None: |
| 599 | + concrete_val = compute_concrete_val() |
| 600 | + self._check_frozen(expr, concrete_val) |
| 601 | + |
| 602 | + if ( |
| 603 | + config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY |
| 604 | + and isinstance(hint, bool) |
| 605 | + and isinstance(expr, (sympy.Eq, sympy.Ne)) |
| 606 | + ): |
| 607 | + expr = sympy.Not(expr) |
| 608 | + |
| 609 | + # Turn this into a boolean expression, no longer need to consult |
| 610 | + # concrete_val |
| 611 | + if concrete_val is sympy.true: |
| 612 | + g = cast(SympyBoolean, expr) |
| 613 | + elif concrete_val is sympy.false: |
| 614 | + g = sympy.Not(expr) |
| 615 | + else: |
| 616 | + g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type] |
| 617 | + |
| 618 | + if transmute_into_runtime_assert: |
| 619 | + self.guard_or_defer_runtime_assert( |
| 620 | + g, f"propagate_real_tensors: {orig_expr} == {concrete_val}" |
| 621 | + ) |
| 622 | + return concrete_val |
| 623 | + |
| 624 | + if not self._suppress_guards_tls(): |
| 625 | + self._log_guard("eval", g, forcing_spec=forcing_spec) |
| 626 | + |
| 627 | + # TODO: If we successfully eliminate a symbol via equality, it |
| 628 | + # is not actually necessary to save a guard for the equality, |
| 629 | + # as we will implicitly generate a guard when we match that |
| 630 | + # input against the symbol. Probably the easiest way to |
| 631 | + # implement this is to have maybe_guard_rel return a bool |
| 632 | + # saying if it "subsumed" the guard (and therefore the guard |
| 633 | + # is no longer necessary) |
| 634 | + self._maybe_guard_rel(g) |
| 635 | + |
| 636 | + if ( |
| 637 | + torch.compiler.is_exporting() |
| 638 | + and self.prefer_deferred_runtime_asserts_over_guards |
| 639 | + ): |
| 640 | + # it's fine to defer simple guards here without checking, |
| 641 | + # the _maybe_guard_rel() call above will set replacements if possible, |
| 642 | + # and so the result here will be statically known |
| 643 | + self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}") |
| 644 | + else: |
| 645 | + # at this point, we've evaluated the concrete expr value, and have |
| 646 | + # flipped/negated the guard if necessary. Now we know what to guard |
| 647 | + # or defer to runtime assert on. |
| 648 | + guard = ShapeGuard(g, self._get_sloc(), size_oblivious=size_oblivious) |
| 649 | + self.guards.append(guard) |
| 650 | + self.axioms.update(dict(self.get_implications(self.simplify(g)))) |
| 651 | + else: |
| 652 | + self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec) |
| 653 | + |
| 654 | + except Exception: |
| 655 | + if fresh: |
| 656 | + self._remove_fx_node(node) |
| 657 | + raise |
| 658 | + |
| 659 | + if not self._suppress_guards_tls(): |
| 660 | + if guard is not None: # we might have deferred this to runtime assert |
| 661 | + for s in g.free_symbols: |
| 662 | + self.symbol_guard_counter[s] += 1 |
| 663 | + # Forcing_spec to avoid infinite recursion |
| 664 | + if ( |
| 665 | + not forcing_spec |
| 666 | + and config.symbol_guard_limit_before_specialize is not None |
| 667 | + and self.symbol_guard_counter[s] |
| 668 | + > config.symbol_guard_limit_before_specialize |
| 669 | + ): |
| 670 | + # Force specialization |
| 671 | + self.log.info( |
| 672 | + "symbol_guard_limit_before_specialize=%s exceeded on %s", |
| 673 | + config.symbol_guard_limit_before_specialize, |
| 674 | + s, |
| 675 | + ) |
| 676 | + self.evaluate_expr(s, forcing_spec=True) |
| 677 | + |
| 678 | + return concrete_val |
| 679 | + |
400 | 680 |
|
401 | 681 | def patched_vmap(func, in_dims=0, out_dims=0): |
402 | 682 | """ |
|
0 commit comments