Skip to content

Commit 9e089bb

Browse files
laithsakkapytorchmergebot
authored andcommitted
change guard_or impl for better perf and simplicity (pytorch#153674)
PR time benchmarks has been showing regressions as we move to guard_or_false, reason is that prev implementation do not cache. This new approach will propagate the fallback value to eval and return it. allowing eval to cache and reducing scamming logs and complexity. Pull Request resolved: pytorch#153674 Approved by: https://github.com/bobrenjc93
1 parent 4b7abce commit 9e089bb

File tree

3 files changed

+60
-76
lines changed

3 files changed

+60
-76
lines changed

benchmarks/dynamo/pr_time_benchmarks/expected_results.csv

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
add_loop_eager,compile_time_instruction_count,2960000000,0.015
1+
add_loop_eager,compile_time_instruction_count,2953000000,0.015
22

33

44

5-
add_loop_eager_dynamic,compile_time_instruction_count,5827000000,0.025
5+
add_loop_eager_dynamic,compile_time_instruction_count,5808000000,0.025
66

77

88

99
add_loop_inductor,compile_time_instruction_count,29370000000,0.015
1010

1111

1212

13-
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44080000000,0.025
13+
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44010000000,0.025
1414

1515

1616

@@ -22,43 +22,27 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,939900000,0.015
2222

2323

2424

25-
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18240000000,0.015
25+
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18140000000,0.015
2626

2727

2828

29-
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16340000000,0.015
29+
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16220000000,0.015
3030

3131

3232

3333
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000000,0.2
3434

3535

3636

37-
basic_InlineMod_eager,compile_time_instruction_count,7101000000,0.015
37+
update_hint_regression,compile_time_instruction_count,1681000000,0.02
3838

3939

4040

41-
update_hint_regression,compile_time_instruction_count,1683000000,0.02
41+
float_args,compile_time_instruction_count,449800000,0.015
4242

4343

4444

45-
float_args,compile_time_instruction_count,455100000,0.015
46-
47-
48-
49-
mm_loop_inductor_gpu,compile_time_instruction_count,4407000000,0.015
50-
51-
52-
53-
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,7381000000,0.015
54-
55-
56-
57-
basic_NestedModule_eager,compile_time_instruction_count,8241000000,0.015
58-
59-
60-
61-
sum_floordiv_regression,compile_time_instruction_count,1000000000,0.015
45+
sum_floordiv_regression,compile_time_instruction_count,998600000,0.015
6246

6347

6448

@@ -78,11 +62,11 @@ aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5981000000,0
7862

7963

8064

81-
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8630000000,0.015
65+
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8585000000,0.015
8266

8367

8468

85-
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1890000000,0.015
69+
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1900000000,0.015
8670

8771

8872

test/test_dynamic_shapes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,10 +2851,11 @@ def func(a, b):
28512851
else:
28522852
return b * 20
28532853

2854-
# call with guarding.
2854+
# eager.
28552855
self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]))
28562856
self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]))
28572857

2858+
# compile with unbacked.
28582859
unbacked_func = torch.compile(func, dynamic=True, fullgraph=True)
28592860
a = torch.tensor([1])
28602861
b = torch.tensor([1])
@@ -2916,10 +2917,11 @@ def func(a, b):
29162917
else:
29172918
return b * 20
29182919

2919-
# call with guarding.
2920+
# eager.
29202921
self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]))
29212922
self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]))
29222923

2924+
# compile with unbacked.
29232925
unbacked_func = torch.compile(func, dynamic=True, fullgraph=True)
29242926
a = torch.tensor([1])
29252927
b = torch.tensor([1])

torch/fx/experimental/symbolic_shapes.py

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import abc
1818
import atexit
1919
import collections
20-
import contextlib
2120
import dis
2221
import functools
2322
import hashlib
@@ -1218,17 +1217,6 @@ def compute_unbacked_bindings(
12181217
return symbol_to_path
12191218

12201219

1221-
def _log_suppressed_dde(a: SymBool, assumed_value: bool) -> None:
1222-
sloc, extra = a.node.shape_env._get_stack_summary(True)
1223-
log.info(
1224-
"could not evaluate %s due to data dependency, it was assumed to be %s with no runtime assertions %s %s",
1225-
a,
1226-
assumed_value,
1227-
sloc,
1228-
extra,
1229-
)
1230-
1231-
12321220
# The following two functions are common utilities used while defining unbacked semantics
12331221
# of various framework code. Those would be used in situations you prefer to guard and know
12341222
# the result of the expression over not guarding, but in case you hit a data dependent error
@@ -1265,12 +1253,11 @@ def _guard_or(a: BoolLikeType, default: bool) -> bool:
12651253
if shape_env is None:
12661254
return guard_bool(a)
12671255

1268-
with a.node.shape_env.dde_suppressed():
1269-
try:
1270-
return guard_bool(a)
1271-
except GuardOnDataDependentSymNode:
1272-
_log_suppressed_dde(a, default)
1273-
return default
1256+
sym_node = a.node
1257+
r = sym_node.shape_env.evaluate_sym_node(
1258+
sym_node, size_oblivious=False, fallback_value=default
1259+
)
1260+
return bool(r)
12741261

12751262

12761263
def guard_or_false(a: BoolLikeType) -> bool:
@@ -3314,10 +3301,6 @@ def __init__(
33143301
else []
33153302
)
33163303

3317-
# Set true when data dependent errors are handled by caller side and not thrown. Ex: guard_or_false
3318-
# and guard_or_true. When its true, a different error message is produced.
3319-
self._dde_suppressed = False
3320-
33213304
# FakeTensor per-ShapeEnv operation cache. This is used for caching
33223305
# operations that contain symbolic shapes which have guards on the
33233306
# ShapeEnv (so are ShapeEnv-dependent).
@@ -3330,18 +3313,6 @@ def __init__(
33303313
torch._subclasses.fake_tensor._DispatchCacheEntry,
33313314
] = {}
33323315

3333-
@contextlib.contextmanager
3334-
def dde_suppressed(self) -> Iterator[None]:
3335-
"""Suppressed GuardOnDataDependent error logs"""
3336-
3337-
# We do not expect this to be called recursively.
3338-
assert not self._dde_suppressed, "not expected value for _dde_suppressed"
3339-
self._dde_suppressed = True
3340-
try:
3341-
yield
3342-
finally:
3343-
self._dde_suppressed = False
3344-
33453316
# Pro-tip: if you add new field to ShapeEnv, this affects some accept
33463317
# tests. Accept their output with:
33473318
#
@@ -3643,7 +3614,6 @@ def check_equal(self, other: ShapeEnv) -> None:
36433614
"replacements_slocs",
36443615
"_resimplify_floor_div_axioms",
36453616
"_expr_sym_node_id",
3646-
"_dde_suppressed",
36473617
"specialization_stacks",
36483618
)
36493619

@@ -6152,12 +6122,6 @@ def _make_data_dependent_error(
61526122
size_oblivious_result: Optional[sympy.Basic] = None,
61536123
expr_sym_node_id: Optional[int] = None,
61546124
) -> GuardOnDataDependentSymNode:
6155-
if self._dde_suppressed:
6156-
return GuardOnDataDependentSymNode(
6157-
expr,
6158-
"This data dependent error is suppressed and handled by the caller",
6159-
)
6160-
61616125
# TODO: in a Dynamo context, having user code, and having the
61626126
# name of the local, will be much better
61636127
size_like_symbols = []
@@ -6846,14 +6810,19 @@ def evaluate_sym_node(
68466810
self,
68476811
sym_node: SymNode,
68486812
size_oblivious: bool = False,
6813+
fallback_value: Optional[bool] = None,
68496814
) -> sympy.Basic:
68506815
"""
68516816
Given a a SymNode, evaluates sym_node.expr, adding guards if necessary.
68526817
"""
68536818

68546819
self._expr_sym_node_id = id(sym_node)
68556820
return self.evaluate_expr(
6856-
sym_node.expr, sym_node.hint, sym_node.fx_node, size_oblivious
6821+
sym_node.expr,
6822+
sym_node.hint,
6823+
sym_node.fx_node,
6824+
size_oblivious,
6825+
fallback_value=fallback_value,
68576826
)
68586827

68596828
def _is_python_assert(self) -> bool:
@@ -6939,17 +6908,25 @@ def evaluate_expr(
69396908
hint: Optional[Union[int, bool, float]] = None,
69406909
fx_node: Optional[torch.fx.Node] = None,
69416910
size_oblivious: bool = False,
6911+
fallback_value: Optional[bool] = None,
69426912
*,
69436913
forcing_spec: bool = False,
69446914
) -> sympy.Basic:
69456915
"""
69466916
Given an expression, evaluates it, adding guards if necessary
6917+
When fallback_value is not None the function return fallback_value instead of failing with data dependent error.
69476918
"""
69486919

69496920
# Add extra state that evaluate_expr() depends on.
69506921
suppress_guards_tls = ShapeEnv._suppress_guards_tls()
69516922
return self._inner_evaluate_expr(
6952-
orig_expr, hint, fx_node, size_oblivious, forcing_spec, suppress_guards_tls
6923+
orig_expr,
6924+
hint,
6925+
fx_node,
6926+
size_oblivious,
6927+
forcing_spec,
6928+
suppress_guards_tls,
6929+
fallback_value,
69536930
)
69546931

69556932
@lru_cache(256)
@@ -6962,17 +6939,19 @@ def _inner_evaluate_expr(
69626939
size_oblivious: bool,
69636940
forcing_spec: bool,
69646941
_suppress_guards_tls: bool,
6942+
fallback_value: Optional[bool] = None,
69656943
) -> sympy.Basic:
69666944
try:
69676945
return self._evaluate_expr(
69686946
orig_expr,
69696947
hint,
69706948
fx_node,
69716949
size_oblivious,
6950+
fallback_value,
69726951
forcing_spec=forcing_spec,
69736952
)
69746953
except Exception as e:
6975-
if isinstance(e, GuardOnDataDependentSymNode) and self._dde_suppressed:
6954+
if isinstance(e, GuardOnDataDependentSymNode):
69766955
pass
69776956
else:
69786957
self.log.warning(
@@ -6984,12 +6963,23 @@ def _inner_evaluate_expr(
69846963
)
69856964
raise
69866965

6966+
def _log_suppressed_dde(self, a: SymBool, assumed_value: bool) -> None:
6967+
sloc, extra = self._get_stack_summary(True)
6968+
log.info(
6969+
"could not evaluate %s due to data dependency, it was assumed to be %s with no runtime assertions %s %s",
6970+
a,
6971+
assumed_value,
6972+
sloc,
6973+
extra,
6974+
)
6975+
69876976
def _evaluate_expr(
69886977
self,
69896978
orig_expr: sympy.Basic,
69906979
hint: Optional[Union[bool, int, float]] = None,
69916980
fx_node: Optional[torch.fx.Node] = None,
69926981
size_oblivious: bool = False,
6982+
fallback_value: Optional[bool] = None,
69936983
*,
69946984
forcing_spec: bool = False,
69956985
) -> sympy.Basic:
@@ -7021,7 +7011,7 @@ def compute_concrete_val() -> sympy.Basic:
70217011
# 3. the guard should not be suppressed
70227012
# 4. the guard doesn't contain backed symfloat symbols
70237013
# since z3 can't handle floats
7024-
#
7014+
# 5. fallback_value is none.
70257015
# If all of the above check, we create an FX node representing the
70267016
# actual expression to be guarded.
70277017
node = None
@@ -7032,6 +7022,7 @@ def compute_concrete_val() -> sympy.Basic:
70327022
and not self._suppress_guards_tls()
70337023
and not size_oblivious
70347024
and not any(symbol_is_type(s, SymT.FLOAT) for s in orig_expr.free_symbols)
7025+
and fallback_value is None
70357026
):
70367027
# TODO: does this even worked with unbacked :think:
70377028
concrete_val = compute_concrete_val()
@@ -7113,7 +7104,7 @@ def compute_concrete_val() -> sympy.Basic:
71137104
# Those are backed dimentions that are treated as unbacked to avoid specializations, but if
71147105
# we fail to bypass with size oblivious reasoning we compute using the actual hint and guard.
71157106
if (
7116-
not self._dde_suppressed
7107+
fallback_value is None # do not do this under guard_or
71177108
and self.oblivious_var_to_val
71187109
and not (
71197110
correct_hint := orig_expr.xreplace(
@@ -7143,8 +7134,9 @@ def compute_concrete_val() -> sympy.Basic:
71437134
# unbacked_var_to_val is not None iff propagate_real_tensors is on.
71447135
# if propagate_real_tensors is on, we check the example values to generate (unsound_result)
71457136
# and if they pass we add a runtime assertions and continue.
7137+
71467138
if (
7147-
not self._dde_suppressed
7139+
fallback_value is None # do not do this under guard_or
71487140
and not ok
71497141
and self.unbacked_var_to_val
71507142
and not (
@@ -7165,10 +7157,16 @@ def compute_concrete_val() -> sympy.Basic:
71657157
transmute_into_runtime_assert = True
71667158
ok = True
71677159

7160+
# fallback value is set when guard_or_true, gaurd_or_false are used.
7161+
# whe we fail to evaluate soundly, we use the default value set by it.
7162+
if not ok and fallback_value is not None:
7163+
self._log_suppressed_dde(orig_expr, fallback_value)
7164+
return fallback_value
7165+
71687166
if not ok:
71697167
size_oblivious_result = None
71707168
# compute size_oblivious_result to suggest it as a fix for the user if it works.
7171-
if not size_oblivious and not self._dde_suppressed:
7169+
if not size_oblivious:
71727170
size_oblivious_result = self._maybe_evaluate_static(
71737171
expr, size_oblivious=True
71747172
)

0 commit comments

Comments
 (0)