Skip to content

Commit a414956

Browse files
committed
Patch to disable one exception in torch
1 parent fa664f1 commit a414956

File tree

2 files changed

+288
-4
lines changed

2 files changed

+288
-4
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ def torch_export_patches(
340340
###############
341341

342342
if patch_torch:
343+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
343344
from .patches.patch_torch import (
344345
patched_infer_size,
345346
patched_vmap,
@@ -349,6 +350,7 @@ def torch_export_patches(
349350
patch__check_input_constraints_for_graph,
350351
patched__broadcast_in_dim_meta,
351352
patched__maybe_broadcast,
353+
patched_ShapeEnv,
352354
)
353355

354356
if verbose:
@@ -395,6 +397,10 @@ def torch_export_patches(
395397
f__maybe_broadcast = torch._refs._maybe_broadcast
396398
torch._refs._maybe_broadcast = patched__maybe_broadcast
397399

400+
# ShapeEnv
401+
f_shape_env__evaluate_expr = ShapeEnv._evaluate_expr
402+
ShapeEnv._evaluate_expr = patched_ShapeEnv._evaluate_expr
403+
398404
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
399405
if patch_torch and catch_constraints:
400406
if verbose:
@@ -417,9 +423,6 @@ def torch_export_patches(
417423
)
418424

419425
if stop_if_static:
420-
from torch.fx.experimental.symbolic_shapes import ShapeEnv
421-
from .patches.patch_torch import patched_ShapeEnv
422-
423426
ShapeEnv._log_guard_remember = ShapeEnv._log_guard
424427

425428
if verbose:
@@ -599,6 +602,7 @@ def torch_export_patches(
599602
torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
600603
torch._prims.broadcast_in_dim = f_broadcast_in_dim
601604
torch._refs._maybe_broadcast = f__maybe_broadcast
605+
ShapeEnv._evaluate_expr = f_shape_env__evaluate_expr
602606

603607
if verbose:
604608
print("[torch_export_patches] restored pytorch functions")

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 281 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import functools
12
import inspect
3+
import operator
24
import os
35
import traceback
46
from functools import reduce
5-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
7+
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
68
import torch
79
from torch._subclasses.fake_tensor import FakeTensorMode
810

@@ -397,6 +399,284 @@ def _log_guard(
397399
# stacklevel=0,
398400
# )
399401

402+
def _evaluate_expr(
403+
self,
404+
orig_expr: "sympy.Basic", # noqa: F821
405+
hint: Optional[Union[bool, int, float]] = None,
406+
fx_node: Optional[torch.fx.Node] = None,
407+
size_oblivious: bool = False,
408+
fallback_value: Optional[bool] = None,
409+
*,
410+
forcing_spec: bool = False,
411+
) -> "sympy.Basic": # noqa: F821
412+
# TODO: split conjunctions and evaluate them separately
413+
import sympy
414+
from torch.fx.experimental import _config as config
415+
from torch.fx.experimental.symbolic_shapes import (
416+
SympyBoolean,
417+
log,
418+
SymT,
419+
symbol_is_type,
420+
)
421+
from torch._guards import ShapeGuard
422+
423+
if isinstance(
424+
orig_expr,
425+
(sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
426+
):
427+
return orig_expr
428+
429+
# Don't track this one. (Because this cache is inside this function the
430+
# cache only lasts for the invocation of this function call)
431+
@functools.cache
432+
def compute_concrete_val() -> sympy.Basic:
433+
if hint is None:
434+
# This is only ever called for expressions WITHOUT unbacked
435+
# symbols
436+
r = self.size_hint(orig_expr)
437+
assert r is not None
438+
return r
439+
else:
440+
return sympy.sympify(hint)
441+
442+
concrete_val: Optional[sympy.Basic]
443+
444+
# Check if:
445+
# 1. 'translation_validation' is set
446+
# 2. the corresponding 'fx_node' is not 'None'
447+
# 3. the guard should not be suppressed
448+
# 4. the guard doesn't contain backed symfloat symbols
449+
# since z3 can't handle floats
450+
# 5. fallback_value is none.
451+
# If all of the above check, we create an FX node representing the
452+
# actual expression to be guarded.
453+
node = None
454+
fresh = False
455+
if (
456+
self._translation_validation_enabled
457+
and fx_node is not None
458+
and not self._suppress_guards_tls()
459+
and not size_oblivious
460+
and not any(symbol_is_type(s, SymT.FLOAT) for s in orig_expr.free_symbols)
461+
and fallback_value is None
462+
):
463+
# TODO: does this even worked with unbacked :think:
464+
concrete_val = compute_concrete_val()
465+
if concrete_val is sympy.true:
466+
node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
467+
elif concrete_val is sympy.false:
468+
neg, _ = self._create_fx_call_function(operator.not_, (fx_node,))
469+
node, fresh = self._create_fx_call_function(torch._assert, (neg,))
470+
else:
471+
eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val))
472+
node, fresh = self._create_fx_call_function(torch._assert, (eql,))
473+
474+
assert node is not None
475+
# If this is a fresh node, we have to remember the event index that
476+
# corresponds to this assertion node.
477+
# Reason: so that, given an assertion node, we can replay the ShapeEnv
478+
# events until the point where this assertion node was freshly created.
479+
if fresh:
480+
self._add_fx_node_metadata(node)
481+
482+
# After creating the FX node corresponding to orig_expr, we must make sure that
483+
# no error will be raised until the end of this function.
484+
#
485+
# Reason: the translation validation may become invalid otherwise.
486+
#
487+
# If an error is raised before the end of this function, we remove the FX node
488+
# inserted, and re-raise the error.
489+
guard = None
490+
491+
try:
492+
if orig_expr.is_number:
493+
self.log.debug("eval %s [trivial]", orig_expr)
494+
if hint is not None:
495+
if isinstance(hint, bool):
496+
assert orig_expr == hint, f"{orig_expr} != {hint}"
497+
else:
498+
assert sympy.Eq(orig_expr, hint), f"{orig_expr} != {hint}"
499+
return orig_expr
500+
501+
expr = orig_expr
502+
503+
static_expr = self._maybe_evaluate_static(expr, size_oblivious=size_oblivious)
504+
if static_expr is not None:
505+
self.log.debug(
506+
"eval %s == %s [statically known]",
507+
(f"size_oblivious({orig_expr})" if size_oblivious else size_oblivious),
508+
static_expr,
509+
)
510+
if not size_oblivious and config.backed_size_oblivious and hint is not None:
511+
# TODO: maybe reconcile this with use of counterfactual hints
512+
# in unbacked case
513+
assert static_expr == hint, f"{static_expr} != {hint}"
514+
return static_expr
515+
516+
transmute_into_runtime_assert = False
517+
518+
concrete_val = None
519+
if not (expr.free_symbols <= self.var_to_val.keys()):
520+
# TODO: dedupe this with _maybe_evaluate_static
521+
# Attempt to eliminate the unbacked SymInt
522+
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
523+
assert new_expr is not None
524+
if not (new_expr.free_symbols <= self.var_to_val.keys()):
525+
ok = False
526+
527+
# fallback_value is set when guard_or_true or guard_or_false are used.
528+
if not ok and fallback_value is not None:
529+
self._log_suppressed_dde(orig_expr, fallback_value)
530+
return fallback_value
531+
532+
# oblivious_var_to_val will be defined iff we have sizes
533+
# with DimDynamic.OBLIVIOUS_SIZE type.
534+
# See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
535+
if (
536+
self.oblivious_var_to_val
537+
and not (
538+
correct_hint := orig_expr.xreplace(self.oblivious_var_to_val)
539+
).free_symbols
540+
and not (
541+
counterfactual_hint := orig_expr.xreplace(
542+
{k: max(2, v) for k, v in self.oblivious_var_to_val.items()}
543+
)
544+
).free_symbols
545+
and correct_hint == counterfactual_hint
546+
):
547+
# TODO: better logging
548+
log.info(
549+
"oblivious_size %s -> %s (passed counterfactual)",
550+
orig_expr,
551+
# pyrefly: ignore # unbound-name
552+
correct_hint,
553+
)
554+
# pyrefly: ignore # unbound-name
555+
concrete_val = correct_hint
556+
# NB: do NOT transmute into runtime assert
557+
ok = True
558+
559+
# unbacked_var_to_val is not None iff propagate_real_tensors is on.
560+
# if propagate_real_tensors is on, we check the example values
561+
# to generate (unsound_result)
562+
# and if they pass we add a runtime assertions and continue.
563+
if (
564+
not ok
565+
and self.unbacked_var_to_val
566+
and not (
567+
unsound_result := orig_expr.xreplace(
568+
self.unbacked_var_to_val
569+
).xreplace(self.var_to_val)
570+
).free_symbols
571+
):
572+
# pyrefly: ignore # unbound-name
573+
self._log_real_tensor_propagation(orig_expr, unsound_result)
574+
transmute_into_runtime_assert = True
575+
# pyrefly: ignore # unbound-name
576+
concrete_val = unsound_result
577+
ok = True
578+
579+
# Check if this is coming from a python assert statement,
580+
# if so, convert it to a runtime assertion
581+
# instead of failing.
582+
if not ok and self.trace_asserts and self._is_python_assert():
583+
concrete_val = sympy.true
584+
transmute_into_runtime_assert = True
585+
ok = True
586+
587+
# PATCHED: ok -> True
588+
ok = True
589+
# if not ok:
590+
# raise self._make_data_dependent_error(
591+
# expr.xreplace(self.var_to_val),
592+
# expr,
593+
# expr_sym_node_id=self._expr_sym_node_id,
594+
# )
595+
else:
596+
expr = new_expr
597+
598+
if concrete_val is None:
599+
concrete_val = compute_concrete_val()
600+
self._check_frozen(expr, concrete_val)
601+
602+
if (
603+
config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
604+
and isinstance(hint, bool)
605+
and isinstance(expr, (sympy.Eq, sympy.Ne))
606+
):
607+
expr = sympy.Not(expr)
608+
609+
# Turn this into a boolean expression, no longer need to consult
610+
# concrete_val
611+
if concrete_val is sympy.true:
612+
g = cast(SympyBoolean, expr)
613+
elif concrete_val is sympy.false:
614+
g = sympy.Not(expr)
615+
else:
616+
g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
617+
618+
if transmute_into_runtime_assert:
619+
self.guard_or_defer_runtime_assert(
620+
g, f"propagate_real_tensors: {orig_expr} == {concrete_val}"
621+
)
622+
return concrete_val
623+
624+
if not self._suppress_guards_tls():
625+
self._log_guard("eval", g, forcing_spec=forcing_spec)
626+
627+
# TODO: If we successfully eliminate a symbol via equality, it
628+
# is not actually necessary to save a guard for the equality,
629+
# as we will implicitly generate a guard when we match that
630+
# input against the symbol. Probably the easiest way to
631+
# implement this is to have maybe_guard_rel return a bool
632+
# saying if it "subsumed" the guard (and therefore the guard
633+
# is no longer necessary)
634+
self._maybe_guard_rel(g)
635+
636+
if (
637+
torch.compiler.is_exporting()
638+
and self.prefer_deferred_runtime_asserts_over_guards
639+
):
640+
# it's fine to defer simple guards here without checking,
641+
# the _maybe_guard_rel() call above will set replacements if possible,
642+
# and so the result here will be statically known
643+
self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
644+
else:
645+
# at this point, we've evaluated the concrete expr value, and have
646+
# flipped/negated the guard if necessary. Now we know what to guard
647+
# or defer to runtime assert on.
648+
guard = ShapeGuard(g, self._get_sloc(), size_oblivious=size_oblivious)
649+
self.guards.append(guard)
650+
self.axioms.update(dict(self.get_implications(self.simplify(g))))
651+
else:
652+
self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)
653+
654+
except Exception:
655+
if fresh:
656+
self._remove_fx_node(node)
657+
raise
658+
659+
if not self._suppress_guards_tls():
660+
if guard is not None: # we might have deferred this to runtime assert
661+
for s in g.free_symbols:
662+
self.symbol_guard_counter[s] += 1
663+
# Forcing_spec to avoid infinite recursion
664+
if (
665+
not forcing_spec
666+
and config.symbol_guard_limit_before_specialize is not None
667+
and self.symbol_guard_counter[s]
668+
> config.symbol_guard_limit_before_specialize
669+
):
670+
# Force specialization
671+
self.log.info(
672+
"symbol_guard_limit_before_specialize=%s exceeded on %s",
673+
config.symbol_guard_limit_before_specialize,
674+
s,
675+
)
676+
self.evaluate_expr(s, forcing_spec=True)
677+
678+
return concrete_val
679+
400680

401681
def patched_vmap(func, in_dims=0, out_dims=0):
402682
"""

0 commit comments

Comments
 (0)