1717import abc
1818import atexit
1919import collections
20- import contextlib
2120import dis
2221import functools
2322import hashlib
@@ -1218,17 +1217,6 @@ def compute_unbacked_bindings(
12181217 return symbol_to_path
12191218
12201219
1221- def _log_suppressed_dde (a : SymBool , assumed_value : bool ) -> None :
1222- sloc , extra = a .node .shape_env ._get_stack_summary (True )
1223- log .info (
1224- "could not evaluate %s due to data dependency, it was assumed to be %s with no runtime assertions %s %s" ,
1225- a ,
1226- assumed_value ,
1227- sloc ,
1228- extra ,
1229- )
1230-
1231-
12321220# The following two functions are common utilities used while defining unbacked semantics
12331221# of various framework code. Those would be used in situations you prefer to guard and know
12341222# the result of the expression over not guarding, but in case you hit a data dependent error
@@ -1265,12 +1253,11 @@ def _guard_or(a: BoolLikeType, default: bool) -> bool:
12651253 if shape_env is None :
12661254 return guard_bool (a )
12671255
1268- with a .node .shape_env .dde_suppressed ():
1269- try :
1270- return guard_bool (a )
1271- except GuardOnDataDependentSymNode :
1272- _log_suppressed_dde (a , default )
1273- return default
1256+ sym_node = a .node
1257+ r = sym_node .shape_env .evaluate_sym_node (
1258+ sym_node , size_oblivious = False , fallback_value = default
1259+ )
1260+ return bool (r )
12741261
12751262
12761263def guard_or_false (a : BoolLikeType ) -> bool :
@@ -3314,10 +3301,6 @@ def __init__(
33143301 else []
33153302 )
33163303
3317- # Set true when data dependent errors are handled by caller side and not thrown. Ex: guard_or_false
3318- # and guard_or_true. When its true, a different error message is produced.
3319- self ._dde_suppressed = False
3320-
33213304 # FakeTensor per-ShapeEnv operation cache. This is used for caching
33223305 # operations that contain symbolic shapes which have guards on the
33233306 # ShapeEnv (so are ShapeEnv-dependent).
@@ -3330,18 +3313,6 @@ def __init__(
33303313 torch ._subclasses .fake_tensor ._DispatchCacheEntry ,
33313314 ] = {}
33323315
3333- @contextlib .contextmanager
3334- def dde_suppressed (self ) -> Iterator [None ]:
3335- """Suppressed GuardOnDataDependent error logs"""
3336-
3337- # We do not expect this to be called recursively.
3338- assert not self ._dde_suppressed , "not expected value for _dde_suppressed"
3339- self ._dde_suppressed = True
3340- try :
3341- yield
3342- finally :
3343- self ._dde_suppressed = False
3344-
33453316 # Pro-tip: if you add new field to ShapeEnv, this affects some accept
33463317 # tests. Accept their output with:
33473318 #
@@ -3643,7 +3614,6 @@ def check_equal(self, other: ShapeEnv) -> None:
36433614 "replacements_slocs" ,
36443615 "_resimplify_floor_div_axioms" ,
36453616 "_expr_sym_node_id" ,
3646- "_dde_suppressed" ,
36473617 "specialization_stacks" ,
36483618 )
36493619
@@ -6152,12 +6122,6 @@ def _make_data_dependent_error(
61526122 size_oblivious_result : Optional [sympy .Basic ] = None ,
61536123 expr_sym_node_id : Optional [int ] = None ,
61546124 ) -> GuardOnDataDependentSymNode :
6155- if self ._dde_suppressed :
6156- return GuardOnDataDependentSymNode (
6157- expr ,
6158- "This data dependent error is suppressed and handled by the caller" ,
6159- )
6160-
61616125 # TODO: in a Dynamo context, having user code, and having the
61626126 # name of the local, will be much better
61636127 size_like_symbols = []
@@ -6846,14 +6810,19 @@ def evaluate_sym_node(
68466810 self ,
68476811 sym_node : SymNode ,
68486812 size_oblivious : bool = False ,
6813+ fallback_value : Optional [bool ] = None ,
68496814 ) -> sympy .Basic :
68506815 """
68516816 Given a a SymNode, evaluates sym_node.expr, adding guards if necessary.
68526817 """
68536818
68546819 self ._expr_sym_node_id = id (sym_node )
68556820 return self .evaluate_expr (
6856- sym_node .expr , sym_node .hint , sym_node .fx_node , size_oblivious
6821+ sym_node .expr ,
6822+ sym_node .hint ,
6823+ sym_node .fx_node ,
6824+ size_oblivious ,
6825+ fallback_value = fallback_value ,
68576826 )
68586827
68596828 def _is_python_assert (self ) -> bool :
@@ -6939,17 +6908,25 @@ def evaluate_expr(
69396908 hint : Optional [Union [int , bool , float ]] = None ,
69406909 fx_node : Optional [torch .fx .Node ] = None ,
69416910 size_oblivious : bool = False ,
6911+ fallback_value : Optional [bool ] = None ,
69426912 * ,
69436913 forcing_spec : bool = False ,
69446914 ) -> sympy .Basic :
69456915 """
69466916 Given an expression, evaluates it, adding guards if necessary
6917+ When fallback_value is not None the function return fallback_value instead of failing with data dependent error.
69476918 """
69486919
69496920 # Add extra state that evaluate_expr() depends on.
69506921 suppress_guards_tls = ShapeEnv ._suppress_guards_tls ()
69516922 return self ._inner_evaluate_expr (
6952- orig_expr , hint , fx_node , size_oblivious , forcing_spec , suppress_guards_tls
6923+ orig_expr ,
6924+ hint ,
6925+ fx_node ,
6926+ size_oblivious ,
6927+ forcing_spec ,
6928+ suppress_guards_tls ,
6929+ fallback_value ,
69536930 )
69546931
69556932 @lru_cache (256 )
@@ -6962,17 +6939,19 @@ def _inner_evaluate_expr(
69626939 size_oblivious : bool ,
69636940 forcing_spec : bool ,
69646941 _suppress_guards_tls : bool ,
6942+ fallback_value : Optional [bool ] = None ,
69656943 ) -> sympy .Basic :
69666944 try :
69676945 return self ._evaluate_expr (
69686946 orig_expr ,
69696947 hint ,
69706948 fx_node ,
69716949 size_oblivious ,
6950+ fallback_value ,
69726951 forcing_spec = forcing_spec ,
69736952 )
69746953 except Exception as e :
6975- if isinstance (e , GuardOnDataDependentSymNode ) and self . _dde_suppressed :
6954+ if isinstance (e , GuardOnDataDependentSymNode ):
69766955 pass
69776956 else :
69786957 self .log .warning (
@@ -6984,12 +6963,23 @@ def _inner_evaluate_expr(
69846963 )
69856964 raise
69866965
6966+ def _log_suppressed_dde (self , a : SymBool , assumed_value : bool ) -> None :
6967+ sloc , extra = self ._get_stack_summary (True )
6968+ log .info (
6969+ "could not evaluate %s due to data dependency, it was assumed to be %s with no runtime assertions %s %s" ,
6970+ a ,
6971+ assumed_value ,
6972+ sloc ,
6973+ extra ,
6974+ )
6975+
69876976 def _evaluate_expr (
69886977 self ,
69896978 orig_expr : sympy .Basic ,
69906979 hint : Optional [Union [bool , int , float ]] = None ,
69916980 fx_node : Optional [torch .fx .Node ] = None ,
69926981 size_oblivious : bool = False ,
6982+ fallback_value : Optional [bool ] = None ,
69936983 * ,
69946984 forcing_spec : bool = False ,
69956985 ) -> sympy .Basic :
@@ -7021,7 +7011,7 @@ def compute_concrete_val() -> sympy.Basic:
70217011 # 3. the guard should not be suppressed
70227012 # 4. the guard doesn't contain backed symfloat symbols
70237013 # since z3 can't handle floats
7024- #
7014+ # 5. fallback_value is none.
70257015 # If all of the above check, we create an FX node representing the
70267016 # actual expression to be guarded.
70277017 node = None
@@ -7032,6 +7022,7 @@ def compute_concrete_val() -> sympy.Basic:
70327022 and not self ._suppress_guards_tls ()
70337023 and not size_oblivious
70347024 and not any (symbol_is_type (s , SymT .FLOAT ) for s in orig_expr .free_symbols )
7025+ and fallback_value is None
70357026 ):
70367027 # TODO: does this even worked with unbacked :think:
70377028 concrete_val = compute_concrete_val ()
@@ -7113,7 +7104,7 @@ def compute_concrete_val() -> sympy.Basic:
71137104 # Those are backed dimentions that are treated as unbacked to avoid specializations, but if
71147105 # we fail to bypass with size oblivious reasoning we compute using the actual hint and guard.
71157106 if (
7116- not self . _dde_suppressed
7107+ fallback_value is None # do not do this under guard_or
71177108 and self .oblivious_var_to_val
71187109 and not (
71197110 correct_hint := orig_expr .xreplace (
@@ -7143,8 +7134,9 @@ def compute_concrete_val() -> sympy.Basic:
71437134 # unbacked_var_to_val is not None iff propagate_real_tensors is on.
71447135 # if propagate_real_tensors is on, we check the example values to generate (unsound_result)
71457136 # and if they pass we add a runtime assertions and continue.
7137+
71467138 if (
7147- not self . _dde_suppressed
7139+ fallback_value is None # do not do this under guard_or
71487140 and not ok
71497141 and self .unbacked_var_to_val
71507142 and not (
@@ -7165,10 +7157,16 @@ def compute_concrete_val() -> sympy.Basic:
71657157 transmute_into_runtime_assert = True
71667158 ok = True
71677159
7160+ # fallback value is set when guard_or_true, gaurd_or_false are used.
7161+ # whe we fail to evaluate soundly, we use the default value set by it.
7162+ if not ok and fallback_value is not None :
7163+ self ._log_suppressed_dde (orig_expr , fallback_value )
7164+ return fallback_value
7165+
71687166 if not ok :
71697167 size_oblivious_result = None
71707168 # compute size_oblivious_result to suggest it as a fix for the user if it works.
7171- if not size_oblivious and not self . _dde_suppressed :
7169+ if not size_oblivious :
71727170 size_oblivious_result = self ._maybe_evaluate_static (
71737171 expr , size_oblivious = True
71747172 )
0 commit comments